├── ann ├── README.md ├── src │ ├── lib.rs │ ├── optim.rs │ ├── init.rs │ └── minibatch.rs ├── Cargo.toml └── examples │ ├── mlp_mnist_acc_on_test.rs │ └── mlp_on_mnist.rs ├── data-pipe ├── README.md ├── src │ ├── lib.rs │ └── dataloader │ │ ├── mod.rs │ │ └── mnist.rs └── Cargo.toml ├── ops ├── src │ ├── lib.rs │ └── stn.rs └── Cargo.toml ├── macros ├── README.md ├── Cargo.toml └── src │ └── lib.rs ├── tensor-rs ├── src │ ├── tensor_impl │ │ ├── cuda_tensor │ │ │ ├── linalg.rs │ │ │ ├── compare_tensor.rs │ │ │ ├── index_slicing.rs │ │ │ ├── reduction.rs │ │ │ ├── convolution.rs │ │ │ └── cuda_helper.rs │ │ ├── lapack_tensor │ │ │ ├── index_slicing.rs │ │ │ ├── reduction.rs │ │ │ ├── compare_tensor.rs │ │ │ ├── linalg.rs │ │ │ ├── mod.rs │ │ │ ├── lapack_api.rs │ │ │ └── elemwise.rs │ │ ├── mod.rs │ │ └── gen_tensor │ │ │ ├── compare_tensor.rs │ │ │ └── rand.rs │ ├── serde │ │ ├── mod.rs │ │ ├── typed_tensor.rs │ │ ├── gen_tensor.rs │ │ └── tensor.rs │ ├── tensor_trait │ │ ├── mod.rs │ │ ├── compare_tensor.rs │ │ ├── linalg.rs │ │ ├── rand.rs │ │ ├── reduction.rs │ │ ├── convolution.rs │ │ ├── elemwise.rs │ │ └── index_slicing.rs │ ├── lib.rs │ └── quaternion.rs ├── CHANGELOG.md ├── COPYRIGHT ├── README.md ├── benches │ └── test_mm_benchmark.rs ├── LICENSE-MIT └── Cargo.toml ├── extension-op ├── src │ └── lib.rs └── Cargo.toml ├── .gitignore ├── tensorboard-rs ├── .gitignore ├── environment.yml ├── examples │ ├── stop.jpg │ ├── draw_image.rs │ ├── draw_scalar.rs │ ├── draw_graph.rs │ └── draw_histo.rs ├── README-dev.md ├── COPYRIGHT ├── src │ ├── lib.rs │ ├── record_writer.rs │ ├── summary.rs │ ├── event_file_writer.rs │ ├── masked_crc32c.rs │ └── summary_writer.rs ├── README.md ├── Cargo.toml └── LICENSE-MIT ├── auto-diff ├── src │ ├── collection │ │ ├── mod.rs │ │ └── undirected_graph.rs │ ├── serde │ │ ├── mod.rs │ │ ├── compute_graph.rs │ │ ├── directed_graph.rs │ │ ├── generational_index.rs │ │ ├── var.rs │ │ ├── op.rs │ │ └── var_inner.rs │ ├── op │ │ ├── normalization.rs │ │ ├── local.rs │ │ ├── pooling.rs │ │ ├── vision.rs │ │ ├── comparison.rs │ │ ├── reduction.rs │ │ └── linalg.rs │ ├── err.rs │ ├── lib.rs │ └── optim.rs ├── COPYRIGHT ├── examples │ ├── data │ │ └── download.sh │ ├── linear_regression.rs │ ├── alexnet.rs │ ├── mnist.rs │ ├── logistic_regression.rs │ ├── mlp.rs │ ├── mlp_mnist.rs │ └── cnn_mnist.rs ├── CHANGELOG.md ├── LICENSE-MIT ├── Cargo.toml └── README.md ├── Cargo.toml ├── cargo_publish.sh ├── .github └── workflows │ └── rust.yml ├── benches ├── Cargo.toml └── benches │ ├── tensor_benchmark.rs │ ├── convolution_benchmark.rs │ └── elemwise_benchmark.rs ├── bump_version.sh └── README.md /ann/README.md: -------------------------------------------------------------------------------- 1 | Neural network tools -------------------------------------------------------------------------------- /data-pipe/README.md: -------------------------------------------------------------------------------- 1 | A data pipeline. -------------------------------------------------------------------------------- /ops/src/lib.rs: -------------------------------------------------------------------------------- 1 | 2 | pub mod stn; 3 | 4 | -------------------------------------------------------------------------------- /macros/README.md: -------------------------------------------------------------------------------- 1 | Macros for auto-diff. 2 | 3 | -------------------------------------------------------------------------------- /tensor-rs/src/tensor_impl/cuda_tensor/linalg.rs: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tensor-rs/src/tensor_impl/lapack_tensor/index_slicing.rs: -------------------------------------------------------------------------------- 1 | // 2 | -------------------------------------------------------------------------------- /tensor-rs/src/tensor_impl/lapack_tensor/reduction.rs: -------------------------------------------------------------------------------- 1 | // 2 | -------------------------------------------------------------------------------- /tensor-rs/src/tensor_impl/lapack_tensor/compare_tensor.rs: -------------------------------------------------------------------------------- 1 | // 2 | 3 | -------------------------------------------------------------------------------- /ann/src/lib.rs: -------------------------------------------------------------------------------- 1 | pub mod init; 2 | pub mod optim; 3 | pub mod minibatch; 4 | -------------------------------------------------------------------------------- /extension-op/src/lib.rs: -------------------------------------------------------------------------------- 1 | /// More ops for auto-diff 2 | 3 | pub struct A {} 4 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | Cargo.lock 2 | target/ 3 | *~ 4 | */examples/data 5 | logdir 6 | saved_model/ 7 | -------------------------------------------------------------------------------- /data-pipe/src/lib.rs: -------------------------------------------------------------------------------- 1 | //! A data loader for machine learning. 2 | 3 | 4 | pub mod dataloader; 5 | -------------------------------------------------------------------------------- /tensor-rs/src/serde/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod gen_tensor; 2 | pub mod typed_tensor; 3 | pub mod tensor; 4 | -------------------------------------------------------------------------------- /tensorboard-rs/.gitignore: -------------------------------------------------------------------------------- 1 | Cargo.lock 2 | target/ 3 | *~ 4 | examples/data/*.data 5 | logdir 6 | -------------------------------------------------------------------------------- /auto-diff/src/collection/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod directed_graph; 2 | pub mod generational_index; 3 | pub mod undirected_graph; 4 | -------------------------------------------------------------------------------- /tensorboard-rs/environment.yml: -------------------------------------------------------------------------------- 1 | name: tensorboardrs 2 | dependencies: 3 | - numpy 4 | - pandas 5 | - tensorboard 6 | -------------------------------------------------------------------------------- /tensorboard-rs/examples/stop.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pipehappy1/auto-diff/HEAD/tensorboard-rs/examples/stop.jpg -------------------------------------------------------------------------------- /auto-diff/src/serde/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod compute_graph; 2 | pub mod directed_graph; 3 | pub mod generational_index; 4 | pub mod op; 5 | pub mod var; 6 | pub mod var_inner; 7 | -------------------------------------------------------------------------------- /tensor-rs/src/tensor_trait/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod compare_tensor; 2 | pub mod convolution; 3 | pub mod elemwise; 4 | pub mod index_slicing; 5 | pub mod linalg; 6 | pub mod reduction; 7 | pub mod rand; 8 | -------------------------------------------------------------------------------- /tensor-rs/src/tensor_impl/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod gen_tensor; 2 | #[cfg(feature = "use-cuda")] 3 | pub mod cuda_tensor; 4 | #[cfg(feature = "use-cuda")] 5 | pub mod cuda_helper; 6 | #[cfg(feature = "use-blas-lapack")] 7 | pub mod lapack_tensor; 8 | -------------------------------------------------------------------------------- /tensorboard-rs/README-dev.md: -------------------------------------------------------------------------------- 1 | ### Install tensorboard ### 2 | 3 | ```sh,no_run 4 | conda env create -f environment.yml 5 | ``` 6 | 7 | ### Rust protobuf gen ### 8 | 9 | apt-get install protobuf-compiler 10 | 11 | cargo install protobuf-codegen -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [workspace] 2 | 3 | members = [ 4 | "auto-diff", 5 | "tensor-rs", 6 | "benches", 7 | "macros", 8 | "extension-op", 9 | "tensorboard-rs", 10 | "ann", 11 | "ops", 12 | "data-pipe", 13 | ] 14 | 15 | exclude = [ 16 | 17 | ] -------------------------------------------------------------------------------- /auto-diff/src/op/normalization.rs: -------------------------------------------------------------------------------- 1 | use tensor_rs::tensor::Tensor; 2 | use super::OpTrait; 3 | 4 | // BatchNorm1d 5 | // BatchNorm2d 6 | // BatchNorm3d 7 | // GroupNorm 8 | // SyncBatchNorm 9 | // InstanceNorm1d 10 | // InstanceNorm2d 11 | // InstanceNorm3d 12 | // LayerNorm 13 | // LocalResponseNorm 14 | 15 | 16 | -------------------------------------------------------------------------------- /extension-op/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "extension-op" 3 | version = "0.5.9" 4 | edition = "2021" 5 | 6 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html 7 | 8 | [dependencies] 9 | 10 | auto-diff = { path="../auto-diff", version = "0.5.9"} 11 | auto-diff-macros = { path="../macros", version = "0.5.9" } 12 | -------------------------------------------------------------------------------- /cargo_publish.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | 4 | cd ./tensorboard-rs 5 | cargo publish 6 | 7 | cd ../tensor-rs 8 | cargo publish 9 | 10 | sleep 2 11 | 12 | cd ../macros 13 | cargo publish 14 | 15 | cd ../auto-diff 16 | cargo publish 17 | 18 | sleep 2 19 | 20 | cd ../data-pipe 21 | cargo publish 22 | 23 | sleep 2 24 | 25 | cd ../ann 26 | cargo publish 27 | 28 | 29 | -------------------------------------------------------------------------------- /ann/src/optim.rs: -------------------------------------------------------------------------------- 1 | //use auto_diff::Var; 2 | use auto_diff::optim::Optimizer; 3 | use auto_diff::compute_graph::Net; 4 | use std::rc::Rc; 5 | use std::cell::RefCell; 6 | 7 | pub struct Momentum { 8 | 9 | } 10 | 11 | impl Momentum { 12 | 13 | } 14 | 15 | impl Optimizer for Momentum { 16 | fn step(&mut self, _net: Rc>) { 17 | unimplemented!() 18 | } 19 | } 20 | -------------------------------------------------------------------------------- /tensor-rs/CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | All notable changes to this project will be documented in this file. 3 | 4 | The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/) 5 | and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). 6 | 7 | ## [Unreleased] 8 | 9 | ## [0.3.0] - 2020-05-01 10 | ### Added 11 | - linear regression example works. 12 | -------------------------------------------------------------------------------- /tensor-rs/src/tensor_trait/compare_tensor.rs: -------------------------------------------------------------------------------- 1 | pub trait CompareTensor { 2 | type TensorType; 3 | type ElementType; 4 | 5 | fn max_pair(&self, o: &Self::TensorType) -> Self::TensorType; 6 | fn min_pair(&self, o: &Self::TensorType) -> Self::TensorType; 7 | fn all(&self, f: &dyn Fn(Self::ElementType) -> bool) -> bool; 8 | fn any(&self, f: &dyn Fn(Self::ElementType) -> bool) -> bool; 9 | } 10 | -------------------------------------------------------------------------------- /auto-diff/COPYRIGHT: -------------------------------------------------------------------------------- 1 | Copyrights in the auto-diff project are retained by their contributors. No 2 | copyright assignment is required to contribute to the auto-diff project. 3 | 4 | For full authorship information, see the version control history. 5 | 6 | Except as otherwise noted (below and/or in individual files), auto-diff is 7 | licensed under the MIT license 8 | or , at your option. 9 | -------------------------------------------------------------------------------- /tensor-rs/COPYRIGHT: -------------------------------------------------------------------------------- 1 | Copyrights in the auto-diff project are retained by their contributors. No 2 | copyright assignment is required to contribute to the auto-diff project. 3 | 4 | For full authorship information, see the version control history. 5 | 6 | Except as otherwise noted (below and/or in individual files), auto-diff is 7 | licensed under the MIT license 8 | or , at your option. 9 | -------------------------------------------------------------------------------- /tensorboard-rs/COPYRIGHT: -------------------------------------------------------------------------------- 1 | Copyrights in the auto-diff project are retained by their contributors. No 2 | copyright assignment is required to contribute to the auto-diff project. 3 | 4 | For full authorship information, see the version control history. 5 | 6 | Except as otherwise noted (below and/or in individual files), auto-diff is 7 | licensed under the MIT license 8 | or , at your option. 9 | -------------------------------------------------------------------------------- /ann/src/init.rs: -------------------------------------------------------------------------------- 1 | use ::rand::prelude::StdRng; 2 | 3 | use auto_diff::{Var, AutoDiffError}; 4 | use tensor_rs::tensor::Tensor; 5 | 6 | pub fn normal(data: &Tensor, mean: Option, std: Option, rng: &mut StdRng) -> Result<(), AutoDiffError>{ 7 | let size = data.size(); 8 | let mean = if let Some(v) = mean {f64::try_from(v)?} else {0.}; 9 | let std = if let Some(v) = std {f64::try_from(v)?} else {1.}; 10 | data.swap(&Var::normal(rng, &size, mean, std).val()); 11 | Ok(()) 12 | } 13 | -------------------------------------------------------------------------------- /ops/src/stn.rs: -------------------------------------------------------------------------------- 1 | use tensor_rs::tensor::Tensor; 2 | use auto_diff::op::{OpTrait, OpHandle}; 3 | use auto_diff_macros::add_op_handle; 4 | 5 | #[add_op_handle] 6 | pub struct AffineGrid { 7 | } 8 | impl AffineGrid { 9 | 10 | } 11 | 12 | #[add_op_handle] 13 | pub struct AffineSample { 14 | } 15 | impl AffineSample { 16 | 17 | } 18 | 19 | 20 | #[cfg(test)] 21 | mod tests { 22 | use super::*; 23 | 24 | #[test] 25 | fn it_works() { 26 | //let demo = H{}; 27 | } 28 | } 29 | -------------------------------------------------------------------------------- /tensorboard-rs/src/lib.rs: -------------------------------------------------------------------------------- 1 | 2 | //! Write data for Tensorboard from Rust. 3 | //! ============================================================= 4 | //! 5 | //! 6 | //! Introduction 7 | //! ------------ 8 | //! 9 | //! Install 10 | //! ------------ 11 | //! 12 | //! Example 13 | //! ------------ 14 | //! 15 | //! Licese 16 | //! ------------ 17 | 18 | 19 | 20 | 21 | 22 | pub mod masked_crc32c; 23 | pub mod record_writer; 24 | pub mod event_file_writer; 25 | pub mod summary_writer; 26 | pub mod summary; 27 | -------------------------------------------------------------------------------- /tensorboard-rs/examples/draw_image.rs: -------------------------------------------------------------------------------- 1 | use tensorboard_rs::summary_writer::SummaryWriter; 2 | use image::{open, }; 3 | 4 | pub fn main() { 5 | 6 | let mut writer = SummaryWriter::new(&("./logdir".to_string())); 7 | 8 | let stop_image = "./examples/stop.jpg"; 9 | let img = open(stop_image).expect(""); 10 | let img = img.into_rgb8(); 11 | let (width, height) = img.dimensions(); 12 | 13 | 14 | writer.add_image(&"test_image".to_string(), &img.into_raw()[..], &vec![3, width as usize, height as usize][..], 12); 15 | writer.flush(); 16 | } 17 | -------------------------------------------------------------------------------- /tensor-rs/src/tensor_impl/cuda_tensor/compare_tensor.rs: -------------------------------------------------------------------------------- 1 | #[cfg(feature = "use-cuda")] 2 | use crate::tensor::cuda_tensor::CudaTensor; 3 | use crate::tensor_trait::compare_tensor::CompareTensor; 4 | 5 | 6 | 7 | 8 | #[cfg(feature = "use-cuda")] 9 | impl CompareTensor for CudaTensor { 10 | type TensorType = CudaTensor; 11 | 12 | fn max_pair(&self, o: &Self::TensorType) -> Self::TensorType { 13 | unimplemented!(); 14 | } 15 | fn min_pair(&self, o: &Self::TensorType) -> Self::TensorType { 16 | unimplemented!(); 17 | } 18 | } 19 | 20 | -------------------------------------------------------------------------------- /tensor-rs/README.md: -------------------------------------------------------------------------------- 1 | # A typeless tensor library 2 | 3 | [![crates.io version](https://img.shields.io/crates/v/tensor-rs.svg)](https://crates.io/crates/tensor-rs) 4 | [![License](https://img.shields.io/crates/l/auto-diff.svg)](https://github.com/pipehappy1/auto-diff/blob/master/LICENSE.txt) 5 | 6 | ## Introduction 7 | 8 | A typeless tensor library 9 | 10 | ## Features 11 | 12 | - A type less tensor. 13 | - A set of ops for it. 14 | 15 | ## Usage 16 | 17 | 18 | ## Example 19 | 20 | 21 | 22 | ## Dependence 23 | 24 | install gfortran when openblas-src = "0.9" is used. 25 | 26 | 27 | -------------------------------------------------------------------------------- /tensor-rs/src/serde/typed_tensor.rs: -------------------------------------------------------------------------------- 1 | 2 | 3 | #[cfg(all(test, feature = "use-serde"))] 4 | mod tests { 5 | use crate::typed_tensor::TypedTensor; 6 | use crate::tensor_impl::gen_tensor::GenTensor; 7 | 8 | #[test] 9 | fn test_serde() { 10 | let m1 = GenTensor::::new_raw(&vec![1.,2.,3.,4.,5.,6.], &vec![3,2]); 11 | let m1 = TypedTensor::Typef64(m1); 12 | 13 | let serialized = serde_pickle::to_vec(&m1, true).unwrap(); 14 | let deserialized = serde_pickle::from_slice(&serialized).unwrap(); 15 | //println!("{:?}", deserialized); 16 | assert_eq!(m1, deserialized); 17 | } 18 | } 19 | -------------------------------------------------------------------------------- /tensor-rs/src/serde/gen_tensor.rs: -------------------------------------------------------------------------------- 1 | //use serde::{Serialize, Deserialize, Serializer, Deserializer, ser::SerializeStruct}; 2 | 3 | 4 | 5 | 6 | #[cfg(all(test, feature = "use-serde"))] 7 | mod tests { 8 | use crate::tensor_impl::gen_tensor::GenTensor; 9 | 10 | #[test] 11 | fn test_serde() { 12 | let m1 = GenTensor::::new_raw(&vec![1.,2.,3.,4.,5.,6.], &vec![3,2]); 13 | 14 | let serialized = serde_pickle::to_vec(&m1, true).unwrap(); 15 | let deserialized = serde_pickle::from_slice(&serialized).unwrap(); 16 | //println!("{:?}", deserialized); 17 | assert_eq!(m1, deserialized); 18 | } 19 | } 20 | -------------------------------------------------------------------------------- /tensorboard-rs/README.md: -------------------------------------------------------------------------------- 1 | # Write to tensorboard in Rust # 2 | 3 | Write TensorBoard events in Rust. 4 | 5 | * Can write `scalar`, `image`, `histogram`. 6 | 7 | ## Example 8 | 9 | * Write multiple scalar in one plot. 10 | 11 | ```rust,no_run 12 | let mut writer = SummaryWriter::new(&("./logdir".to_string())); 13 | 14 | for n_iter in 0..100 { 15 | let mut map = HashMap::new(); 16 | map.insert("x1".to_string(), (n_iter as f32)); 17 | map.insert("x^2".to_string(), (n_iter as f32) * (n_iter as f32)); 18 | writer.add_scalars("data/scalar_group", &map, n_iter); 19 | } 20 | writer.flush(); 21 | ``` 22 | 23 | -------------------------------------------------------------------------------- /auto-diff/examples/data/download.sh: -------------------------------------------------------------------------------- 1 | wget http://archive.ics.uci.edu/ml/machine-learning-databases/breast-cancer-wisconsin/wdbc.data -P examples/data/ 2 | 3 | mkdir -p data/mnist 4 | curl http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz --output "./data/mnist/train-images-idx3-ubyte.gz" 5 | curl http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz --output "./data/mnist/train-labels-idx1-ubyte.gz" 6 | curl http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz --output "./data/mnist/t10k-images-idx3-ubyte.gz" 7 | curl http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz --output "./data/mnist/t10k-labels-idx1-ubyte.gz" 8 | gzip -d data/mnist/*.gz 9 | -------------------------------------------------------------------------------- /data-pipe/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "auto-diff-data-pipe" 3 | version = "0.5.9" 4 | authors = ["yguan "] 5 | edition = "2021" 6 | description = """ 7 | A data loader 8 | """ 9 | documentation = "https://docs.rs/auto-diff" 10 | homepage = "https://github.com/pipehappy1/auto-diff" 11 | repository = "https://github.com/pipehappy1/auto-diff" 12 | readme = "README.md" 13 | license = "MIT" 14 | keywords = ["machine-learning", "neural-network", "deep-learning"] 15 | exclude = ["/dev/**"] 16 | 17 | [dependencies] 18 | auto-diff = { path = "../auto-diff", version = "0.5.9" } 19 | 20 | [dev-dependencies] 21 | # one backend 22 | openblas-src = "0.10" # or another backend of your choice -------------------------------------------------------------------------------- /macros/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "auto-diff-macros" 3 | version = "0.5.9" 4 | authors = ["yguan "] 5 | edition = "2021" 6 | description = """ 7 | Macros for auto-diff. 8 | 9 | 10 | """ 11 | documentation = "https://docs.rs/auto-diff-macros" 12 | homepage = "https://github.com/pipehappy1/auto-diff" 13 | repository = "https://github.com/pipehappy1/auto-diff" 14 | readme = "README.md" 15 | license = "MIT" 16 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html 17 | 18 | [lib] 19 | proc-macro = true 20 | 21 | [dependencies] 22 | 23 | syn = { version = "1", features = ["full", "extra-traits"] } 24 | quote = { version = "1" } 25 | -------------------------------------------------------------------------------- /auto-diff/CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | All notable changes to this project will be documented in this file. 3 | 4 | The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/) 5 | and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). 6 | 7 | ## [Unreleased] 8 | 9 | ## [0.5.9] - 2022-03-28 10 | - Use proc_macro to save some typing. 11 | 12 | ## [0.5.8] - 2022-03-19 13 | - Add data loader and fix mlp_mnist example. 14 | 15 | ## [0.5.7] - 2022-03-12 16 | - Support serde. 17 | 18 | ## [0.5.4] - 2022-02-21 19 | - Add more gradient methods. 20 | 21 | ## [0.5.3] - 2022-02-14 22 | - New active syntax is working. 23 | 24 | ## [0.3.0] - 2020-05-01 25 | ### Added 26 | - linear regression example works. 27 | -------------------------------------------------------------------------------- /auto-diff/src/serde/compute_graph.rs: -------------------------------------------------------------------------------- 1 | #[cfg(all(test, feature = "use-serde"))] 2 | mod tests { 3 | use crate::compute_graph::Net; 4 | use crate::var::Var; 5 | use rand::prelude::*; 6 | 7 | #[test] 8 | fn test_serde_net() { 9 | let mut rng = StdRng::seed_from_u64(671); 10 | let n = 10; 11 | let data = Var::normal(&mut rng, &vec![n, 2], 0., 2.); 12 | let result = data.matmul(&Var::new(&vec![2., 3.], &vec![2, 1])).unwrap() 13 | + Var::new(&vec![1.], &vec![1]); 14 | 15 | let serialized = serde_pickle::to_vec(&*result.dump_net().borrow(), true).unwrap(); 16 | let deserialized: Net = serde_pickle::from_slice(&serialized).unwrap(); 17 | //println!("{:?}", deserialized); 18 | //assert_eq!(*result.dump_net().borrow(), deserialized); 19 | } 20 | } 21 | -------------------------------------------------------------------------------- /ops/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "auto-diff-ops" 3 | version = "0.5.9" 4 | authors = ["yguan "] 5 | edition = "2021" 6 | description = """ 7 | Extra operators for auto-diff. 8 | """ 9 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html 10 | documentation = "https://docs.rs/auto-diff-ops" 11 | homepage = "https://github.com/pipehappy1/auto-diff" 12 | repository = "https://github.com/pipehappy1/auto-diff" 13 | readme = "README.md" 14 | license = "MIT" 15 | 16 | [dependencies] 17 | tensor-rs = { path="../tensor-rs", version = "0.5.9"} 18 | auto-diff = { path="../auto-diff", version = "0.5.9"} 19 | auto-diff-macros = { path="../macros", version = "0.5.9"} 20 | 21 | [dev-dependencies] 22 | # one backend 23 | openblas-src = "0.10" # or another backend of your choice 24 | 25 | -------------------------------------------------------------------------------- /tensor-rs/src/tensor_trait/linalg.rs: -------------------------------------------------------------------------------- 1 | pub trait LinearAlgbra { 2 | type TensorType; 3 | type ElementType; 4 | 5 | fn norm(&self) -> Self::TensorType; 6 | /// Assuming the input is 2 dimensional array, 7 | /// normalize_unit 8 | fn normalize_unit(&self) -> Self::TensorType; 9 | fn lu(&self) -> Option<[Self::TensorType; 2]>; 10 | fn lu_solve(&self, y: &Self::TensorType) -> Option; 11 | fn qr(&self) -> Option<[Self::TensorType; 2]>; 12 | fn eigen(&self) -> Option<[Self::TensorType; 2]>; 13 | fn cholesky(&self) -> Option; 14 | fn det(&self) -> Option; 15 | fn svd(&self) -> Option<[Self::TensorType; 3]>; 16 | fn inv(&self) -> Option; 17 | fn pinv(&self) -> Self::TensorType; 18 | fn tr(&self) -> Self::TensorType; 19 | } 20 | -------------------------------------------------------------------------------- /.github/workflows/rust.yml: -------------------------------------------------------------------------------- 1 | name: Rust 2 | 3 | on: 4 | push: 5 | branches: [ master ] 6 | pull_request: 7 | branches: [ master ] 8 | 9 | jobs: 10 | build: 11 | 12 | runs-on: ubuntu-latest 13 | 14 | steps: 15 | - uses: actions/checkout@v2 16 | - name: Install gfortran 17 | run: sudo apt-get install -y gfortran 18 | - name: Build 19 | run: cargo build --verbose 20 | - name: Run test 1 21 | run: cargo test --verbose 22 | working-directory: tensor-rs 23 | - name: Run test 2 24 | run: cargo test --verbose 25 | working-directory: auto-diff 26 | - name: Run examples 27 | run: cargo run --verbose --example linear_regression 28 | working-directory: auto-diff 29 | # - name: Run benchmark 30 | # run: cargo bench --bench elemwise_benchmark --verbose 31 | # working-directory: benches 32 | -------------------------------------------------------------------------------- /auto-diff/src/serde/directed_graph.rs: -------------------------------------------------------------------------------- 1 | #[cfg(all(test, feature = "use-serde"))] 2 | mod tests { 3 | use crate::collection::directed_graph::Graph; 4 | use crate::collection::generational_index::GenKey; 5 | 6 | #[test] 7 | fn test_serde_graph() { 8 | let mut m1 = Graph::::new(); 9 | let data1 = GenKey::new(1, 1); 10 | let data2 = GenKey::new(2, 6); 11 | let op1 = GenKey::new(3, 8); 12 | m1.add_data(&data1).unwrap(); 13 | m1.add_data(&data2).unwrap(); 14 | m1.add_op(&op1).unwrap(); 15 | m1.connect(&[data1], &[data2], &op1).unwrap(); 16 | 17 | let serialized = serde_pickle::to_vec(&m1, true).unwrap(); 18 | let deserialized = serde_pickle::from_slice(&serialized).unwrap(); 19 | //println!("{:?}", deserialized); 20 | assert_eq!(m1, deserialized); 21 | } 22 | } 23 | -------------------------------------------------------------------------------- /tensor-rs/src/tensor_trait/rand.rs: -------------------------------------------------------------------------------- 1 | use rand::prelude::StdRng; 2 | 3 | pub trait Random { 4 | type TensorType; 5 | type ElementType; 6 | 7 | /// Generate a random int close on left, open on right. 8 | fn rand_usize(rng: &mut StdRng, 9 | dim: &[usize], 10 | left: usize, right: usize) -> Self::TensorType; 11 | fn bernoulli() -> Self::TensorType; 12 | fn cauchy() -> Self::TensorType; 13 | fn exponential() -> Self::TensorType; 14 | fn geometric() -> Self::TensorType; 15 | fn log_normal() -> Self::TensorType; 16 | fn normal(rng: &mut StdRng, 17 | dim: &[usize], 18 | mean: Self::ElementType, 19 | std: Self::ElementType) -> Self::TensorType; 20 | fn uniform(rng: &mut StdRng, 21 | dim: &[usize], 22 | from: Self::ElementType, 23 | to: Self::ElementType) -> Self::TensorType; 24 | } 25 | -------------------------------------------------------------------------------- /tensorboard-rs/examples/draw_scalar.rs: -------------------------------------------------------------------------------- 1 | use tensorboard_rs::summary_writer::SummaryWriter; 2 | use std::collections::HashMap; 3 | 4 | pub fn main() { 5 | let mut writer = SummaryWriter::new(&("./logdir".to_string())); 6 | 7 | let name = "run1"; 8 | let mut scalar = 2.3; 9 | let mut step = 12; 10 | for i in 0..2 { 11 | println!("{}", i); 12 | scalar += (i as f32)*0.1; 13 | step += i; 14 | 15 | writer.add_scalar(name, scalar, step); 16 | } 17 | writer.flush(); 18 | 19 | for n_iter in 0..100 { 20 | let mut map = HashMap::new(); 21 | map.insert("xsinx".to_string(), (n_iter as f32) * (n_iter as f32).sin()); 22 | map.insert("xcosx".to_string(), (n_iter as f32) * (n_iter as f32).cos()); 23 | map.insert("arctanx".to_string(), (n_iter as f32).atan()); 24 | writer.add_scalars("data/scalar_group", &map, n_iter); 25 | } 26 | writer.flush(); 27 | } 28 | -------------------------------------------------------------------------------- /auto-diff/src/serde/generational_index.rs: -------------------------------------------------------------------------------- 1 | #[cfg(all(test, feature = "use-serde"))] 2 | mod tests { 3 | use crate::collection::generational_index::{GenIndex, GenKey}; 4 | 5 | #[test] 6 | fn test_serde_genkey() { 7 | let m1 = GenKey::new(1, 3); 8 | 9 | let serialized = serde_pickle::to_vec(&m1, true).unwrap(); 10 | let deserialized = serde_pickle::from_slice(&serialized).unwrap(); 11 | //println!("{:?}", deserialized); 12 | assert_eq!(m1, deserialized); 13 | } 14 | 15 | #[test] 16 | fn test_serde_genindex() { 17 | let mut m1 = GenIndex::::new(); 18 | let key = m1.insert(10.); 19 | m1.remove(&key).unwrap(); 20 | m1.insert(12.); 21 | 22 | let serialized = serde_pickle::to_vec(&m1, true).unwrap(); 23 | let deserialized = serde_pickle::from_slice(&serialized).unwrap(); 24 | //println!("{:?}", deserialized); 25 | assert_eq!(m1, deserialized); 26 | } 27 | } 28 | -------------------------------------------------------------------------------- /tensorboard-rs/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "tensorboard-rs" 3 | version = "0.5.9" 4 | authors = ["yguan "] 5 | edition = "2021" 6 | description = """ 7 | Write data for Tensorboard from Rust. 8 | 9 | 10 | """ 11 | documentation = "https://docs.rs/tensorboard-rs" 12 | homepage = "https://github.com/pipehappy1/auto-diff" 13 | repository = "https://github.com/pipehappy1/auto-diff" 14 | readme = "README.md" 15 | license = "MIT" 16 | keywords = ["machine-learning", "neural-network", "deep-learning"] 17 | exclude = ["/dev/**"] 18 | 19 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html 20 | 21 | [dependencies] 22 | protobuf = "2.14" 23 | 24 | #tensorboard-proto = { path = "../../tensorboard-proto", version = "0.5.9" } 25 | tensorboard-proto = { version = "0.5.7" } 26 | 27 | gethostname = "0.2.1" 28 | 29 | image = "0.23.4" 30 | 31 | [dev-dependencies] 32 | #protobuf-codegen = "2.14" 33 | 34 | 35 | [features] 36 | 37 | -------------------------------------------------------------------------------- /tensorboard-rs/src/record_writer.rs: -------------------------------------------------------------------------------- 1 | use crate::masked_crc32c::masked_crc32c; 2 | use std::io::Write; 3 | 4 | pub struct RecordWriter { 5 | _writer: W, 6 | } 7 | impl RecordWriter { 8 | pub fn new(writer: W) -> RecordWriter where W: Write{ 9 | RecordWriter { 10 | _writer: writer, 11 | } 12 | } 13 | pub fn write(&mut self, data: &[u8]) -> std::io::Result<()>{ 14 | let header = data.len() as u64; 15 | let header_crc = (masked_crc32c(&(header.to_le_bytes())) as u32).to_le_bytes(); 16 | let footer_crc = (masked_crc32c(data) as u32).to_le_bytes(); 17 | let header = header.to_le_bytes(); 18 | 19 | self._writer.write_all(&header)?; 20 | self._writer.write_all(&header_crc)?; 21 | self._writer.write_all(data)?; 22 | self._writer.write_all(&footer_crc) 23 | } 24 | pub fn flush(&mut self) -> std::io::Result<()> { 25 | self._writer.flush() 26 | } 27 | //pub fn close() {} 28 | //pub fn closed() {} 29 | } 30 | -------------------------------------------------------------------------------- /tensor-rs/benches/test_mm_benchmark.rs: -------------------------------------------------------------------------------- 1 | //use tensor_rs::tensor::*; 2 | //use tensor_rs::tensor::blas::*; 3 | 4 | 5 | #[cfg(test)] 6 | mod tests { 7 | 8 | extern crate openblas_src; 9 | use tensor_rs::tensor_impl::lapack_tensor::blas_api::BlasAPI; 10 | use tensor_rs::tensor::Tensor; 11 | 12 | #[test] 13 | fn test_gemm1() { 14 | 15 | for _i in 0..100000 { 16 | let v1: Vec = (0..128*256).map(|x| x as f32).collect(); 17 | let v2: Vec = (0..128*256).map(|x| x as f32).collect(); 18 | let mut v3: [f32; 65536] = [0.; 65536]; 19 | 20 | let trans = false; 21 | BlasAPI::::gemm(trans, trans, 256, 256, 128, 1., &v1, 256, &v2, 128, 1., &mut v3, 256); 22 | //println!("{:?}", v3); 23 | } 24 | 25 | } 26 | #[test] 27 | fn test_mm1() { 28 | 29 | for _i in 0..1000 { 30 | let v1 = Tensor::ones(&[256, 128]); 31 | let v2 = Tensor::ones(&[128, 256]); 32 | 33 | v1.mm(&v2); 34 | } 35 | } 36 | } 37 | 38 | -------------------------------------------------------------------------------- /auto-diff/src/err.rs: -------------------------------------------------------------------------------- 1 | use std::error::Error; 2 | use std::fmt; 3 | 4 | #[derive(Debug)] 5 | pub struct AutoDiffError { 6 | details: String, 7 | } 8 | 9 | impl AutoDiffError { 10 | pub fn new(msg: &str) -> AutoDiffError { 11 | AutoDiffError { 12 | details: msg.to_string(), 13 | } 14 | } 15 | } 16 | 17 | impl fmt::Display for AutoDiffError { 18 | fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { 19 | write!(f, "{}", self.details) 20 | } 21 | } 22 | 23 | impl Error for AutoDiffError { 24 | fn description(&self) -> &str { 25 | &self.details 26 | } 27 | } 28 | 29 | impl From for std::fmt::Error { 30 | fn from(item: AutoDiffError) -> std::fmt::Error { 31 | std::fmt::Error::default() 32 | } 33 | } 34 | 35 | #[cfg(test)] 36 | mod tests { 37 | use super::*; 38 | 39 | #[test] 40 | fn test() { 41 | fn return_err() -> Result<(), AutoDiffError> { 42 | Err(AutoDiffError::new(&format!("{:?}", 12))) 43 | } 44 | 45 | let e = return_err(); 46 | assert!(e.is_err()); 47 | } 48 | } 49 | -------------------------------------------------------------------------------- /tensor-rs/src/lib.rs: -------------------------------------------------------------------------------- 1 | //! A simple tensor implementation 2 | //! ============================================================= 3 | //! 4 | //! 5 | //! Introduction 6 | //! ------------ 7 | //! This is a type less tensor library with the option to use 8 | //! built-in operators or third-party acceleration library. 9 | //! Some API for tensor to implement are listed in [tensor_trait]. 10 | //! [typed_tensor] is an enum to cover the tensor type information for [tensor]. 11 | //! 12 | //! Currently, there are over 80 methods for [tensor]. 13 | //! 14 | //! Install 15 | //! ------------ 16 | //! cargo install tensor-rs 17 | //! 18 | //! Example 19 | //! ------------ 20 | //! The following example shows a dip to using the package. 21 | //! 22 | //! use tensor_rs::tensor_impl::gen_tensor::*; 23 | //! let m1 = GenTensor::::new_raw(&vec![0.; 3*5*2], &vec![3,5,2]); 24 | //! assert_eq!(m1.stride(), vec![10,2,1]); 25 | //! 26 | //! Licese 27 | //! ------------ 28 | 29 | 30 | pub mod tensor; 31 | pub mod quaternion; 32 | pub mod typed_tensor; 33 | pub mod tensor_trait; 34 | pub mod tensor_impl; 35 | #[cfg(feature = "use-serde")] 36 | pub mod serde; 37 | -------------------------------------------------------------------------------- /auto-diff/LICENSE-MIT: -------------------------------------------------------------------------------- 1 | Copyright 2018 Developers of the auto-diff project 2 | 3 | Permission is hereby granted, free of charge, to any 4 | person obtaining a copy of this software and associated 5 | documentation files (the "Software"), to deal in the 6 | Software without restriction, including without 7 | limitation the rights to use, copy, modify, merge, 8 | publish, distribute, sublicense, and/or sell copies of 9 | the Software, and to permit persons to whom the Software 10 | is furnished to do so, subject to the following 11 | conditions: 12 | 13 | The above copyright notice and this permission notice 14 | shall be included in all copies or substantial portions 15 | of the Software. 16 | 17 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF 18 | ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED 19 | TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A 20 | PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT 21 | SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY 22 | CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION 23 | OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR 24 | IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 25 | DEALINGS IN THE SOFTWARE. 26 | -------------------------------------------------------------------------------- /tensor-rs/LICENSE-MIT: -------------------------------------------------------------------------------- 1 | Copyright 2018 Developers of the auto-diff project 2 | 3 | Permission is hereby granted, free of charge, to any 4 | person obtaining a copy of this software and associated 5 | documentation files (the "Software"), to deal in the 6 | Software without restriction, including without 7 | limitation the rights to use, copy, modify, merge, 8 | publish, distribute, sublicense, and/or sell copies of 9 | the Software, and to permit persons to whom the Software 10 | is furnished to do so, subject to the following 11 | conditions: 12 | 13 | The above copyright notice and this permission notice 14 | shall be included in all copies or substantial portions 15 | of the Software. 16 | 17 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF 18 | ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED 19 | TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A 20 | PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT 21 | SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY 22 | CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION 23 | OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR 24 | IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 25 | DEALINGS IN THE SOFTWARE. 26 | -------------------------------------------------------------------------------- /tensorboard-rs/LICENSE-MIT: -------------------------------------------------------------------------------- 1 | Copyright 2018 Developers of the auto-diff project 2 | 3 | Permission is hereby granted, free of charge, to any 4 | person obtaining a copy of this software and associated 5 | documentation files (the "Software"), to deal in the 6 | Software without restriction, including without 7 | limitation the rights to use, copy, modify, merge, 8 | publish, distribute, sublicense, and/or sell copies of 9 | the Software, and to permit persons to whom the Software 10 | is furnished to do so, subject to the following 11 | conditions: 12 | 13 | The above copyright notice and this permission notice 14 | shall be included in all copies or substantial portions 15 | of the Software. 16 | 17 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF 18 | ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED 19 | TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A 20 | PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT 21 | SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY 22 | CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION 23 | OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR 24 | IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 25 | DEALINGS IN THE SOFTWARE. 26 | -------------------------------------------------------------------------------- /tensorboard-rs/examples/draw_graph.rs: -------------------------------------------------------------------------------- 1 | use tensorboard_rs::summary_writer::SummaryWriter; 2 | //use tensorboard_proto::event::{Event, TaggedRunMetadata}; 3 | //use tensorboard_proto::summary::{Summary}; 4 | //use tensorboard_proto::graph::{GraphDef, }; 5 | use tensorboard_proto::node_def::{NodeDef, }; 6 | //use tensorboard_proto::versions::{VersionDef, }; 7 | use tensorboard_proto::attr_value::{AttrValue, }; 8 | //use tensorboard_proto::tensor_shape::{TensorShapeProto, }; 9 | //use tensorboard_proto::step_stats::{RunMetadata, }; 10 | use protobuf::RepeatedField; 11 | use std::collections::HashMap; 12 | 13 | pub fn main() { 14 | let mut writer = SummaryWriter::new(&("./logdir".to_string())); 15 | 16 | let mut node1 = NodeDef::new(); 17 | node1.set_name("node1".to_string()); 18 | node1.set_op("op1".to_string()); 19 | 20 | let inputs = RepeatedField::from(vec![]); 21 | node1.set_input(inputs); 22 | 23 | let mut attrs = HashMap::new(); 24 | let mut v1 = AttrValue::new(); 25 | v1.set_i(16); 26 | attrs.insert("attr1".to_string(), v1); 27 | node1.set_attr(attrs); 28 | 29 | writer.add_graph(&[node1]); 30 | 31 | writer.flush(); 32 | } 33 | -------------------------------------------------------------------------------- /ann/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "auto-diff-ann" 3 | version = "0.5.9" 4 | authors = ["yguan "] 5 | edition = "2021" 6 | description = """ 7 | A neural network library in Rust. 8 | 9 | 10 | """ 11 | documentation = "https://docs.rs/auto-diff" 12 | homepage = "https://github.com/pipehappy1/auto-diff" 13 | repository = "https://github.com/pipehappy1/auto-diff" 14 | readme = "README.md" 15 | license = "MIT" 16 | keywords = ["machine-learning", "neural-network", "deep-learning"] 17 | exclude = ["/dev/**"] 18 | 19 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html 20 | 21 | [dependencies] 22 | rand = "0.8" 23 | rand_distr = "0.4" 24 | 25 | tensor-rs = { path = "../tensor-rs", version = "0.5.9" } 26 | auto-diff = { path = "../auto-diff", version = "0.5.9" } 27 | 28 | auto-diff-data-pipe = { path = "../data-pipe", version = "0.5.9" } 29 | 30 | [dev-dependencies] 31 | # one backend 32 | openblas-src = "0.10" # or another backend of your choice 33 | 34 | auto-diff-data-pipe = { path = "../data-pipe", version = "0.5.9" } 35 | 36 | tensorboard-rs = { path = "../tensorboard-rs", version = "0.5.9" } 37 | 38 | bincode = {version = "1.3.3"} -------------------------------------------------------------------------------- /tensor-rs/src/tensor_trait/reduction.rs: -------------------------------------------------------------------------------- 1 | pub trait ReduceTensor where Self: std::marker::Sized { 2 | 3 | fn argmax(&self, dim: Option<&[usize]>, keep_dim: bool) -> Self; 4 | fn argmin(&self, dim: Option<&[usize]>, keep_dim: bool) -> Self; 5 | fn dist(); 6 | /// log(sum(exp(x))), 7 | /// dim is the dimension along which sum is applied. 8 | /// if keep_dim, the dimension along which sum is applied will be kept and be 1. 9 | fn logsumexp(&self, dim: Option<&[usize]>, keep_dim: bool) -> Self; 10 | fn mean(&self, dim: Option<&[usize]>, keepdim: bool) -> Self; 11 | fn median(); 12 | fn mode(); 13 | fn prod(&self, dim: Option<&[usize]>, keepdim: bool) -> Self; 14 | fn std(&self, dim: Option<&[usize]>, keepdim: bool) -> Self; 15 | fn std_mean(); 16 | //fn sum(&self, dim: usize, keepdim: bool) -> Self::TensorType; 17 | fn sum(&self, dim: Option<&[usize]>, keepdim: bool) -> Self; 18 | fn unique(); 19 | fn unique_consecutive(); 20 | fn var(&self, dim: Option<&[usize]>, keepdim: bool) -> Self; 21 | fn var_mean(); 22 | 23 | fn max(&self, dim: Option<&[usize]>, keepdim: bool) -> Self; 24 | fn min(&self, dim: Option<&[usize]>, keepdim: bool) -> Self; 25 | } 26 | -------------------------------------------------------------------------------- /benches/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "benches" 3 | version = "0.0.0" 4 | authors = ["yguan "] 5 | edition = "2021" 6 | description = """ 7 | benchmark for auto-diff 8 | 9 | 10 | """ 11 | documentation = "https://docs.rs/auto-diff" 12 | homepage = "https://github.com/pipehappy1/auto-diff" 13 | repository = "https://github.com/pipehappy1/auto-diff" 14 | readme = "README.md" 15 | license = "MIT" 16 | keywords = ["machine-learning", "neural-network", "deep-learning"] 17 | exclude = ["/dev/**"] 18 | 19 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html 20 | 21 | [dependencies] 22 | tensor-rs = { path = "../tensor-rs" , features = ["use-blas-lapack"]} 23 | auto-diff = { path = "../auto-diff" } 24 | 25 | ndarray = "0.12" 26 | ndarray-linalg = "0.11" 27 | 28 | [dev-dependencies] 29 | criterion = "0.3" 30 | 31 | # one backend 32 | openblas-src = "0.10" # or another backend of your choice 33 | 34 | # for examples 35 | csv = "1.1" 36 | 37 | 38 | 39 | [features] 40 | 41 | [[bench]] 42 | name = "tensor_benchmark" 43 | harness = false 44 | 45 | [[bench]] 46 | name = "elemwise_benchmark" 47 | harness = false 48 | 49 | [[bench]] 50 | name = "convolution_benchmark" 51 | harness = false 52 | 53 | -------------------------------------------------------------------------------- /tensor-rs/src/tensor_trait/convolution.rs: -------------------------------------------------------------------------------- 1 | use crate::tensor::PaddingMode; 2 | 3 | pub trait Convolution where Self: std::marker::Sized { 4 | 5 | fn conv2d(&self, filter: &Self, 6 | stride: (usize, usize), 7 | padding: (usize, usize), 8 | dilation: (usize, usize), 9 | padding_mode: PaddingMode 10 | ) -> Self; 11 | 12 | fn conv2d_grad(&self, filter: &Self, 13 | stride: (usize, usize), 14 | padding: (usize, usize), 15 | dilation: (usize, usize), 16 | padding_mode: PaddingMode, 17 | output_grad: &Self 18 | ) -> (Self, Self); 19 | 20 | fn conv_gen(&self, filter: &Self, 21 | stride: &[usize], 22 | padding: &[usize], 23 | dilation: &[usize], 24 | padding_mode: PaddingMode 25 | ) -> Self; 26 | 27 | fn conv_grad_gen(&self, filter: &Self, 28 | stride: &[usize], 29 | padding: &[usize], 30 | dilation: &[usize], 31 | padding_mode: PaddingMode, 32 | output_grad: &Self, 33 | ) -> (Self, Self); 34 | } 35 | -------------------------------------------------------------------------------- /data-pipe/src/dataloader/mod.rs: -------------------------------------------------------------------------------- 1 | use auto_diff::{Var, AutoDiffError}; 2 | 3 | #[derive(Copy, Clone)] 4 | pub enum DataSlice { 5 | Train, 6 | Test, 7 | Tune, 8 | Other, 9 | } 10 | 11 | pub trait DataLoader { 12 | /// The shape of the data if applicable. 13 | fn get_size(&self, slice: Option) -> Result, AutoDiffError>; 14 | /// Return one sample. 15 | fn get_item(&self, index: usize, slice: Option) -> Result<(Var, Var), AutoDiffError>; 16 | /// Return a batch following original order. 17 | fn get_batch(&self, start: usize, end: usize, slice: Option) -> Result<(Var, Var), AutoDiffError>; 18 | /// Return a batch given the index. 19 | fn get_indexed_batch(&self, index: &[usize], slice: Option) -> Result<(Var, Var), AutoDiffError> { 20 | let mut data: Vec = vec![]; 21 | let mut label: Vec = vec![]; 22 | 23 | for elem_index in index { 24 | let (elem_data, elem_label) = self.get_item(*elem_index, slice)?; 25 | data.push(elem_data); 26 | label.push(elem_label); 27 | } 28 | let d1 = data[0].cat(&data[1..], 0)?; 29 | let d2 = label[0].cat(&label[1..], 0)?; 30 | d1.reset_net(); 31 | d2.reset_net(); 32 | Ok((d1, d2)) 33 | } 34 | } 35 | 36 | pub mod mnist; 37 | -------------------------------------------------------------------------------- /auto-diff/src/collection/undirected_graph.rs: -------------------------------------------------------------------------------- 1 | use std::collections::{BTreeMap, BTreeSet}; 2 | //use crate::err::AutoDiffError; 3 | 4 | /// Graph 5 | pub struct UnDirectedGraph { 6 | node: BTreeSet, 7 | edge: BTreeSet, 8 | edge2node: BTreeMap>, 9 | node2edige: BTreeMap>, 10 | } 11 | 12 | impl Default 13 | for UnDirectedGraph 14 | { 15 | fn default() -> UnDirectedGraph { 16 | UnDirectedGraph { 17 | node: BTreeSet::new(), 18 | edge: BTreeSet::new(), 19 | edge2node: BTreeMap::new(), 20 | node2edige: BTreeMap::new(), 21 | } 22 | } 23 | } 24 | 25 | impl UnDirectedGraph { 26 | pub fn new() -> UnDirectedGraph { 27 | UnDirectedGraph { 28 | node: BTreeSet::new(), 29 | edge: BTreeSet::new(), 30 | edge2node: BTreeMap::new(), 31 | node2edige: BTreeMap::new(), 32 | } 33 | } 34 | } 35 | 36 | #[cfg(test)] 37 | mod tests { 38 | use super::*; 39 | use crate::collection::generational_index::GenKey; 40 | 41 | #[test] 42 | fn new() { 43 | let _g = UnDirectedGraph::::new(); 44 | } 45 | } 46 | -------------------------------------------------------------------------------- /auto-diff/Cargo.toml: -------------------------------------------------------------------------------- 1 | 2 | [package] 3 | name = "auto-diff" 4 | version = "0.5.9" 5 | authors = ["yguan "] 6 | edition = "2021" 7 | description = """ 8 | A neural network library in Rust. 9 | 10 | 11 | """ 12 | documentation = "https://docs.rs/auto-diff" 13 | homepage = "https://github.com/pipehappy1/auto-diff" 14 | repository = "https://github.com/pipehappy1/auto-diff" 15 | readme = "README.md" 16 | license = "MIT" 17 | keywords = ["machine-learning", "neural-network", "deep-learning"] 18 | exclude = ["/dev/**"] 19 | 20 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html 21 | 22 | [dependencies] 23 | tensor-rs = { path = "../tensor-rs", version = "0.5.9" } 24 | auto-diff-macros = { path = "../macros", version = "0.5.9" } 25 | 26 | num-traits = "0.2" 27 | 28 | rand = "0.8" 29 | rand_distr = "0.4" 30 | 31 | serde = { version = "1.0", features = ["derive"], optional = true} 32 | 33 | #lazy_static = { version = "1.4.0", optional = true} 34 | 35 | 36 | 37 | [dev-dependencies] 38 | criterion = "0.3" 39 | 40 | # one backend 41 | openblas-src = "0.10" # or another backend of your choice 42 | 43 | # for examples 44 | csv = "1.1" 45 | 46 | #tensorboard-rs = { path = "../tensorboard-rs", version = "0.5.9"} 47 | 48 | serde-pickle = {version = "0.6"} 49 | 50 | #cargo-expand = "1" 51 | 52 | [features] 53 | 54 | default = ["use-f64", "use-serde"] 55 | 56 | use-f32 = [] 57 | use-f64 = [] 58 | 59 | use-serde = ["serde"] 60 | -------------------------------------------------------------------------------- /ann/examples/mlp_mnist_acc_on_test.rs: -------------------------------------------------------------------------------- 1 | use auto_diff_ann::minibatch::MiniBatch; 2 | use auto_diff::Var; 3 | use auto_diff_data_pipe::dataloader::{mnist::Mnist, DataSlice}; 4 | use std::path::Path; 5 | use rand::prelude::*; 6 | use ::rand::prelude::StdRng; 7 | use std::fs; 8 | 9 | extern crate openblas_src; 10 | 11 | fn main() { 12 | 13 | let rng = StdRng::seed_from_u64(671); 14 | 15 | let mnist = Mnist::load(&Path::new("../auto-diff/examples/data/mnist")); 16 | let minibatch = MiniBatch::new(rng, 16); 17 | 18 | let file_name = "./saved_model/net_9900"; 19 | let deserialized = fs::read(file_name).expect("unable to read file"); 20 | let loss: Var = bincode::deserialize(&deserialized).unwrap(); 21 | 22 | let (inputs, _outputs) = loss.get_io_var().expect(""); 23 | let predict = loss.predict().expect(""); 24 | 25 | let mut right = 0.; 26 | let mut total = 0.; 27 | for (data, label) in minibatch.iter_block(&mnist, &DataSlice::Test).expect("") { 28 | 29 | let input = data.reshape(&[minibatch.batch_size(), data.size()[1]*data.size()[2]]).unwrap(); 30 | input.reset_net(); 31 | 32 | inputs[0].set(&input); 33 | inputs[1].set(&label); 34 | loss.rerun().expect(""); 35 | 36 | let predict_max = predict.clone().argmax(Some(&[1]), false).unwrap(); 37 | right += f64::try_from(predict_max.eq_elem(&label).unwrap().sum(None, false).unwrap()).unwrap(); 38 | total += input.size()[0] as f64; 39 | 40 | println!("acc: {}", right/total); 41 | } 42 | } 43 | -------------------------------------------------------------------------------- /auto-diff/src/lib.rs: -------------------------------------------------------------------------------- 1 | #![allow(dead_code)] 2 | #![allow(unused_variables)] 3 | //#![no_std] 4 | //! An auto-difference library 5 | //! ============================================================= 6 | //! 7 | //! 8 | //! Introduction 9 | //! ------------ 10 | //! This is yet another auto-difference library for deep neural network. 11 | //! The focus is easy on use and dynamic computation graph building. 12 | //! 13 | //! Install 14 | //! ------------ 15 | //! Add auto-diff = "0.5" to the \[dependencies\] section of your project Cargo.toml file. 16 | //! 17 | //! Features 18 | //! ------------ 19 | //! The forward operators support a commonly used set, including: 20 | //! 21 | //! 1. getter/setter, 22 | //! 2. index and slicing, 23 | //! 3. +, -, *, / and matmul, 24 | //! 4. speciall functions, 25 | //! 5. statistics, 26 | //! 6. linear algebra, 27 | //! 7. random number generator. 28 | //! 29 | //! The corresponding gradient is work-in-progress. 30 | //! 31 | //! One feature of auto-diff is the auto-difference is in background 32 | //! and don't get in your way if only forward calculation is needed. 33 | //! Thus it can be used without syntax like variable place holder. 34 | //! 35 | //! Example 36 | //! ------------ 37 | //! 38 | 39 | pub mod err; 40 | pub mod op; 41 | pub mod optim; 42 | pub mod var; 43 | 44 | pub use err::AutoDiffError; 45 | pub use var::Var; 46 | 47 | pub mod collection; 48 | pub mod compute_graph; 49 | #[cfg(feature = "use-serde")] 50 | pub mod serde; 51 | pub mod var_inner; 52 | -------------------------------------------------------------------------------- /auto-diff/examples/linear_regression.rs: -------------------------------------------------------------------------------- 1 | use rand::prelude::*; 2 | use auto_diff::var::Var; 3 | use auto_diff::optim::{SGD}; 4 | use auto_diff::op::Linear; 5 | use auto_diff::op::OpCall; 6 | use auto_diff::err::AutoDiffError; 7 | extern crate openblas_src; 8 | 9 | fn main() { 10 | 11 | fn func(input: &Var) -> Result { 12 | let input = input.clone(); 13 | input.set_grad(false); 14 | let result = input.matmul(&Var::new(&vec![2., 3.], &vec![2, 1]))? + Var::new(&vec![1.], &vec![1]); 15 | result.set_grad(true); 16 | Ok(result) 17 | } 18 | 19 | let n = 15; 20 | let mut rng = StdRng::seed_from_u64(671); 21 | let data = Var::normal(&mut rng, &vec![n, 2], 0., 2.); 22 | let label = func(&data).unwrap(); 23 | 24 | let mut op1 = Linear::new(Some(2), Some(1), true); 25 | op1.set_weight(Var::normal(&mut rng, &[2, 1], 0., 2.)); 26 | op1.set_bias(Var::normal(&mut rng, &[1, ], 0., 2.)); 27 | 28 | let output = op1.call(&[&data]).unwrap().pop().unwrap(); 29 | 30 | let loss = output.mse_loss(&label).unwrap(); 31 | 32 | let mut opt = SGD::new(3.); 33 | 34 | for i in 0..200 { 35 | 36 | println!("index: {}, loss: {:?}", i, loss); 37 | loss.rerun().unwrap(); 38 | loss.bp().unwrap(); 39 | loss.step(&mut opt).unwrap(); 40 | 41 | let weight = op1.weight(); 42 | let bias = op1.bias(); 43 | println!("{:?}, {:?}", weight, bias); 44 | } 45 | } 46 | -------------------------------------------------------------------------------- /tensorboard-rs/examples/draw_histo.rs: -------------------------------------------------------------------------------- 1 | use tensorboard_rs::summary_writer::SummaryWriter; 2 | //use image::{open, }; 3 | 4 | pub fn main() { 5 | 6 | let mut writer = SummaryWriter::new(&("./logdir".to_string())); 7 | 8 | let min = 1.001; 9 | let max = 29.001; 10 | let num = 435.; 11 | let sum = 8555.435; 12 | let sum_squares = 189242.110435; 13 | let bucket_limits = [3.8009999999999997, 6.600999999999999, 9.400999999999998, 12.200999999999999, 15.001, 17.801, 20.601, 23.401, 26.201, 29.001]; 14 | let bucket_counts = [ 6., 15., 24., 33., 27., 48., 57., 66., 75., 84.]; 15 | 16 | writer.add_histogram_raw("run1/histo1", 17 | min, max, 18 | num, 19 | sum, sum_squares, 20 | &bucket_limits, &bucket_counts, 21 | 1 22 | ); 23 | 24 | writer.add_histogram_raw("run1/histo1", 25 | min, max, 26 | num, 27 | sum, sum_squares, 28 | &bucket_limits, &bucket_counts, 29 | 2 30 | ); 31 | 32 | writer.add_histogram_raw("run1/histo1", 33 | min, max, 34 | num, 35 | sum, sum_squares, 36 | &bucket_limits, &bucket_counts, 37 | 3 38 | ); 39 | writer.flush(); 40 | } 41 | -------------------------------------------------------------------------------- /tensor-rs/src/tensor_trait/elemwise.rs: -------------------------------------------------------------------------------- 1 | pub trait ElemwiseTensorOp { 2 | type TensorType; 3 | type ElementType; 4 | 5 | fn abs(&self) -> Self::TensorType; 6 | fn acos(&self) -> Self::TensorType; 7 | fn asin(&self) -> Self::TensorType; 8 | fn atan(&self) -> Self::TensorType; 9 | fn ceil(&self) -> Self::TensorType; 10 | fn clamp(&self, min: Self::ElementType, max: Self::ElementType) -> Self::TensorType; 11 | fn cos(&self) -> Self::TensorType; 12 | fn cosh(&self) -> Self::TensorType; 13 | fn exp(&self) -> Self::TensorType; 14 | fn expm1(&self) -> Self::TensorType; 15 | fn floor(&self) -> Self::TensorType; 16 | fn frac(&self) -> Self::TensorType ; 17 | fn log(&self) -> Self::TensorType; 18 | fn log10(&self) -> Self::TensorType; 19 | fn log1p(&self) -> Self::TensorType; 20 | fn log1pexp(&self) -> Self::TensorType; 21 | fn log2(&self) -> Self::TensorType; 22 | fn neg(&self) -> Self::TensorType; 23 | fn pow(&self, n: Self::ElementType) -> Self::TensorType; 24 | fn reciprocal(&self) -> Self::TensorType; 25 | fn round(&self) -> Self::TensorType; 26 | fn rsqrt(&self) -> Self::TensorType ; 27 | fn sigmoid(&self) -> Self::TensorType; 28 | fn sign(&self) -> Self::TensorType; 29 | fn sin(&self) -> Self::TensorType; 30 | fn sinh(&self) -> Self::TensorType; 31 | fn sqrt(&self) -> Self::TensorType; 32 | fn square(&self) -> Self::TensorType; 33 | fn tan(&self) -> Self::TensorType; 34 | fn tanh(&self) -> Self::TensorType; 35 | fn trunc(&self) -> Self::TensorType; 36 | 37 | } 38 | -------------------------------------------------------------------------------- /auto-diff/examples/alexnet.rs: -------------------------------------------------------------------------------- 1 | //use tensor_rs::tensor::Tensor; 2 | //use auto_diff::var::{Module, Var, bcewithlogitsloss}; 3 | 4 | //fn alexnet(x: Var) { 5 | // def __init__(self, num_classes=1000): 6 | // super(AlexNet, self).__init__() 7 | // self.features = nn.Sequential( 8 | // nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2), 9 | // nn.ReLU(inplace=True), 10 | // nn.MaxPool2d(kernel_size=3, stride=2), 11 | // nn.Conv2d(64, 192, kernel_size=5, padding=2), 12 | // nn.ReLU(inplace=True), 13 | // nn.MaxPool2d(kernel_size=3, stride=2), 14 | // nn.Conv2d(192, 384, kernel_size=3, padding=1), 15 | // nn.ReLU(inplace=True), 16 | // nn.Conv2d(384, 256, kernel_size=3, padding=1), 17 | // nn.ReLU(inplace=True), 18 | // nn.Conv2d(256, 256, kernel_size=3, padding=1), 19 | // nn.ReLU(inplace=True), 20 | // nn.MaxPool2d(kernel_size=3, stride=2), 21 | // ) 22 | // self.avgpool = nn.AdaptiveAvgPool2d((6, 6)) 23 | // self.classifier = nn.Sequential( 24 | // nn.Dropout(), 25 | // nn.Linear(256 * 6 * 6, 4096), 26 | // nn.ReLU(inplace=True), 27 | // nn.Dropout(), 28 | // nn.Linear(4096, 4096), 29 | // nn.ReLU(inplace=True), 30 | // nn.Linear(4096, num_classes), 31 | // ) 32 | // 33 | // def forward(self, x): 34 | // x = self.features(x) 35 | // x = self.avgpool(x) 36 | // x = torch.flatten(x, 1) 37 | // x = self.classifier(x) 38 | // return x 39 | //} 40 | 41 | fn main() { 42 | } 43 | -------------------------------------------------------------------------------- /benches/benches/tensor_benchmark.rs: -------------------------------------------------------------------------------- 1 | use criterion::{black_box, criterion_group, criterion_main, Criterion, Throughput, BenchmarkId}; 2 | use std::iter; 3 | 4 | use auto_diff::Var; 5 | 6 | extern crate ndarray; 7 | extern crate ndarray_linalg; 8 | extern crate openblas_src; // or another backend of your choice 9 | 10 | //use ndarray; 11 | 12 | fn single_add_benchmark(c: &mut Criterion) { 13 | let m1 = Var::fill_f64(&vec![10,10], 1.); 14 | c.bench_function("single add", |b| b.iter(|| { 15 | let m3 = m1.add(&m1); 16 | })); 17 | } 18 | 19 | fn tensor_dim_benchmark(c: &mut Criterion) { 20 | let ss = vec![10, 20, 30, 50, 70, 128]; 21 | 22 | let mut group = c.benchmark_group("tensor_dim"); 23 | for size in ss.iter() { 24 | let m1 = Var::fill_f64(&vec![*size, *size], 1.); 25 | group.bench_with_input(BenchmarkId::new("add", size*size), size, |b, &size| { 26 | b.iter(|| { 27 | let m_result = m1.add(&m1); 28 | }); 29 | }); 30 | group.bench_with_input(BenchmarkId::new("mul", size*size), size, |b, &size| { 31 | b.iter(|| { 32 | let m_result = m1.mul(&m1); 33 | }); 34 | }); 35 | let md = &ndarray::Array2::::zeros(((*size) as usize, (*size) as usize)); 36 | group.bench_with_input(BenchmarkId::new("mul_ndarray", size*size), size, |b, &size| { 37 | b.iter(|| { 38 | let m_result = md * md; 39 | }); 40 | }); 41 | } 42 | group.finish(); 43 | 44 | } 45 | 46 | criterion_group!(benches, single_add_benchmark, tensor_dim_benchmark); 47 | criterion_main!(benches); 48 | -------------------------------------------------------------------------------- /tensor-rs/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "tensor-rs" 3 | version = "0.5.9" 4 | authors = ["yguan "] 5 | edition = "2021" 6 | description = """ 7 | A typeless tensor library 8 | 9 | 10 | """ 11 | documentation = "https://docs.rs/tensor-rs" 12 | homepage = "https://github.com/pipehappy1/auto-diff" 13 | repository = "https://github.com/pipehappy1/auto-diff" 14 | readme = "README.md" 15 | license = "MIT" 16 | keywords = ["machine-learning", "neural-network", "deep-learning"] 17 | exclude = ["/dev/**"] 18 | 19 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html 20 | 21 | [dependencies] 22 | num-traits = "0.2" 23 | 24 | rand = "0.8" 25 | rand_distr = "0.4" 26 | 27 | blas = { version = "0.22", optional = true } 28 | blas-src = { version = "0.8", optional = true } 29 | lapack = { version = "0.19", optional = true } 30 | lapack-src = { version = "0.8", optional = true } 31 | 32 | #rcublas = { version = "0.5", optional = true } 33 | cuda11-cudart-sys = { version = "0.3", optional = true } 34 | cuda11-cutensor-sys = { version = "0.3", optional = true } 35 | 36 | serde = { version = "1.0", features = ["derive"], optional = true} 37 | 38 | [dev-dependencies] 39 | criterion = "0.3" 40 | 41 | # for examples 42 | csv = "1.1" 43 | 44 | tensorboard-rs = { path = "../tensorboard-rs", version = "0.5.9"} 45 | 46 | openblas-src = { version = "0.10" } 47 | 48 | serde-pickle = {version = "0.6"} 49 | 50 | [features] 51 | 52 | default = ["use-f64", "use-blas-lapack", "use-serde"] # 53 | 54 | use-f32 = [] 55 | use-f64 = [] 56 | use-usize = [] 57 | use-u8 = [] 58 | 59 | use-serde = ["serde"] 60 | 61 | use-blas-lapack = ["blas", "blas-src", "lapack", "lapack-src"] 62 | 63 | use-cuda = ["cuda11-cudart-sys", "cuda11-cutensor-sys"] -------------------------------------------------------------------------------- /tensor-rs/src/tensor_impl/cuda_tensor/index_slicing.rs: -------------------------------------------------------------------------------- 1 | #[cfg(feature = "use-cuda")] 2 | use super::cuda_tensor::CudaTensor; 3 | use crate::tensor_trait::index_slicing::IndexSlicing; 4 | 5 | 6 | 7 | 8 | /****************/ 9 | // Cuda tensor ops 10 | /****************/ 11 | #[cfg(feature = "use-cuda")] 12 | impl IndexSlicing for CudaTensor { 13 | fn cat(&self, tensors: &[&Self], dim: usize) -> Self { 14 | todo!(); 15 | } 16 | fn chunk(&self, chunks: usize, dim: usize) -> Vec { 17 | todo!(); 18 | } 19 | fn gather(&self, dim: usize, index: &Self) -> Self { 20 | todo!(); 21 | } 22 | fn index_select(&self, dim: usize, index: &Self) -> Self 23 | { 24 | todo!(); 25 | } 26 | // fn masked_select(); 27 | //pub fn narrow() {} 28 | //pub fn nonzero() {} 29 | fn reshape(&self, new_shape: &[usize]) -> Self { 30 | todo!(); 31 | } 32 | fn split(&self, sections: &[usize], dim: usize) -> Vec { 33 | todo!(); 34 | } 35 | fn squeeze(&self, dim: Option) -> Self { 36 | todo!(); 37 | } 38 | fn stack(tensors: &[&Self], dim: usize) -> Self { 39 | todo!(); 40 | } 41 | //pub fn t() {} 42 | fn take(&self, index: &[usize]) -> Self { 43 | todo!(); 44 | } 45 | //pub fn transpose() {} 46 | //pub fn unbind() {} 47 | fn permute(&self, dims: &[usize]) -> Self { 48 | todo!(); 49 | } 50 | fn unsqueeze(&self, dim: usize) -> Self { 51 | todo!(); 52 | } 53 | //pub fn condition() {} // this is pytorch where 54 | fn conditional_select(&self, x: &Self, y: &Self) -> Self { 55 | todo!(); 56 | } 57 | fn repeat(&self, sizes: &[usize]) -> Self { 58 | todo!(); 59 | } 60 | 61 | } 62 | -------------------------------------------------------------------------------- /tensor-rs/src/tensor_impl/cuda_tensor/reduction.rs: -------------------------------------------------------------------------------- 1 | #[cfg(feature = "use-cuda")] 2 | use crate::tensor::cuda_tensor::CudaTensor; 3 | use crate::tensor_trait::reduction::ReduceTensor; 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | ////////////// 14 | // cuda tensor 15 | ////////////// 16 | #[cfg(feature = "use-cuda")] 17 | impl ReduceTensor for CudaTensor { 18 | 19 | fn argmax(&self, dim: Option<&[usize]>, keep_dim: bool) -> Self { 20 | todo!(); 21 | } 22 | fn argmin(&self, dim: Option<&[usize]>, keep_dim: bool) -> Self { 23 | todo!(); 24 | } 25 | fn dist() { 26 | todo!(); 27 | } 28 | fn logsumexp(&self, dim: Option<&[usize]>, keep_dim: bool) -> Self { 29 | todo!(); 30 | } 31 | fn mean(&self, dim: Option<&[usize]>, keepdim: bool) -> Self { 32 | todo!(); 33 | } 34 | fn median() { 35 | todo!(); 36 | } 37 | fn mode() { 38 | todo!(); 39 | } 40 | fn prod(&self, dim: Option<&[usize]>, keepdim: bool) -> Self { 41 | todo!(); 42 | } 43 | fn std(&self, dim: Option<&[usize]>, keepdim: bool) -> Self { 44 | todo!(); 45 | } 46 | fn std_mean() { 47 | todo!(); 48 | } 49 | //fn sum(&self, dim: usize, keepdim: bool) -> Self::TensorType; 50 | fn sum(&self, dim: Option<&[usize]>, keepdim: bool) -> Self { 51 | todo!(); 52 | } 53 | fn unique() { 54 | todo!(); 55 | } 56 | fn unique_consecutive() { 57 | todo!(); 58 | } 59 | fn var(&self, dim: Option<&[usize]>, keepdim: bool) -> Self { 60 | todo!(); 61 | } 62 | fn var_mean() { 63 | todo!(); 64 | } 65 | 66 | fn max(&self, dim: Option<&[usize]>, keepdim: bool) -> Self { 67 | todo!(); 68 | } 69 | fn min(&self, dim: Option<&[usize]>, keepdim: bool) -> Self { 70 | todo!(); 71 | } 72 | } 73 | -------------------------------------------------------------------------------- /benches/benches/convolution_benchmark.rs: -------------------------------------------------------------------------------- 1 | use criterion::{black_box, criterion_group, criterion_main, Criterion, Throughput, BenchmarkId}; 2 | use std::iter; 3 | 4 | use tensor_rs::tensor_impl::gen_tensor::*; 5 | use tensor_rs::tensor_impl::lapack_tensor::*; 6 | use tensor_rs::tensor::PaddingMode; 7 | use tensor_rs::tensor_trait::index_slicing::IndexSlicing; 8 | use tensor_rs::tensor_trait::convolution::Convolution; 9 | use tensor_rs::tensor_impl::lapack_tensor::convolution::gemm_conv_f32; 10 | 11 | 12 | extern crate ndarray; 13 | extern crate ndarray_linalg; 14 | extern crate openblas_src; // or another backend of your choice 15 | 16 | //use ndarray; 17 | 18 | fn varing_input_size_benchmark(c: &mut Criterion) { 19 | let ss = [10, 15, 20, 25, 30, 35, 50, 70, 100]; 20 | 21 | let mut group = c.benchmark_group("varing_input_size"); 22 | for size in &ss { 23 | let data = GenTensor::::fill(1., &vec![*size, *size]).reshape(&[1, 1, *size, *size]); 24 | let filter = GenTensor::::arange(9).reshape(&vec![1, 1, 3, 3]); 25 | let stride = vec![1, 1]; 26 | let padding = vec![1, 1]; 27 | let dilation = vec![1, 1]; 28 | let padding_mode = PaddingMode::Zeros; 29 | group.bench_with_input(BenchmarkId::new("naive", size*size), size, |b, &size| { 30 | b.iter(|| { 31 | let result = data.conv_gen(&filter, &stride, &padding, &dilation, padding_mode); 32 | }); 33 | }); 34 | 35 | group.bench_with_input(BenchmarkId::new("dot_product", size*size), size, |b, &size| { 36 | b.iter(|| { 37 | let result = gemm_conv_f32(&data, &filter, &stride, &padding, &dilation, padding_mode); 38 | }); 39 | }); 40 | } 41 | group.finish(); 42 | 43 | } 44 | 45 | criterion_group!(benches, varing_input_size_benchmark); 46 | criterion_main!(benches); 47 | -------------------------------------------------------------------------------- /tensor-rs/src/tensor_impl/cuda_tensor/convolution.rs: -------------------------------------------------------------------------------- 1 | use std::collections::BTreeMap; 2 | use super::gen_tensor::*; 3 | use crate::tensor::PaddingMode; 4 | 5 | #[cfg(feature = "use-cuda")] 6 | use crate::tensor::cuda_tensor::CudaTensor; 7 | use crate::tensor_trait::convolution::Convolution; 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | #[cfg(test)] 16 | mod tests { 17 | use crate::tensor::gen_tensor::GenTensor; 18 | use crate::tensor_trait::index_slicing::IndexSlicing; 19 | use super::*; 20 | 21 | 22 | 23 | 24 | 25 | } 26 | 27 | 28 | ////////////// 29 | // for cuda tensor 30 | ///////// 31 | #[cfg(feature = "use-cuda")] 32 | impl Convolution for CudaTensor { 33 | 34 | fn conv2d(&self, filter: &Self, 35 | stride: (usize, usize), 36 | padding: (usize, usize), 37 | dilation: (usize, usize), 38 | padding_mode: PaddingMode 39 | ) -> Self { 40 | todo!(); 41 | } 42 | 43 | fn conv2d_grad(&self, filter: &Self, 44 | stride: (usize, usize), 45 | padding: (usize, usize), 46 | dilation: (usize, usize), 47 | padding_mode: PaddingMode, 48 | output_grad: &Self 49 | ) -> (Self, Self) { 50 | todo!(); 51 | } 52 | 53 | fn conv_gen(&self, filter: &Self, 54 | stride: &[usize], 55 | padding: &[usize], 56 | dilation: &[usize], 57 | padding_mode: PaddingMode 58 | ) -> Self { 59 | todo!(); 60 | } 61 | 62 | fn conv_grad_gen(&self, filter: &Self, 63 | stride: &[usize], 64 | padding: &[usize], 65 | dilation: &[usize], 66 | padding_mode: PaddingMode, 67 | output_grad: &Self, 68 | ) -> (Self, Self) { 69 | todo!(); 70 | } 71 | } 72 | -------------------------------------------------------------------------------- /tensor-rs/src/tensor_impl/lapack_tensor/linalg.rs: -------------------------------------------------------------------------------- 1 | // 2 | use crate::tensor_impl::gen_tensor::GenTensor; 3 | use crate::tensor_trait::index_slicing::IndexSlicing; 4 | #[cfg(feature = "use-blas-lapack")] 5 | use super::lapack_api::LapackAPI; 6 | use std::cmp; 7 | 8 | #[cfg(feature = "use-blas-lapack")] 9 | macro_rules! lapack_svd { 10 | ($a:ty, $b: ident) => { 11 | pub fn $b( 12 | x: &GenTensor<$a>, 13 | ) -> (GenTensor<$a>, GenTensor<$a>, GenTensor<$a>) { 14 | if x.size().len() != 2 { 15 | panic!("lapack_svd expects 2d matrix."); 16 | } 17 | let n = x.size()[0]; 18 | let m = x.size()[1]; 19 | let mmn = cmp::min(m, n); 20 | 21 | let mut ma = x.get_data().clone(); 22 | let mut s: Vec<$a> = vec![0.; mmn]; 23 | let mut u: Vec<$a> = vec![0.; mmn*m]; 24 | let mut vt: Vec<$a> = vec![0.; mmn*n]; 25 | let mut info: i32 = 0; 26 | LapackAPI::<$a>::gesdd(&'S', m, n, 27 | &mut ma, m, 28 | &mut s, 29 | &mut u, m, 30 | &mut vt, mmn, 31 | &mut info); 32 | let ret_u = GenTensor::<$a>::new_move(vt, vec![n, mmn]); 33 | let ret_s = GenTensor::<$a>::new_move(s, vec![mmn]); 34 | let ret_v = GenTensor::<$a>::new_move(u, vec![mmn, m]).t(); 35 | if info != 0 { 36 | panic!("svd return inf ononzero!"); 37 | } 38 | (ret_u, ret_s, ret_v) 39 | } 40 | } 41 | } 42 | 43 | #[cfg(feature = "use-blas-lapack")] 44 | lapack_svd!(f32, svd_f32); 45 | 46 | #[cfg(feature = "use-blas-lapack")] 47 | lapack_svd!(f64, svd_f64); 48 | 49 | 50 | #[cfg(test)] 51 | mod tests { 52 | use crate::tensor_impl::gen_tensor::GenTensor; 53 | use super::*; 54 | 55 | #[test] 56 | #[cfg(feature = "use-blas-lapack")] 57 | fn test_svd() { 58 | let m = GenTensor::::new_raw(&[4., 12., -16., 12., 37., -43., -16., -43., 98.], &[3, 3]); 59 | let (u, s, v) = svd_f64(&m); 60 | println!("{:?}, {:?}, {:?}", u, s, v); 61 | } 62 | } 63 | -------------------------------------------------------------------------------- /bump_version.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | cargo_file="Cargo.toml" 4 | 5 | function validate_semversion() { 6 | version=$1 7 | if [[ "${version}" =~ ^v.+$ ]]; then 8 | version="${version:1}" 9 | else 10 | echo "bad version: ${version}" 11 | exit 1 12 | fi 13 | 14 | if [[ "${version}" =~ ^(0|[1-9][0-9]*)\.(0|[1-9][0-9]*)\.(0|[1-9][0-9]*)(-((0|[1-9][0-9]*|[0-9]*[a-zA-Z-][0-9a-zA-Z-]*)(\.(0|[1-9][0-9]*|[0-9]*[a-zA-Z-][0-9a-zA-Z-]*))*))?(\+([0-9a-zA-Z-]+(\.[0-9a-zA-Z-]+)*))?$ ]]; then 15 | echo "${version}" 16 | else 17 | echo "bad version: ${version}" 18 | exit 1 19 | fi 20 | } 21 | 22 | function bump_version_usage() { 23 | echo 24 | echo "Usage: bump_version.sh [OPTIONS] VERSION_OLD VERSION_NEW" 25 | echo 26 | echo "bump the cargo version through the package." 27 | echo 28 | echo "Options:" 29 | echo " -v Verbose output, prints errors and echos the raw version on success" 30 | echo " -t Run tests for this script" 31 | echo " -h Print usage" 32 | echo 33 | echo "Run like: bump_version.sh -v v0.5.5 v0.5.6" 34 | echo 35 | echo 36 | } 37 | 38 | function bump_version() { 39 | test=0 40 | verbose=0 41 | 42 | while getopts ":vt" opt; do 43 | case $opt in 44 | t) test=1 45 | ;; 46 | v) verbose=1 47 | ;; 48 | \?) echo "Invalid option -$OPTARG" >&2; echo; bump_version_usage; exit 1 49 | ;; 50 | esac 51 | done 52 | 53 | shift $(($OPTIND - 1)) 54 | version_old=$1 55 | version_new=$2 56 | 57 | semver_old=$(validate_semversion "${version_old}") 58 | semver_new=$(validate_semversion "${version_new}") 59 | 60 | line_old="version = \"${semver_old}\"" 61 | line_new="version = \"${semver_new}\"" 62 | 63 | find . -type f -name "${cargo_file}" -exec perl -pi \ 64 | -e "s|${line_old}|${line_new}|g;" \ 65 | {} + 66 | 67 | echo "bump version from ${semver_old} to ${semver_new} ..." 68 | 69 | } 70 | 71 | bump_version "$@" 72 | -------------------------------------------------------------------------------- /benches/benches/elemwise_benchmark.rs: -------------------------------------------------------------------------------- 1 | use criterion::{black_box, criterion_group, criterion_main, Criterion, Throughput, BenchmarkId}; 2 | use std::iter; 3 | 4 | use auto_diff::Var; 5 | use tensor_rs::tensor_impl::gen_tensor::GenTensor; 6 | use tensor_rs::tensor::Tensor; 7 | 8 | extern crate ndarray; 9 | extern crate ndarray_linalg; 10 | extern crate openblas_src; // or another backend of your choice 11 | 12 | //use ndarray; 13 | 14 | fn elemwise_benchmark(c: &mut Criterion) { 15 | let ss = vec![10, 20, 30, 50,]; 16 | 17 | let mut group = c.benchmark_group("elemwise"); 18 | for size in ss.iter() { 19 | 20 | let m1 = Var::fill_f64(&vec![*size, *size], 1.); 21 | let m2 = Var::fill_f64(&vec![*size, *size], 2.); 22 | group.bench_with_input(BenchmarkId::new("var", size*size), size, |b, &size| { 23 | b.iter(|| { 24 | let tmp = m1.sub(&m2).unwrap(); 25 | let tmp2 = tmp.mul(&tmp).unwrap(); 26 | }); 27 | }); 28 | 29 | let m1 = GenTensor::::fill(1., &vec![*size, *size]); 30 | let m2 = GenTensor::::fill(2., &vec![*size, *size]); 31 | group.bench_with_input(BenchmarkId::new("gentensor", size*size), size, |b, &size| { 32 | b.iter(|| { 33 | let m_result = GenTensor::::squared_error(&m1, &m2); 34 | }); 35 | }); 36 | 37 | let m1 = Tensor::fill_f64(&vec![*size, *size], 1.); 38 | let m2 = Tensor::fill_f64(&vec![*size, *size], 2.); 39 | group.bench_with_input(BenchmarkId::new("tensor", size*size), size, |b, &size| { 40 | b.iter(|| { 41 | let tmp = m1.sub(&m2); 42 | let tmp2 = tmp.mul(&tmp); 43 | }); 44 | }); 45 | 46 | let md1 = &ndarray::Array2::::zeros(((*size) as usize, (*size) as usize)); 47 | let md2 = &ndarray::Array2::::ones(((*size) as usize, (*size) as usize)); 48 | group.bench_with_input(BenchmarkId::new("ndarray", size*size), size, |b, &size| { 49 | b.iter(|| { 50 | let tmp = md1 - md2; 51 | let tmp2 = &tmp*&tmp; 52 | }); 53 | }); 54 | } 55 | group.finish(); 56 | 57 | } 58 | 59 | criterion_group!(benches, elemwise_benchmark); 60 | criterion_main!(benches); 61 | -------------------------------------------------------------------------------- /tensor-rs/src/tensor_impl/gen_tensor/compare_tensor.rs: -------------------------------------------------------------------------------- 1 | use super::GenTensor; 2 | use crate::tensor_trait::compare_tensor::CompareTensor; 3 | 4 | impl CompareTensor for GenTensor where T: num_traits::Float { 5 | type TensorType = GenTensor; 6 | type ElementType = T; 7 | 8 | fn max_pair(&self, o: &GenTensor) -> GenTensor { 9 | if self.size() != o.size() { 10 | panic!("max needs two tensor have the same size, {:?}, {:?}", self.size(), o.size()); 11 | } 12 | let mut ret = GenTensor::zeros(self.size()); 13 | 14 | for ((a, b), c) in self.get_data().iter().zip(o.get_data().iter()).zip(ret.get_data_mut().iter_mut()) { 15 | if a >= b { 16 | *c = *a; 17 | } else { 18 | *c = *b; 19 | } 20 | } 21 | ret 22 | } 23 | // min, 24 | fn min_pair(&self, o: &GenTensor) -> GenTensor { 25 | if self.size() != o.size() { 26 | panic!("max needs two tensor have the same size, {:?}, {:?}", self.size(), o.size()); 27 | } 28 | let mut ret = GenTensor::zeros(self.size()); 29 | 30 | for ((a, b), c) in self.get_data().iter().zip(o.get_data().iter()).zip(ret.get_data_mut().iter_mut()) { 31 | if a >= b { 32 | *c = *b; 33 | } else { 34 | *c = *a; 35 | } 36 | } 37 | ret 38 | } 39 | 40 | fn all(&self, f: &dyn Fn(Self::ElementType) -> bool) -> bool { 41 | self.get_data().iter().all(|x| f(*x)) 42 | } 43 | fn any(&self, f: &dyn Fn(Self::ElementType) -> bool) -> bool { 44 | self.get_data().iter().any(|x| f(*x)) 45 | } 46 | } 47 | 48 | #[cfg(test)] 49 | mod tests { 50 | use crate::tensor_impl::gen_tensor::GenTensor; 51 | use super::*; 52 | 53 | #[test] 54 | fn max_pair() { 55 | let a = GenTensor::::new_raw(&vec![1., 3., 10., 11.], &vec![2,2]); 56 | let b = GenTensor::::new_raw(&vec![2., 4., 5., 6.], &vec![2,2]); 57 | let c = a.max_pair(&b); 58 | assert_eq!(c, GenTensor::::new_raw(&vec![2., 4., 10., 11.], &vec![2,2])); 59 | } 60 | 61 | #[test] 62 | fn min_pair() { 63 | let a = GenTensor::::new_raw(&vec![1., 3., 10., 11.], &vec![2,2]); 64 | let b = GenTensor::::new_raw(&vec![2., 4., 5., 6.], &vec![2,2]); 65 | let c = a.min_pair(&b); 66 | assert_eq!(c, GenTensor::::new_raw(&vec![1., 3., 5., 6.], &vec![2,2])); 67 | } 68 | } 69 | -------------------------------------------------------------------------------- /tensor-rs/src/tensor_impl/lapack_tensor/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod compare_tensor; 2 | pub mod convolution; 3 | pub mod elemwise; 4 | pub mod index_slicing; 5 | pub mod linalg; 6 | pub mod reduction; 7 | pub mod blas_api; 8 | pub mod lapack_api; 9 | 10 | use crate::tensor_impl::gen_tensor::GenTensor; 11 | use crate::tensor_impl::lapack_tensor::blas_api::BlasAPI; 12 | 13 | macro_rules! blas_matmul { 14 | ($a:ty, $b: ident) => { 15 | pub fn $b( 16 | x: &GenTensor<$a>, 17 | y: &GenTensor<$a>, 18 | ) -> GenTensor<$a> { 19 | if x.size()[x.size().len()-1] != y.size()[0] { 20 | panic!("matmul expect matched size {:?}, {:?}", x.size(), y.size()); 21 | } 22 | if x.size().len() == 1 && y.size().len() == 1 { 23 | panic!("Two vector have not matched size for matmul! {:?}, {:?}", x.numel(), y.numel()); 24 | } 25 | let inner = y.size()[0]; 26 | let mut cap = 1; 27 | let mut odim = Vec::new(); 28 | let mut lloop = 1; 29 | let mut rloop = 1; 30 | for i in 0..x.size().len()-1 { 31 | cap *= x.size()[i]; 32 | odim.push(x.size()[i]); 33 | lloop *= x.size()[i]; 34 | } 35 | for i in 1..y.size().len() { 36 | cap *= y.size()[i]; 37 | odim.push(y.size()[i]); 38 | rloop *= y.size()[i]; 39 | } 40 | 41 | let mut ret = GenTensor::<$a>::new_move( 42 | vec![0.; cap], odim); 43 | 44 | BlasAPI::<$a>::gemm(false, false, 45 | rloop, lloop, inner, 46 | 1., y.get_data(), rloop, 47 | x.get_data(), inner, 48 | 1., ret.get_data_mut(), rloop,); 49 | ret 50 | } 51 | } 52 | } 53 | 54 | blas_matmul!(f32, matmul_f32); 55 | blas_matmul!(f64, matmul_f64); 56 | 57 | #[cfg(test)] 58 | mod tests { 59 | use crate::tensor_impl::gen_tensor::GenTensor; 60 | use super::*; 61 | 62 | #[test] 63 | fn test_matmul() { 64 | let v1 = GenTensor::::new_raw(&[1., 2., 3., 4., 5., 6.], &[2, 3]); 65 | let v2 = GenTensor::::new_raw(&[11., 12., 13., 14., 15., 16., 17., 18., 19.], &[3, 3]); 66 | let v3 = matmul_f32(&v1, &v2); 67 | let em = GenTensor::::new_raw(&[90.0, 96.0, 102.0, 216.0, 231.0, 246.0], &[2, 3]); 68 | assert_eq!(v3, em); 69 | } 70 | } 71 | -------------------------------------------------------------------------------- /auto-diff/README.md: -------------------------------------------------------------------------------- 1 | # A simple machine learning toolset 2 | 3 | [![crates.io version](https://img.shields.io/crates/v/auto-diff.svg)](https://crates.io/crates/auto-diff) 4 | [![License](https://img.shields.io/crates/l/auto-diff.svg)](https://github.com/pipehappy1/auto-diff/blob/master/LICENSE.txt) 5 | [![example workflow](https://github.com/pipehappy1/auto-diff/actions/workflows/rust.yml/badge.svg)](https://github.com/pipehappy1/auto-diff/actions) 6 | [![doc badge](https://docs.rs/auto-diff/badge.svg)](https://docs.rs/auto-diff) 7 | 8 | ## Introduction 9 | 10 | This is an auto-difference based learning library. 11 | 12 | ## Features 13 | 14 | - A type-less tensor. 15 | - Variable over tensor with support for back propagation. 16 | - Support for common operators, including convolution. 17 | 18 | ## Example 19 | 20 | ```rust,no_run 21 | use tensor_rs::tensor::Tensor; 22 | use auto_diff::rand::RNG; 23 | use auto_diff::var::{Module}; 24 | use auto_diff::optim::{SGD, Optimizer}; 25 | 26 | fn main() { 27 | 28 | fn func(input: &Tensor) -> Tensor { 29 | input.matmul(&Tensor::from_vec_f32(&vec![2., 3.], &vec![2, 1])).add(&Tensor::from_vec_f32(&vec![1.], &vec![1])) 30 | } 31 | 32 | let N = 100; 33 | let mut rng = RNG::new(); 34 | rng.set_seed(123); 35 | let data = rng.normal(&vec![N, 2], 0., 2.); 36 | let label = func(&data); 37 | 38 | 39 | let mut m = Module::new(); 40 | 41 | let op1 = m.linear(Some(2), Some(1), true); 42 | let weights = op1.get_values().unwrap(); 43 | rng.normal_(&weights[0], 0., 1.); 44 | rng.normal_(&weights[1], 0., 1.); 45 | op1.set_values(&weights); 46 | 47 | let op2 = op1.clone(); 48 | let block = m.func( 49 | move |x| { 50 | op2.call(x) 51 | } 52 | ); 53 | 54 | let loss_func = m.mse_loss(); 55 | 56 | let mut opt = SGD::new(3.); 57 | 58 | for i in 0..200 { 59 | let input = m.var_value(data.clone()); 60 | 61 | let y = block.call(&[&input]); 62 | 63 | let loss = loss_func.call(&[&y, &m.var_value(label.clone())]); 64 | println!("index: {}, loss: {}", i, loss.get().get_scale_f32()); 65 | 66 | loss.backward(-1.); 67 | opt.step2(&block); 68 | 69 | } 70 | 71 | let weights = op1.get_values().expect(""); 72 | println!("{:?}, {:?}", weights[0], weights[1]); 73 | } 74 | ``` 75 | 76 | 77 | ## Dependence 78 | 79 | install gfortran is openblas-src = "0.9" is used. 80 | 81 | ## Contributing 82 | 83 | Any contribution is welcome and please open an issue by creating a pull request. 84 | 85 | -------------------------------------------------------------------------------- /tensor-rs/src/tensor_impl/gen_tensor/rand.rs: -------------------------------------------------------------------------------- 1 | use super::GenTensor; 2 | use crate::tensor_trait::rand::Random; 3 | 4 | use rand::prelude::*; 5 | use rand::Rng; 6 | use rand_distr::{Normal, Uniform, Distribution, StandardNormal}; 7 | 8 | impl Random for GenTensor 9 | where T: num_traits::Float + rand_distr::uniform::SampleUniform, 10 | StandardNormal: Distribution { 11 | type TensorType = GenTensor; 12 | type ElementType = T; 13 | 14 | fn rand_usize(rng: &mut StdRng, 15 | dim: &[usize], 16 | left: usize, right: usize) -> Self::TensorType { 17 | let elem = dim.iter().product(); 18 | 19 | let mut dta = Vec::::with_capacity(elem); 20 | for _i in 0..elem { 21 | let v: usize = rng.gen_range(left..right); 22 | dta.push(T::from(v).unwrap()); 23 | } 24 | GenTensor::new_raw(&dta, dim) 25 | } 26 | 27 | fn bernoulli() -> Self::TensorType { 28 | unimplemented!(); 29 | } 30 | fn cauchy() -> Self::TensorType { 31 | unimplemented!(); 32 | } 33 | fn exponential() -> Self::TensorType { 34 | unimplemented!(); 35 | } 36 | fn geometric() -> Self::TensorType { 37 | unimplemented!(); 38 | } 39 | fn log_normal() -> Self::TensorType { 40 | unimplemented!(); 41 | } 42 | fn normal(rng: &mut StdRng, 43 | dim: &[usize], 44 | mean: Self::ElementType, 45 | std: Self::ElementType) -> Self::TensorType { 46 | let elem = dim.iter().product(); 47 | 48 | let mut dta = Vec::::with_capacity(elem); 49 | let normal = Normal::::new(mean, std).expect(""); 50 | for _i in 0..elem { 51 | dta.push(normal.sample(rng)); 52 | } 53 | GenTensor::new_raw(&dta, dim) 54 | } 55 | fn uniform(rng: &mut StdRng, 56 | dim: &[usize], 57 | from: Self::ElementType, 58 | to: Self::ElementType) -> Self::TensorType { 59 | let elem: usize = dim.iter().product(); 60 | 61 | let mut dta = Vec::::with_capacity(elem); 62 | let normal = Uniform::::new(from, to); 63 | for _i in 0..elem { 64 | dta.push(normal.sample(rng)); 65 | } 66 | GenTensor::new_raw(&dta, dim) 67 | } 68 | } 69 | 70 | 71 | #[cfg(test)] 72 | mod tests { 73 | use super::*; 74 | use crate::tensor_trait::compare_tensor::CompareTensor; 75 | 76 | #[test] 77 | fn normalize_unit() { 78 | let mut rng = StdRng::seed_from_u64(671); 79 | let m = GenTensor::::uniform(&mut rng, &[2,2], 0., 10.); 80 | assert!(GenTensor::::fill(10., &[2,2]).sub(&m).all(&|x| x > 0. && x < 10.)); 81 | } 82 | } 83 | -------------------------------------------------------------------------------- /ann/src/minibatch.rs: -------------------------------------------------------------------------------- 1 | use ::rand::prelude::StdRng; 2 | use auto_diff::{Var, AutoDiffError}; 3 | use auto_diff_data_pipe::dataloader::{DataLoader, DataSlice}; 4 | 5 | pub struct MiniBatch { 6 | rng: StdRng, 7 | size: usize, 8 | } 9 | impl MiniBatch { 10 | pub fn new(rng: StdRng, size: usize) -> MiniBatch { 11 | MiniBatch { 12 | rng, 13 | size, 14 | } 15 | } 16 | 17 | pub fn batch_size(&self) -> usize { 18 | self.size 19 | } 20 | 21 | /// Get a random set of samples from the data loader. 22 | pub fn next(&mut self, loader: &dyn DataLoader, part: &DataSlice) -> Result<(Var, Var), AutoDiffError> { 23 | let sample_size = loader.get_size(Some(*part))?[0]; 24 | let index_t = Var::rand_usize(&mut self.rng, &[self.size], 0, sample_size); 25 | loader.get_indexed_batch(&(Vec::::try_from(index_t)?), Some(*part)) 26 | } 27 | /// Get a random set of samples given the data and label. 28 | pub fn next_data_slice(&mut self, data: &Var, label: &Var) -> Result<(Var, Var), AutoDiffError> { 29 | let sample_size = data.size()[0]; 30 | let sample_size2 = label.size()[0]; 31 | 32 | if sample_size != sample_size2 { 33 | return Err(AutoDiffError::new(&format!("minibatch needs data and label has the same N {}, {}", 34 | sample_size, sample_size2))); 35 | } 36 | let index_t = Var::rand_usize(&mut self.rng, &[self.size], 0, sample_size); 37 | 38 | let mdata = data.index_select(0, index_t.clone())?; 39 | let mlabel = label.index_select(0, index_t)?; 40 | mdata.reset_net(); 41 | mlabel.reset_net(); 42 | Ok((mdata, mlabel)) 43 | } 44 | 45 | pub fn iter_block<'a>(&self, loader: &'a dyn DataLoader, part: & DataSlice) -> Result, AutoDiffError> { 46 | Ok(BlockIterator { 47 | loader, 48 | part: *part, 49 | block_size: self.size, 50 | block_index: 0, 51 | }) 52 | } 53 | } 54 | 55 | pub struct BlockIterator<'a> { 56 | loader: &'a dyn DataLoader, 57 | part: DataSlice, 58 | block_size: usize, 59 | block_index: usize, 60 | } 61 | impl<'a> Iterator for BlockIterator<'a> { 62 | type Item = (Var, Var); 63 | fn next(&mut self) -> Option { 64 | let n = if let Ok(size) = self.loader.get_size(Some(self.part)) { 65 | size[0] 66 | } else { 67 | return None; 68 | }; 69 | 70 | if self.block_index >= n { 71 | return None; 72 | } 73 | let mut end_index = self.block_index + self.block_size; 74 | if end_index > n { 75 | end_index = n; 76 | } 77 | 78 | let result = self.loader.get_batch(self.block_index, 79 | end_index, 80 | Some(self.part)); 81 | self.block_index += self.block_size; 82 | result.ok() 83 | } 84 | } 85 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # A simple machine learning toolset 2 | 3 | [![crates.io version](https://img.shields.io/crates/v/auto-diff.svg)](https://crates.io/crates/auto-diff) 4 | [![License](https://img.shields.io/crates/l/auto-diff.svg)](https://github.com/pipehappy1/auto-diff/blob/master/LICENSE.txt) 5 | [![example workflow](https://github.com/pipehappy1/auto-diff/actions/workflows/rust.yml/badge.svg)](https://github.com/pipehappy1/auto-diff/actions) 6 | [![doc badge](https://docs.rs/auto-diff/badge.svg)](https://docs.rs/auto-diff) 7 | 8 | 9 | ## Introduction 10 | 11 | This is an auto-difference library for deep neural network. 12 | 13 | Try **auto-diff** by adding the following to your Cargo.toml: 14 | ``` 15 | [dependencies] 16 | auto-diff = "0.5" 17 | ``` 18 | 19 | ## Features 20 | 21 | - A type less tensor. 22 | - Variable over tensor with support for back propagation. 23 | - Support for common operators, including convolution. 24 | 25 | ## Example 26 | 27 | ```rust,no_run 28 | use tensor_rs::tensor::Tensor; 29 | use auto_diff::rand::RNG; 30 | use auto_diff::var::{Module}; 31 | use auto_diff::optim::{SGD, Optimizer}; 32 | 33 | fn main() { 34 | 35 | fn func(input: &Tensor) -> Tensor { 36 | input.matmul(&Tensor::from_vec_f32(&vec![2., 3.], &vec![2, 1])).add(&Tensor::from_vec_f32(&vec![1.], &vec![1])) 37 | } 38 | 39 | let N = 100; 40 | let mut rng = RNG::new(); 41 | rng.set_seed(123); 42 | let data = rng.normal(&vec![N, 2], 0., 2.); 43 | let label = func(&data); 44 | 45 | 46 | let mut m = Module::new(); 47 | 48 | let op1 = m.linear(Some(2), Some(1), true); 49 | let weights = op1.get_values().unwrap(); 50 | rng.normal_(&weights[0], 0., 1.); 51 | rng.normal_(&weights[1], 0., 1.); 52 | op1.set_values(&weights); 53 | 54 | let op2 = op1.clone(); 55 | let block = m.func( 56 | move |x| { 57 | op2.call(x) 58 | } 59 | ); 60 | 61 | let loss_func = m.mse_loss(); 62 | 63 | let mut opt = SGD::new(3.); 64 | 65 | for i in 0..200 { 66 | let input = m.var_value(data.clone()); 67 | 68 | let y = block.call(&[&input]); 69 | 70 | let loss = loss_func.call(&[&y, &m.var_value(label.clone())]); 71 | println!("index: {}, loss: {}", i, loss.get().get_scale_f32()); 72 | 73 | loss.backward(-1.); 74 | opt.step2(&block); 75 | 76 | } 77 | 78 | let weights = op1.get_values().expect(""); 79 | println!("{:?}, {:?}", weights[0], weights[1]); 80 | } 81 | ``` 82 | 83 | ## TODO 84 | 85 | - Use cudnn and cutensor 86 | - Stride based tensor 87 | - Block components inspection by func call 88 | - serde 89 | 90 | ## Dependence 91 | 92 | - install gfortran is openblas-src = "0.9" is used. 93 | - To use Rust's bindgen feature on Ubuntu, for example, for cuda, do apt install llvm-dev libclang-dev clang. 94 | 95 | ## Contributing 96 | 97 | Any contribution is welcome and please open an issue by creating a pull request. 98 | -------------------------------------------------------------------------------- /ann/examples/mlp_on_mnist.rs: -------------------------------------------------------------------------------- 1 | use auto_diff::op::{Linear, OpCall}; 2 | use auto_diff::optim::{SGD}; 3 | use auto_diff_ann::minibatch::MiniBatch; 4 | //use auto_diff::Var; 5 | use auto_diff_ann::init::normal; 6 | use auto_diff_data_pipe::dataloader::{mnist::Mnist, DataSlice}; 7 | use tensorboard_rs::summary_writer::SummaryWriter; 8 | use std::path::Path; 9 | use rand::prelude::*; 10 | use ::rand::prelude::StdRng; 11 | use auto_diff_data_pipe::dataloader::DataLoader; 12 | use std::fs; 13 | 14 | extern crate openblas_src; 15 | 16 | 17 | fn main() { 18 | 19 | let mut rng = StdRng::seed_from_u64(671); 20 | 21 | let mnist = Mnist::load(&Path::new("../auto-diff/examples/data/mnist")); 22 | 23 | let train_size = mnist.get_size(Some(DataSlice::Train)).unwrap(); 24 | let h = train_size[1]; 25 | let w = train_size[2]; 26 | 27 | // init 28 | let mut op1 = Linear::new(Some(h*w), Some(120), true); 29 | normal(op1.weight(), None, None, &mut rng).unwrap(); 30 | normal(op1.bias(), None, None, &mut rng).unwrap(); 31 | 32 | let mut op2 = Linear::new(Some(120), Some(84), true); 33 | normal(op2.weight(), None, None, &mut rng).unwrap(); 34 | normal(op2.bias(), None, None, &mut rng).unwrap(); 35 | 36 | let mut op3 = Linear::new(Some(84), Some(10), true); 37 | normal(op3.weight(), None, None, &mut rng).unwrap(); 38 | normal(op3.bias(), None, None, &mut rng).unwrap(); 39 | 40 | 41 | let mut minibatch = MiniBatch::new(rng, 16); 42 | let mut writer = SummaryWriter::new(&("./logdir".to_string())); 43 | 44 | // get data 45 | let (input, label) = minibatch.next(&mnist, &DataSlice::Train).unwrap(); 46 | let input = input.reshape(&[16, h*w]).unwrap(); 47 | input.reset_net(); 48 | 49 | // the network 50 | let output1 = op1.call(&[&input]).unwrap().pop().unwrap(); 51 | let output2 = output1.relu().unwrap(); 52 | let output3 = op2.call(&[&output2]).unwrap().pop().unwrap(); 53 | let output4 = output3.relu().unwrap(); 54 | let output = op3.call(&[&output4]).unwrap().pop().unwrap(); 55 | 56 | // label the predict var. 57 | output.set_predict().unwrap(); 58 | 59 | let loss = output.cross_entropy_loss(&label).unwrap(); 60 | 61 | let lr = 0.001; 62 | let mut opt = SGD::new(lr); 63 | 64 | for i in 0..100000 { 65 | let (input_next, label_next) = minibatch.next(&mnist, &DataSlice::Train).unwrap(); 66 | let input_next = input_next.reshape(&[16, h*w]).unwrap(); 67 | input_next.reset_net(); 68 | 69 | // set data and label 70 | input.set(&input_next); 71 | label.set(&label_next); 72 | 73 | loss.rerun().unwrap(); 74 | loss.bp().unwrap(); 75 | loss.step(&mut opt).unwrap(); 76 | 77 | if i % 1000 == 0 { 78 | println!("i: {:?}, loss: {:?}", i, loss); 79 | writer.add_scalar(&"mlp_mnist/train_loss".to_string(), f64::try_from(loss.clone()).unwrap() as f32, i); 80 | 81 | let encoded: Vec = bincode::serialize(&loss).unwrap(); 82 | fs::write(format!("saved_model/net_{}", i), encoded).expect("Unable to write file"); 83 | } 84 | } 85 | } 86 | -------------------------------------------------------------------------------- /auto-diff/src/op/local.rs: -------------------------------------------------------------------------------- 1 | #![allow(clippy::redundant_closure_call)] 2 | use super::macros::new_binary_op; 3 | use super::{OpHandle, OpTrait}; 4 | use tensor_rs::tensor::Tensor; 5 | 6 | #[cfg(feature = "use-serde")] 7 | use serde::{Deserialize, Serialize}; 8 | #[cfg(feature = "use-serde")] 9 | use std::any::Any; 10 | 11 | new_binary_op!( 12 | Add, 13 | "Add", 14 | (|a: &[Tensor], b: &[Tensor]| b[0].swap(&a[0].add(&a[1]))), 15 | (|input: &[Tensor], output_grad: &[Tensor], input_grad: &[Tensor]| { 16 | let x = input[0].ones_like().mul(&output_grad[0]); 17 | let y = input[1].ones_like().mul(&output_grad[0]); 18 | input_grad[0].swap(&x); 19 | input_grad[1].swap(&y); 20 | }) 21 | ); 22 | new_binary_op!( 23 | Sub, 24 | "Sub", 25 | (|a: &[Tensor], b: &[Tensor]| b[0].swap(&a[0].sub(&a[1]))), 26 | (|input: &[Tensor], output_grad: &[Tensor], input_grad: &[Tensor]| { 27 | let x = input[0].ones_like().mul(&output_grad[0]); 28 | let y = input[1].ones_like().neg().mul(&output_grad[0]); 29 | input_grad[0].swap(&x); 30 | input_grad[1].swap(&y); 31 | }) 32 | ); 33 | new_binary_op!( 34 | Mul, 35 | "Mul", 36 | (|a: &[Tensor], b: &[Tensor]| b[0].swap(&a[0].mul(&a[1]))), 37 | (|input: &[Tensor], output_grad: &[Tensor], input_grad: &[Tensor]| { 38 | let x = input[1].mul(&output_grad[0]); 39 | let y = input[0].mul(&output_grad[0]); 40 | input_grad[0].swap(&x); 41 | input_grad[1].swap(&y); 42 | }) 43 | ); 44 | new_binary_op!( 45 | Div, 46 | "Div", 47 | (|a: &[Tensor], b: &[Tensor]| b[0].swap(&a[0].div(&a[1]))), 48 | (|input: &[Tensor], output_grad: &[Tensor], input_grad: &[Tensor]| { 49 | let x = input[1].reciprocal().mul(&output_grad[0]); 50 | let y = input[0] 51 | .neg() 52 | .div(&input[1]) 53 | .div(&input[1]) 54 | .mul(&output_grad[0]); 55 | input_grad[0].swap(&x); 56 | input_grad[1].swap(&y); 57 | }) 58 | ); 59 | 60 | new_binary_op!( 61 | Matmul, 62 | "Matmul", 63 | (|a: &[Tensor], b: &[Tensor]| b[0].swap(&a[0].matmul(&a[1]))), 64 | (|input: &[Tensor], output_grad: &[Tensor], input_grad: &[Tensor]| { 65 | input_grad[0].swap(&input[1].outer(&output_grad[0], Some(true))); 66 | input_grad[1].swap(&input[0].outer(&output_grad[0], Some(true))); 67 | }) 68 | ); 69 | 70 | new_binary_op!( 71 | Outer, 72 | "Outer", 73 | (|a: &[Tensor], b: &[Tensor]| b[0].swap(&a[0].outer(&a[1], None))), 74 | (|input: &[Tensor], output_grad: &[Tensor], input_grad: &[Tensor]| { 75 | unimplemented!(); 76 | }) 77 | ); 78 | 79 | #[cfg(test)] 80 | mod tests { 81 | use super::*; 82 | use crate::op::_gradient_checker; 83 | 84 | #[test] 85 | fn matmul() { 86 | let mut op = Mul::new(); 87 | 88 | for i in 0..10 { 89 | let zero = Tensor::from_vec_f64(&vec![(i - 5) as f64], &vec![1]); 90 | let zero2 = zero.clone(); 91 | let good_grad = _gradient_checker(&mut op, &[zero, zero2], None, None, None); 92 | assert_eq!(good_grad, true); 93 | } 94 | } 95 | } 96 | -------------------------------------------------------------------------------- /auto-diff/src/op/pooling.rs: -------------------------------------------------------------------------------- 1 | use super::{OpHandle, OpTrait}; 2 | use tensor_rs::tensor::Tensor; 3 | 4 | #[cfg(feature = "use-serde")] 5 | use serde::{Deserialize, Serialize}; 6 | #[cfg(feature = "use-serde")] 7 | use std::any::Any; 8 | 9 | // MaxPool1d 10 | // Maxpool2d 11 | #[cfg_attr(feature = "use-serde", derive(Serialize, Deserialize))] 12 | pub struct MaxPool2d { 13 | #[cfg_attr(feature = "use-serde", serde(skip))] 14 | handle: OpHandle, 15 | kernel_size: (usize, usize), 16 | stride: (usize, usize), 17 | padding: Tensor, 18 | dilation: (usize, usize), 19 | return_indices: bool, 20 | ceil_mode: bool, 21 | } 22 | impl MaxPool2d { 23 | pub fn new( 24 | kernel_size: Option<(usize, usize)>, 25 | stride: Option<(usize, usize)>, 26 | padding: Option, 27 | dilation: Option<(usize, usize)>, 28 | return_indices: Option, 29 | ceil_mode: Option, 30 | ) -> MaxPool2d { 31 | let kernel_size = if let Some(v) = kernel_size { v } else { (2, 2) }; 32 | let stride = if let Some(v) = stride { v } else { (2, 2) }; 33 | let padding = if let Some(v) = padding { 34 | v 35 | } else { 36 | Tensor::zeros(&[1]) 37 | }; 38 | let dilation = if let Some(v) = dilation { v } else { (2, 2) }; 39 | let return_indices = if let Some(v) = return_indices { 40 | v 41 | } else { 42 | false 43 | }; 44 | let ceil_mode = if let Some(v) = ceil_mode { v } else { false }; 45 | MaxPool2d { 46 | handle: OpHandle::new(), 47 | kernel_size, 48 | stride, 49 | padding, 50 | dilation, 51 | return_indices, 52 | ceil_mode, 53 | } 54 | } 55 | fn get_handle(&self) -> &OpHandle { 56 | &self.handle 57 | } 58 | fn get_handle_mut(&mut self) -> &mut OpHandle { 59 | &mut self.handle 60 | } 61 | } 62 | impl OpTrait for MaxPool2d { 63 | fn get_name(&self) -> &'static str { 64 | "MaxPool2d" 65 | } 66 | fn get_input_size(&self) -> usize { 67 | 1 68 | } 69 | fn get_output_size(&self) -> usize { 70 | 1 71 | } 72 | fn apply(&self, input: &[Tensor], output: &[Tensor]) { 73 | unimplemented!(); 74 | } 75 | fn grad(&self, input: &[Tensor], output_grad: &[Tensor], input_grad: &[Tensor]) { 76 | unimplemented!(); 77 | } 78 | fn get_values(&self) -> Vec { 79 | Vec::new() 80 | } 81 | fn get_grads(&self) -> Vec { 82 | Vec::new() 83 | } 84 | fn set_values(&self, _v: &[Tensor]) {} 85 | #[cfg(feature = "use-serde")] 86 | fn as_any(&self) -> &dyn Any { 87 | self 88 | } 89 | } 90 | 91 | // MaxPool3d 92 | // MaxUnpool1d 93 | // MaxUnpool2d 94 | // MaxUnpool3d 95 | // AvgPool1d 96 | // AvgPool2d 97 | // AvgPool3d 98 | // FractionalMaxPool2d 99 | // LPPool1d 100 | // LPPool2d 101 | // AdaptiveMaxPool1d 102 | // AdaptiveMaxPool2d 103 | // AdaptiveMaxPool3d 104 | // AdaptiveAvgPool1d 105 | // AdaptiveAvgPool2d 106 | // AdaptiveAvgPool3d 107 | // 108 | -------------------------------------------------------------------------------- /tensorboard-rs/src/summary.rs: -------------------------------------------------------------------------------- 1 | #![allow(clippy::too_many_arguments)] 2 | use tensorboard_proto::summary::{Summary, Summary_Value, Summary_Image, SummaryMetadata, SummaryMetadata_PluginData, HistogramProto}; 3 | use tensorboard_proto::layout::{Layout, Category}; 4 | use protobuf::RepeatedField; 5 | 6 | use image::{RgbImage, DynamicImage, ImageOutputFormat}; 7 | 8 | pub fn scalar(name: &str, scalar_value: f32) -> Summary { 9 | 10 | let mut value = Summary_Value::new(); 11 | value.set_tag(name.to_string()); 12 | value.set_simple_value(scalar_value); 13 | 14 | let values = RepeatedField::from(vec![value]); 15 | let mut summary = Summary::new(); 16 | summary.set_value(values); 17 | 18 | summary 19 | } 20 | 21 | pub fn histogram_raw(name: &str, 22 | min: f64, max: f64, 23 | num: f64, 24 | sum: f64, sum_squares: f64, 25 | bucket_limits: &[f64], 26 | bucket_counts: &[f64], 27 | ) -> Summary { 28 | let mut hist = HistogramProto::new(); 29 | hist.set_min(min); 30 | hist.set_max(max); 31 | hist.set_num(num); 32 | hist.set_sum(sum); 33 | hist.set_sum_squares(sum_squares); 34 | hist.set_bucket_limit(bucket_limits.to_vec()); 35 | hist.set_bucket(bucket_counts.to_vec()); 36 | 37 | let mut value = Summary_Value::new(); 38 | value.set_tag(name.to_string()); 39 | value.set_histo(hist); 40 | 41 | let values = RepeatedField::from(vec![value]); 42 | let mut summary = Summary::new(); 43 | summary.set_value(values); 44 | 45 | summary 46 | } 47 | 48 | /// dim is in CHW 49 | pub fn image(tag: &str, data: &[u8], dim: &[usize]) -> Summary { 50 | if dim.len() != 3 { 51 | panic!("format:CHW"); 52 | } 53 | if dim[0] != 3 { 54 | panic!("needs rgb"); 55 | } 56 | if data.len() != dim[0]*dim[1]*dim[2] { 57 | panic!("length of data should matches with dim."); 58 | } 59 | 60 | let mut img = RgbImage::new(dim[1] as u32, dim[2] as u32); 61 | img.clone_from_slice(data); 62 | let dimg = DynamicImage::ImageRgb8(img); 63 | let mut output_buf = Vec::::new(); 64 | dimg.write_to(&mut output_buf, ImageOutputFormat::Png).expect(""); 65 | 66 | let mut output_image = Summary_Image::new(); 67 | output_image.set_height(dim[1] as i32); 68 | output_image.set_width(dim[2] as i32); 69 | output_image.set_colorspace(3); 70 | output_image.set_encoded_image_string(output_buf); 71 | let mut value = Summary_Value::new(); 72 | value.set_tag(tag.to_string()); 73 | value.set_image(output_image); 74 | let values = RepeatedField::from(vec![value]); 75 | let mut summary = Summary::new(); 76 | summary.set_value(values); 77 | 78 | summary 79 | } 80 | 81 | pub fn custom_scalars(_layout: f32) { 82 | let mut layout = Layout::new(); 83 | let value = Category::new(); 84 | let values = RepeatedField::from(vec![value]); 85 | layout.set_category(values); 86 | 87 | let mut plugin_data = SummaryMetadata_PluginData::new(); 88 | plugin_data.set_plugin_name("custom_scalars".to_string()); 89 | let mut smd = SummaryMetadata::new(); 90 | smd.set_plugin_data(plugin_data); 91 | 92 | 93 | } 94 | -------------------------------------------------------------------------------- /auto-diff/src/optim.rs: -------------------------------------------------------------------------------- 1 | //! 2 | //! Gradient based optimization. 3 | //! 4 | use super::compute_graph::Net; 5 | use crate::err::AutoDiffError; 6 | use crate::var::Var; 7 | use rand::prelude::StdRng; 8 | use std::cell::RefCell; 9 | use std::rc::Rc; 10 | use tensor_rs::tensor::Tensor; 11 | 12 | /// Create random batch view from a large batch. 13 | pub struct MiniBatch { 14 | rng: StdRng, 15 | size: usize, 16 | } 17 | impl MiniBatch { 18 | pub fn new(rng: StdRng, size: usize) -> MiniBatch { 19 | MiniBatch { rng, size } 20 | } 21 | 22 | pub fn next(&mut self, data: &Var, label: &Var) -> Result<(Var, Var), AutoDiffError> { 23 | let sample_size = data.size()[0]; 24 | let sample_size2 = label.size()[0]; 25 | 26 | if sample_size != sample_size2 { 27 | return Err(AutoDiffError::new(&format!( 28 | "minibatch needs data and label has the same N {}, {}", 29 | sample_size, sample_size2 30 | ))); 31 | } 32 | let index_t = Var::rand_usize(&mut self.rng, &[self.size], 0, sample_size); 33 | 34 | let mdata = data.index_select(0, index_t.clone())?; 35 | let mlabel = label.index_select(0, index_t)?; 36 | mdata.reset_net(); 37 | mlabel.reset_net(); 38 | Ok((mdata, mlabel)) 39 | } 40 | } 41 | 42 | pub trait Optimizer { 43 | fn step(&mut self, net: Rc>); 44 | } 45 | 46 | // actually it's GD 47 | pub struct SGD { 48 | lr: Tensor, 49 | } 50 | impl SGD { 51 | #[cfg(feature = "use-f64")] 52 | pub fn new(lr: f64) -> SGD { 53 | Self::new_f64(lr) 54 | } 55 | #[cfg(feature = "use-f32")] 56 | pub fn new(lr: f32) -> SGD { 57 | Self::new_f32(lr) 58 | } 59 | 60 | pub fn new_f64(lr: f64) -> SGD { 61 | SGD { 62 | lr: Tensor::from_vec_f64(&[lr], &[1]), 63 | } 64 | } 65 | pub fn new_f32(lr: f32) -> SGD { 66 | SGD { 67 | lr: Tensor::from_vec_f32(&[lr], &[1]), 68 | } 69 | } 70 | } 71 | impl Optimizer for SGD { 72 | fn step(&mut self, net: Rc>) { 73 | net.borrow_mut().visit_op( 74 | |x| { 75 | let weights = x.get_values(); 76 | let grads = x.get_grads(); 77 | //println!("name: {}, {}, {}", x.get_name(), weights.len(), grads.len()); 78 | 79 | let mut new_weight = Vec::new(); 80 | for (i, j) in weights.iter().zip(grads.iter()) { 81 | //println!("{:?}, {:?}, {:?}", i.size(), j.size(), self.lr.size()); 82 | new_weight.push(i.sub(&j.mul(&self.lr))); 83 | } 84 | x.set_values(&new_weight); 85 | }, 86 | None, 87 | None, 88 | ); 89 | } 90 | } 91 | 92 | #[cfg(test)] 93 | mod tests { 94 | use super::*; 95 | use crate::var::Var; 96 | use rand::prelude::*; 97 | 98 | #[test] 99 | fn mini_batch() { 100 | let data = Var::ones(&[10, 3]); 101 | let label = Var::zeros(&[10]); 102 | 103 | let rng = StdRng::seed_from_u64(671); 104 | let mut minibatch = MiniBatch::new(rng, 4); 105 | let (mdata, mlabel) = minibatch.next(&data, &label).unwrap(); 106 | 107 | assert_eq!(mdata.size(), [4, 3]); 108 | assert_eq!(mlabel.size(), [4]); 109 | println!("{:?}, {:?}", mdata, mlabel); 110 | } 111 | } 112 | -------------------------------------------------------------------------------- /tensorboard-rs/src/event_file_writer.rs: -------------------------------------------------------------------------------- 1 | use std::path::{PathBuf, Path}; 2 | use std::fs; 3 | use std::time::SystemTime; 4 | use gethostname::gethostname; 5 | use std::process::id; 6 | use std::fs::File; 7 | use protobuf::Message; 8 | use std::thread::{spawn, JoinHandle}; 9 | use std::sync::mpsc::{channel, Sender}; 10 | 11 | use tensorboard_proto::event::Event; 12 | use crate::record_writer::RecordWriter; 13 | 14 | enum EventSignal { 15 | Data(Vec), 16 | Flush, 17 | Stop, 18 | } 19 | 20 | pub struct EventFileWriter { 21 | logdir: PathBuf, 22 | writer: Sender, 23 | child: Option>, 24 | } 25 | impl EventFileWriter { 26 | //pub fn new>(logdir: P) -> EventFileWriter { 27 | pub fn new>(logdir: P) -> EventFileWriter { 28 | let logdir = logdir.as_ref().to_path_buf(); 29 | 30 | fs::create_dir_all(&logdir).expect(""); 31 | 32 | let mut time = 0; 33 | let mut time_full = 0.0; 34 | if let Ok(n) = SystemTime::now().duration_since(SystemTime::UNIX_EPOCH) { 35 | time = n.as_secs(); 36 | time_full = n.as_secs_f64(); 37 | } 38 | let hostname = gethostname().into_string().expect(""); 39 | let pid = id(); 40 | 41 | let file_name = format!("events.out.tfevents.{:010}.{}.{}.{}", time, hostname, pid, 0); 42 | //let file_writer = File::create(logdir.join(file_name)).expect(""); 43 | //let writer = RecordWriter::new(file_writer); 44 | 45 | let logdir_move = logdir.clone(); 46 | let (tx, rx) = channel(); 47 | let child = spawn(move || { 48 | let file_writer = File::create(logdir_move.join(file_name)).expect(""); 49 | let mut writer = RecordWriter::new(file_writer); 50 | 51 | loop { 52 | let result: EventSignal = rx.recv().unwrap(); 53 | match result { 54 | EventSignal::Data(d) => { 55 | writer.write(&d).expect("write error"); 56 | }, 57 | EventSignal::Flush => {writer.flush().expect("flush error");}, 58 | EventSignal::Stop => {break;}, 59 | } 60 | }; 61 | writer.flush().expect("flush error"); 62 | }); 63 | 64 | let mut ret = EventFileWriter { 65 | logdir, 66 | writer: tx, 67 | child: Some(child), 68 | }; 69 | 70 | let mut evn = Event::new(); 71 | evn.set_wall_time(time_full); 72 | evn.set_file_version("brain.Event:2".to_string()); 73 | ret.add_event(&evn); 74 | ret.flush(); 75 | 76 | ret 77 | } 78 | } 79 | 80 | impl EventFileWriter { 81 | pub fn get_logdir(&self) -> PathBuf { 82 | self.logdir.to_path_buf() 83 | } 84 | 85 | pub fn add_event(&mut self, event: &Event) { 86 | let mut data: Vec = Vec::new(); 87 | event.write_to_vec(&mut data).expect(""); 88 | self.writer.send(EventSignal::Data(data)).expect(""); 89 | } 90 | 91 | pub fn flush(&mut self) { 92 | self.writer.send(EventSignal::Flush).expect(""); 93 | } 94 | } 95 | 96 | impl Drop for EventFileWriter { 97 | fn drop(&mut self) { 98 | self.writer.send(EventSignal::Stop).expect(""); 99 | self.child.take().unwrap().join().expect(""); 100 | } 101 | } 102 | -------------------------------------------------------------------------------- /tensor-rs/src/tensor_trait/index_slicing.rs: -------------------------------------------------------------------------------- 1 | use std::marker::Sized; 2 | 3 | pub trait IndexSlicing where Self: Sized { 4 | 5 | /// Concatenates the given sequence of seq tensors 6 | /// in the given dimension. 7 | /// The input tensor should all have the same size except 8 | /// on the given dimension. 9 | /// The output tensor will have all the same size as the input 10 | /// except the given dimension, which will be the sum of 11 | /// the inputs on the given dimension. 12 | /// Apply cat on [tensor(5, 3, 2), tensor(5, 7, 2), ] 13 | /// along index 1 dimension 14 | /// will get a tensor(5, 10, 2). 15 | fn cat(&self, tensors: &[Self], dim: usize) -> Self; 16 | 17 | /// Splits a tensor into a specific number of chunks. 18 | fn chunk(&self, chunks: usize, dim: usize) -> Vec; 19 | 20 | /// Pick elements on the given dimension by the index, 21 | /// and gather them in the output. 22 | /// A restriction is that self.size() and index.size() 23 | /// should be the same on other dimensions. 24 | fn gather(&self, dim: usize, index: &Self) -> Self; 25 | /// The opposite of gather. 26 | /// Self will be replaced with value along dim by index. 27 | fn spread(&self, dim: usize, index: &Self, value: &Self) -> Self; 28 | 29 | /// Select on dim and collect those subtensor by index. 30 | fn index_select(&self, dim: usize, index: &Self) -> Self; 31 | 32 | /// Inverse of index_select, remove those subtensor by index along dim. 33 | fn index_exclude(&self, dim: usize, index: &Self) -> Self; 34 | // fn masked_select(); 35 | //pub fn narrow() {} 36 | //pub fn nonzero() {} 37 | 38 | /// Just change the index boundary. 39 | fn reshape(&self, new_shape: &[usize]) -> Self; 40 | 41 | /// Inverse of cat(), split tensor along dim dimension, 42 | /// the length of each section on dim is specified by sections. 43 | fn split(&self, sections: &[usize], dim: usize) -> Vec; 44 | 45 | /// Remove dimension with length of 1. 46 | /// Only apply on specific dim if dim is supplied. 47 | fn squeeze(&self, dim: Option) -> Self; 48 | 49 | /// Stack tensor with the same size along a new dimension 50 | /// specified by dim. 51 | /// The difference from cat is that cat don't create new dimension. 52 | fn stack(&self, tensors: &[Self], dim: usize) -> Self; 53 | 54 | /// Transpose 55 | fn t(&self) -> Self; 56 | 57 | /// Returns a new tensor with the elements of input at the given indices. 58 | /// The input tensor is treated as if it were viewed as a 1-D tensor. 59 | /// The result takes the same shape as the indices. 60 | fn take(&self, index: &[usize]) -> Self; 61 | //pub fn transpose() {} 62 | //pub fn unbind() {} 63 | 64 | /// 65 | fn permute(&self, dims: &[usize]) -> Self; 66 | 67 | /// Add size 1 dimension at dim. 68 | fn unsqueeze(&self, dim: usize) -> Self; 69 | //pub fn condition() {} // this is pytorch where 70 | 71 | /// Self is the bool condition, at each position of self, 72 | /// select from x if self at the position is positive or zero, 73 | /// Otherwise , use value from y if self at the position is negative. 74 | /// The restriction is that, self, x, and y all have the same size. 75 | fn conditional_select(&self, x: &Self, y: &Self) -> Self; 76 | /// Repeat the tensor along all dimensions, 77 | /// the number of repeat is specified in sizes. 78 | /// Thus the restriction is that self.size().len() is 79 | /// equal to sizes.len(). 80 | fn repeat(&self, sizes: &[usize]) -> Self; 81 | } 82 | -------------------------------------------------------------------------------- /tensor-rs/src/serde/tensor.rs: -------------------------------------------------------------------------------- 1 | #[cfg(feature = "use-serde")] 2 | use serde::{Serialize, Deserialize, Serializer, Deserializer, 3 | ser::SerializeStruct, 4 | de, de::Visitor, de::SeqAccess, de::MapAccess}; 5 | use crate::tensor::Tensor; 6 | use std::fmt; 7 | 8 | impl Serialize for Tensor { 9 | fn serialize(&self, serializer: S) -> Result 10 | where S: Serializer, { 11 | // 3 is the number of fields in the struct. 12 | let mut state = serializer.serialize_struct("Tensor", 1)?; 13 | state.serialize_field("v", &self.inner().borrow().clone())?; 14 | state.end() 15 | } 16 | } 17 | 18 | impl<'de> Deserialize<'de> for Tensor { 19 | fn deserialize(deserializer: D) -> Result 20 | where D: Deserializer<'de>, { 21 | 22 | enum Field { V } 23 | 24 | impl<'de> Deserialize<'de> for Field { 25 | fn deserialize(deserializer: D) -> Result 26 | where D: Deserializer<'de>, { 27 | struct FieldVisitor; 28 | 29 | impl<'de> Visitor<'de> for FieldVisitor { 30 | type Value = Field; 31 | 32 | fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { 33 | formatter.write_str("v") 34 | } 35 | 36 | fn visit_str(self, value: &str) -> Result 37 | where E: de::Error, { 38 | match value { 39 | "v" => Ok(Field::V), 40 | _ => Err(de::Error::unknown_field(value, &FIELDS)), 41 | } 42 | } 43 | } 44 | 45 | deserializer.deserialize_identifier(FieldVisitor) 46 | } 47 | } 48 | 49 | struct TensorVisitor; 50 | 51 | impl<'de> Visitor<'de> for TensorVisitor { 52 | type Value = Tensor; 53 | 54 | fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { 55 | formatter.write_str("struct Tensor") 56 | } 57 | 58 | fn visit_map(self, mut map: V) -> Result 59 | where V: MapAccess<'de>, { 60 | let mut v = None; 61 | while let Some(key) = map.next_key()? { 62 | match key { 63 | Field::V => { 64 | if v.is_some() { 65 | return Err(de::Error::duplicate_field("v")); 66 | } 67 | v = Some(map.next_value()?); 68 | } 69 | } 70 | } 71 | let v = v.ok_or_else(|| de::Error::missing_field("ok"))?; 72 | Ok(Tensor::set_inner(v)) 73 | } 74 | 75 | fn visit_seq(self, mut seq: V) -> Result 76 | where V: SeqAccess<'de>, { 77 | let tt = seq.next_element()? 78 | .ok_or_else(|| de::Error::invalid_length(0, &self))?; 79 | Ok(Tensor::set_inner(tt)) 80 | } 81 | } 82 | 83 | const FIELDS: [&str; 1] = ["v"]; 84 | deserializer.deserialize_struct("Duration", &FIELDS, TensorVisitor) 85 | } 86 | } 87 | 88 | 89 | #[cfg(all(test, feature = "use-serde"))] 90 | mod tests { 91 | use crate::tensor::Tensor; 92 | 93 | #[test] 94 | fn test_serde() { 95 | let m1 = Tensor::eye(3,3); 96 | 97 | let serialized = serde_pickle::to_vec(&m1, true).unwrap(); 98 | let deserialized = serde_pickle::from_slice(&serialized).unwrap(); 99 | //println!("{:?}", deserialized); 100 | assert_eq!(m1, deserialized); 101 | } 102 | } 103 | -------------------------------------------------------------------------------- /auto-diff/examples/mnist.rs: -------------------------------------------------------------------------------- 1 | use std::fs::File; 2 | use std::path::Path; 3 | use std::io; 4 | use std::io::Read; 5 | 6 | use auto_diff::Var; 7 | 8 | //use tensorboard_rs::summary_writer::SummaryWriter; 9 | 10 | pub fn load_images>(path: P) -> Var { 11 | let ref mut reader = io::BufReader::new(File::open(path).expect("")); 12 | let magic = read_as_u32(reader); 13 | if magic != 2051 { 14 | panic!("Invalid magic number. expected 2051, got {}", magic) 15 | } 16 | let num_image = read_as_u32(reader) as usize; 17 | let rows = read_as_u32(reader) as usize; 18 | let cols = read_as_u32(reader) as usize; 19 | assert!(rows == 28 && cols == 28); 20 | 21 | // read images 22 | let mut buf: Vec = vec![0u8; num_image * rows * cols]; 23 | let _ = reader.read_exact(buf.as_mut()); 24 | let ret: Vec = buf.into_iter().map(|x| (x as f64) / 255.).collect(); 25 | let ret = Var::new(&ret[..], &vec![num_image, rows, cols]); 26 | ret 27 | } 28 | 29 | pub fn load_labels>(path: P) -> Var { 30 | let ref mut reader = io::BufReader::new(File::open(path).expect("")); 31 | let magic = read_as_u32(reader); 32 | if magic != 2049 { 33 | panic!("Invalid magic number. Got expect 2049, got {}", magic); 34 | } 35 | let num_label = read_as_u32(reader) as usize; 36 | // read labels 37 | let mut buf: Vec = vec![0u8; num_label]; 38 | let _ = reader.read_exact(buf.as_mut()); 39 | let ret: Vec = buf.into_iter().map(|x| x as f64).collect(); 40 | let ret = Var::new(&ret[..], &vec![num_label]); 41 | ret 42 | } 43 | 44 | fn read_as_u32(reader: &mut T) -> u32 { 45 | let mut buf: [u8; 4] = [0, 0, 0, 0]; 46 | let _ = reader.read_exact(&mut buf); 47 | u32::from_be_bytes(buf) 48 | } 49 | 50 | #[allow(dead_code)] 51 | pub fn main() { 52 | let t = load_images("examples/data/mnist/train-images-idx3-ubyte"); 53 | 54 | //let mut writer = SummaryWriter::new(&("./logdir".to_string())); 55 | 56 | for i in 0..10 { 57 | let first_image = t.get_patch(&vec![(i,i+1),(0,28),(0,28)], None).unwrap(); 58 | //println!("{:?}, {}, {}", first_image.size(), first_image.max(None, None, None), first_image.min(None, None, None)); 59 | let rgb_img = first_image.cat(&vec![first_image.clone(), first_image.clone()], 0).unwrap(); 60 | let rgb_img = rgb_img.permute(&vec![1, 2, 0]).unwrap(); 61 | let _rgb_img = rgb_img * Var::fill(&vec![1], &Var::new(&[255.], &[1])); 62 | // writer.add_image(&"test_image".to_string(), &rgb_img.get_u8().expect("u8")[..], &vec![3, 28, 28][..], i+32); 63 | } 64 | 65 | let first_image = t.get_patch(&vec![(0,1),(0,28),(0,28)], None).unwrap(); 66 | //println!("{:?}, {}, {}", first_image.size(), first_image.max(None, None, None), first_image.min(None, None, None)); 67 | let rgb_img = first_image.cat(&vec![first_image.clone(), first_image.clone()], 0).unwrap(); 68 | let rgb_img = rgb_img.permute(&vec![1, 2, 0]).unwrap(); 69 | let _rgb_img = rgb_img * Var::fill(&vec![1], &Var::new(&[255.], &[1])); 70 | //writer.add_image(&"test_image".to_string(), &rgb_img.get_u8().expect("u8")[..], &vec![3, 28, 28][..], 12); 71 | //writer.flush(); 72 | 73 | 74 | let first_image = t.get_patch(&vec![(10,11),(0,28),(0,28)], None).unwrap(); 75 | //println!("{:?}, {}, {}", first_image.size(), first_image.max(None, None, None), first_image.min(None, None, None)); 76 | let rgb_img = first_image.cat(&vec![first_image.clone(), first_image.clone()], 0).unwrap(); 77 | let rgb_img = rgb_img.permute(&vec![1, 2, 0]).unwrap(); 78 | let _rgb_img = rgb_img * Var::fill(&vec![1], &Var::new(&[255.], &[1])); 79 | //writer.add_image(&"test_image".to_string(), &rgb_img.get_u8().expect("u8")[..], &vec![3, 28, 28][..], 13); 80 | //writer.flush(); 81 | 82 | let l = load_labels("examples/data/mnist/train-labels-idx1-ubyte"); 83 | println!("{}, {}", l.get_f32(&vec![0]).unwrap(), l.get_f32(&vec![10]).unwrap()); 84 | } 85 | -------------------------------------------------------------------------------- /tensorboard-rs/src/masked_crc32c.rs: -------------------------------------------------------------------------------- 1 | pub fn masked_crc32c(data: &[u8]) -> u32 { 2 | let x = crc32c(data); 3 | ((x >> 15) | (x << 17)).overflowing_add(0xa282ead8).0 4 | } 5 | 6 | //pub fn u32(data: &[u8]) -> u32{ 7 | // 8 | //} 9 | 10 | const CRC_TABLE: [u32; 256] = [ 11 | 0x00000000, 0xf26b8303, 0xe13b70f7, 0x1350f3f4, 12 | 0xc79a971f, 0x35f1141c, 0x26a1e7e8, 0xd4ca64eb, 13 | 0x8ad958cf, 0x78b2dbcc, 0x6be22838, 0x9989ab3b, 14 | 0x4d43cfd0, 0xbf284cd3, 0xac78bf27, 0x5e133c24, 15 | 0x105ec76f, 0xe235446c, 0xf165b798, 0x030e349b, 16 | 0xd7c45070, 0x25afd373, 0x36ff2087, 0xc494a384, 17 | 0x9a879fa0, 0x68ec1ca3, 0x7bbcef57, 0x89d76c54, 18 | 0x5d1d08bf, 0xaf768bbc, 0xbc267848, 0x4e4dfb4b, 19 | 0x20bd8ede, 0xd2d60ddd, 0xc186fe29, 0x33ed7d2a, 20 | 0xe72719c1, 0x154c9ac2, 0x061c6936, 0xf477ea35, 21 | 0xaa64d611, 0x580f5512, 0x4b5fa6e6, 0xb93425e5, 22 | 0x6dfe410e, 0x9f95c20d, 0x8cc531f9, 0x7eaeb2fa, 23 | 0x30e349b1, 0xc288cab2, 0xd1d83946, 0x23b3ba45, 24 | 0xf779deae, 0x05125dad, 0x1642ae59, 0xe4292d5a, 25 | 0xba3a117e, 0x4851927d, 0x5b016189, 0xa96ae28a, 26 | 0x7da08661, 0x8fcb0562, 0x9c9bf696, 0x6ef07595, 27 | 0x417b1dbc, 0xb3109ebf, 0xa0406d4b, 0x522bee48, 28 | 0x86e18aa3, 0x748a09a0, 0x67dafa54, 0x95b17957, 29 | 0xcba24573, 0x39c9c670, 0x2a993584, 0xd8f2b687, 30 | 0x0c38d26c, 0xfe53516f, 0xed03a29b, 0x1f682198, 31 | 0x5125dad3, 0xa34e59d0, 0xb01eaa24, 0x42752927, 32 | 0x96bf4dcc, 0x64d4cecf, 0x77843d3b, 0x85efbe38, 33 | 0xdbfc821c, 0x2997011f, 0x3ac7f2eb, 0xc8ac71e8, 34 | 0x1c661503, 0xee0d9600, 0xfd5d65f4, 0x0f36e6f7, 35 | 0x61c69362, 0x93ad1061, 0x80fde395, 0x72966096, 36 | 0xa65c047d, 0x5437877e, 0x4767748a, 0xb50cf789, 37 | 0xeb1fcbad, 0x197448ae, 0x0a24bb5a, 0xf84f3859, 38 | 0x2c855cb2, 0xdeeedfb1, 0xcdbe2c45, 0x3fd5af46, 39 | 0x7198540d, 0x83f3d70e, 0x90a324fa, 0x62c8a7f9, 40 | 0xb602c312, 0x44694011, 0x5739b3e5, 0xa55230e6, 41 | 0xfb410cc2, 0x092a8fc1, 0x1a7a7c35, 0xe811ff36, 42 | 0x3cdb9bdd, 0xceb018de, 0xdde0eb2a, 0x2f8b6829, 43 | 0x82f63b78, 0x709db87b, 0x63cd4b8f, 0x91a6c88c, 44 | 0x456cac67, 0xb7072f64, 0xa457dc90, 0x563c5f93, 45 | 0x082f63b7, 0xfa44e0b4, 0xe9141340, 0x1b7f9043, 46 | 0xcfb5f4a8, 0x3dde77ab, 0x2e8e845f, 0xdce5075c, 47 | 0x92a8fc17, 0x60c37f14, 0x73938ce0, 0x81f80fe3, 48 | 0x55326b08, 0xa759e80b, 0xb4091bff, 0x466298fc, 49 | 0x1871a4d8, 0xea1a27db, 0xf94ad42f, 0x0b21572c, 50 | 0xdfeb33c7, 0x2d80b0c4, 0x3ed04330, 0xccbbc033, 51 | 0xa24bb5a6, 0x502036a5, 0x4370c551, 0xb11b4652, 52 | 0x65d122b9, 0x97baa1ba, 0x84ea524e, 0x7681d14d, 53 | 0x2892ed69, 0xdaf96e6a, 0xc9a99d9e, 0x3bc21e9d, 54 | 0xef087a76, 0x1d63f975, 0x0e330a81, 0xfc588982, 55 | 0xb21572c9, 0x407ef1ca, 0x532e023e, 0xa145813d, 56 | 0x758fe5d6, 0x87e466d5, 0x94b49521, 0x66df1622, 57 | 0x38cc2a06, 0xcaa7a905, 0xd9f75af1, 0x2b9cd9f2, 58 | 0xff56bd19, 0x0d3d3e1a, 0x1e6dcdee, 0xec064eed, 59 | 0xc38d26c4, 0x31e6a5c7, 0x22b65633, 0xd0ddd530, 60 | 0x0417b1db, 0xf67c32d8, 0xe52cc12c, 0x1747422f, 61 | 0x49547e0b, 0xbb3ffd08, 0xa86f0efc, 0x5a048dff, 62 | 0x8ecee914, 0x7ca56a17, 0x6ff599e3, 0x9d9e1ae0, 63 | 0xd3d3e1ab, 0x21b862a8, 0x32e8915c, 0xc083125f, 64 | 0x144976b4, 0xe622f5b7, 0xf5720643, 0x07198540, 65 | 0x590ab964, 0xab613a67, 0xb831c993, 0x4a5a4a90, 66 | 0x9e902e7b, 0x6cfbad78, 0x7fab5e8c, 0x8dc0dd8f, 67 | 0xe330a81a, 0x115b2b19, 0x020bd8ed, 0xf0605bee, 68 | 0x24aa3f05, 0xd6c1bc06, 0xc5914ff2, 0x37faccf1, 69 | 0x69e9f0d5, 0x9b8273d6, 0x88d28022, 0x7ab90321, 70 | 0xae7367ca, 0x5c18e4c9, 0x4f48173d, 0xbd23943e, 71 | 0xf36e6f75, 0x0105ec76, 0x12551f82, 0xe03e9c81, 72 | 0x34f4f86a, 0xc69f7b69, 0xd5cf889d, 0x27a40b9e, 73 | 0x79b737ba, 0x8bdcb4b9, 0x988c474d, 0x6ae7c44e, 74 | 0xbe2da0a5, 0x4c4623a6, 0x5f16d052, 0xad7d5351, 75 | ]; 76 | 77 | const CRC_INIT: u32 = 0; 78 | 79 | const _MASK: u32 = 0xFFFFFFFF; 80 | 81 | pub fn crc_update(crc: u32, data: &[u8]) -> u32 { 82 | let mut crc = crc ^ _MASK; 83 | for b in data { 84 | let table_index = ((crc & 0xff ) as u8 )^ b; 85 | crc = (CRC_TABLE[table_index as usize] ^ (crc >> 8)) & _MASK; 86 | } 87 | crc ^ _MASK 88 | } 89 | 90 | pub fn crc_finalize(crc: u32) -> u32{ 91 | crc & _MASK 92 | } 93 | 94 | pub fn crc32c(data: &[u8]) -> u32 { 95 | crc_finalize(crc_update(CRC_INIT, data)) 96 | } 97 | -------------------------------------------------------------------------------- /auto-diff/src/serde/var.rs: -------------------------------------------------------------------------------- 1 | #[cfg(feature = "use-serde")] 2 | use serde::{ 3 | de, de::MapAccess, de::SeqAccess, de::Visitor, ser::SerializeStruct, Deserialize, Deserializer, 4 | Serialize, Serializer, 5 | }; 6 | use std::fmt; 7 | 8 | use crate::var::Var; 9 | 10 | impl Serialize for Var { 11 | fn serialize(&self, serializer: S) -> Result 12 | where 13 | S: Serializer, 14 | { 15 | // 3 is the number of fields in the struct. 16 | let mut state = serializer.serialize_struct("Var", 1)?; 17 | state.serialize_field("var", &*self.inner().borrow())?; 18 | state.end() 19 | } 20 | } 21 | 22 | impl<'de> Deserialize<'de> for Var { 23 | fn deserialize(deserializer: D) -> Result 24 | where 25 | D: Deserializer<'de>, 26 | { 27 | enum Field { 28 | Var, 29 | } 30 | 31 | impl<'de> Deserialize<'de> for Field { 32 | fn deserialize(deserializer: D) -> Result 33 | where 34 | D: Deserializer<'de>, 35 | { 36 | struct FieldVisitor; 37 | 38 | impl<'de> Visitor<'de> for FieldVisitor { 39 | type Value = Field; 40 | 41 | fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { 42 | formatter.write_str("var") 43 | } 44 | 45 | fn visit_str(self, value: &str) -> Result 46 | where 47 | E: de::Error, 48 | { 49 | match value { 50 | "var" => Ok(Field::Var), 51 | _ => Err(de::Error::unknown_field(value, &FIELDS)), 52 | } 53 | } 54 | } 55 | 56 | deserializer.deserialize_identifier(FieldVisitor) 57 | } 58 | } 59 | 60 | struct VarVisitor; 61 | 62 | impl<'de> Visitor<'de> for VarVisitor { 63 | type Value = Var; 64 | 65 | fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { 66 | formatter.write_str("struct Var") 67 | } 68 | 69 | fn visit_map(self, mut map: V) -> Result 70 | where 71 | V: MapAccess<'de>, 72 | { 73 | let mut var = None; 74 | while let Some(key) = map.next_key()? { 75 | match key { 76 | Field::Var => { 77 | if var.is_some() { 78 | return Err(de::Error::duplicate_field("var")); 79 | } 80 | var = Some(map.next_value()?); 81 | } 82 | } 83 | } 84 | let var = var.ok_or_else(|| de::Error::missing_field("id"))?; 85 | Ok(Var::set_inner(var)) 86 | } 87 | 88 | fn visit_seq(self, mut seq: V) -> Result 89 | where 90 | V: SeqAccess<'de>, 91 | { 92 | let var = seq 93 | .next_element()? 94 | .ok_or_else(|| de::Error::invalid_length(0, &self))?; 95 | Ok(Var::set_inner(var)) 96 | } 97 | } 98 | 99 | const FIELDS: [&str; 1] = ["var"]; 100 | deserializer.deserialize_struct("Duration", &FIELDS, VarVisitor) 101 | } 102 | } 103 | 104 | #[cfg(all(test, feature = "use-serde"))] 105 | mod tests { 106 | use crate::var::Var; 107 | use rand::prelude::*; 108 | 109 | #[test] 110 | fn test_serde_var_inner() { 111 | let mut rng = StdRng::seed_from_u64(671); 112 | let n = 10; 113 | let data = Var::normal(&mut rng, &vec![n, 2], 0., 2.); 114 | let result = data.matmul(&Var::new(&vec![2., 3.], &vec![2, 1])).unwrap() 115 | + Var::new(&vec![1.], &vec![1]); 116 | 117 | let serialized = serde_pickle::to_vec(&result, true).unwrap(); 118 | let deserialized: Var = serde_pickle::from_slice(&serialized).unwrap(); 119 | println!("{:?}", deserialized.dump_net()); 120 | assert_eq!(result, deserialized); 121 | } 122 | } 123 | -------------------------------------------------------------------------------- /macros/src/lib.rs: -------------------------------------------------------------------------------- 1 | /// procedure macros 2 | 3 | use proc_macro::TokenStream; 4 | use syn::{parse_macro_input, ItemStruct, parse, parse::Parser}; 5 | use syn::punctuated::Punctuated; 6 | use syn::{Expr, Token}; 7 | use quote::quote; 8 | 9 | 10 | #[proc_macro_attribute] 11 | pub fn add_op_handle(args: TokenStream, input: TokenStream) -> TokenStream { 12 | let mut item_struct = parse_macro_input!(input as ItemStruct); 13 | let _ = parse_macro_input!(args as parse::Nothing); 14 | 15 | if let syn::Fields::Named(ref mut fields) = item_struct.fields { 16 | 17 | fields.named.push( 18 | syn::Field::parse_named 19 | .parse2(quote! { 20 | #[cfg_attr(feature = "use-serde", serde(skip))] 21 | handle: OpHandle 22 | }) 23 | .unwrap(), 24 | ); 25 | } 26 | 27 | return quote! { 28 | #item_struct 29 | } 30 | .into(); 31 | } 32 | 33 | #[proc_macro_attribute] 34 | pub fn extend_op_impl(args: TokenStream, input: TokenStream) -> TokenStream { 35 | let mut item_struct = parse_macro_input!(input as ItemStruct); 36 | let _ = parse_macro_input!(args as parse::Nothing); 37 | 38 | if let syn::Fields::Named(ref mut fields) = item_struct.fields { 39 | 40 | fields.named.push( 41 | syn::Field::parse_named 42 | .parse2(quote! { 43 | #[cfg_attr(feature = "use-serde", serde(skip))] 44 | handle: OpHandle 45 | }) 46 | .unwrap(), 47 | ); 48 | } 49 | 50 | return quote! { 51 | #item_struct 52 | } 53 | .into(); 54 | } 55 | 56 | 57 | #[proc_macro] 58 | pub fn gen_serde_funcs(input: TokenStream) -> TokenStream { 59 | 60 | let input_tokens = input.clone(); 61 | let parser = Punctuated::::parse_separated_nonempty; 62 | let input_result = parser.parse(input_tokens).expect("need list of ids"); 63 | let mut strs = vec![]; // This is the vec of op structure name in str. 64 | for item in input_result { 65 | match item { 66 | Expr::Path(expr) => { 67 | strs.push(expr.path.get_ident().expect("need a ident").to_string()); 68 | }, 69 | _ => {panic!("need a ident, expr::path.");} 70 | } 71 | } 72 | 73 | // This is the vec of ident. 74 | let names: Vec<_> = strs.iter().map(|x| quote::format_ident!("{}", x)).collect(); 75 | 76 | let serialize_box = quote!{ 77 | pub fn serialize_box(op: &Box, serializer: S) -> Result 78 | where S: Serializer { 79 | match op.get_name() { 80 | #( #strs => { 81 | let op = op.as_any().downcast_ref::<#names>().unwrap(); 82 | op.serialize(serializer) 83 | }, )* 84 | other => { 85 | return Err(ser::Error::custom(format!("unknown op {:?}", other))); 86 | } 87 | } 88 | } 89 | }; 90 | 91 | let deserialize_map = quote!{ 92 | pub fn deserialize_map<'de, V>(op_name: String, mut map: V) -> Result 93 | where V: MapAccess<'de>, { 94 | match op_name.as_str() { 95 | #( #strs => { 96 | let op_obj: #names = Some(map.next_value::<#names>()?).ok_or_else(|| de::Error::missing_field("op_obj"))?; 97 | return Ok(Op::new(Rc::new(RefCell::new(Box::new(op_obj))))); 98 | }, )* 99 | _ => { 100 | return Err(de::Error::missing_field("op_obj")); 101 | } 102 | } 103 | } 104 | }; 105 | 106 | let deserialize_seq = quote!{ 107 | pub fn deserialize_seq<'de, V>(op_name: String, mut seq: V) -> Result 108 | where V: SeqAccess<'de>, { 109 | match op_name.as_str() { 110 | #( #strs => { 111 | let op_obj: #names = seq.next_element()?.ok_or_else(|| de::Error::missing_field("op_obj"))?; 112 | return Ok(Op::new(Rc::new(RefCell::new(Box::new(op_obj))))); 113 | }, )* 114 | _ => { 115 | return Err(de::Error::missing_field("op_obj")); 116 | } 117 | } 118 | } 119 | }; 120 | 121 | let tokens = quote! { 122 | #serialize_box 123 | #deserialize_map 124 | #deserialize_seq 125 | }; 126 | 127 | tokens.into() 128 | } 129 | 130 | 131 | #[cfg(test)] 132 | mod tests { 133 | 134 | #[test] 135 | fn test() { 136 | 137 | } 138 | } 139 | -------------------------------------------------------------------------------- /tensor-rs/src/tensor_impl/lapack_tensor/lapack_api.rs: -------------------------------------------------------------------------------- 1 | #![allow(clippy::too_many_arguments)] 2 | #[cfg(feature = "use-blas-lapack")] 3 | use lapack::*; 4 | use std::marker::PhantomData; 5 | use std::cmp; 6 | 7 | 8 | pub struct LapackAPI { 9 | d: PhantomData, 10 | } 11 | 12 | #[cfg(feature = "use-blas-lapack")] 13 | impl LapackAPI { 14 | /// = 'A': all M columns of U and all N rows of V**T are 15 | /// returned in the arrays U and VT; 16 | /// = 'S': the first min(M,N) columns of U and the first 17 | /// min(M,N) rows of V**T are returned in the arrays U 18 | /// and VT; 19 | /// = 'O': If M >= N, the first N columns of U are overwritten 20 | /// on the array A and all rows of V**T are returned in 21 | /// the array VT; 22 | /// otherwise, all columns of U are returned in the 23 | /// array U and the first M rows of V**T are overwritten 24 | /// in the array A; 25 | /// = 'N': no columns of U or rows of V**T are computed. 26 | pub fn gesdd(jobz: &char, m: usize, n: usize, 27 | a: &mut [f32], lda: usize, 28 | s: &mut [f32], 29 | u: &mut [f32], ldu: usize, 30 | vt: &mut [f32], ldvt: usize, 31 | info: &mut i32) { 32 | let (mx, mn) = if m > n {(m, n)} else {(n, m)}; 33 | let (jobz, mini_work): (u8, usize) = match jobz { 34 | 'A' => { 35 | (b'A', 4*mn*mn + 6*mn + mx) 36 | }, 37 | 'S' => { 38 | (b'S', 4*mn*mn + 7*mn) 39 | }, 40 | 'O' => { 41 | (b'O', 3*mn + cmp::max( mx, 5*mn*mn + 4*mn )) 42 | }, 43 | 'N' => { 44 | (b'N', 3*mn + cmp::max( mx, 7*mn )) 45 | }, 46 | _ => panic!("unknown jobz: {}", jobz), 47 | }; 48 | let mut work: Vec = vec![0.; mini_work]; 49 | let lwork = mini_work as i32; 50 | let mut iwork: Vec = vec![0; 8*mn]; 51 | unsafe { 52 | sgesdd(jobz, m as i32, n as i32, 53 | a, lda as i32, 54 | s, 55 | u, ldu as i32, 56 | vt, ldvt as i32, 57 | &mut work, lwork, &mut iwork, 58 | info); 59 | } 60 | } 61 | } 62 | 63 | #[cfg(feature = "use-blas-lapack")] 64 | impl LapackAPI { 65 | pub fn gesdd(jobz: &char, m: usize, n: usize, 66 | a: &mut [f64], lda: usize, 67 | s: &mut [f64], 68 | u: &mut [f64], ldu: usize, 69 | vt: &mut [f64], ldvt: usize, 70 | info: &mut i32) { 71 | let (mx, mn) = if m > n {(m, n)} else {(n, m)}; 72 | let (jobz, mini_work): (u8, usize) = match jobz { 73 | 'A' => { 74 | (b'A', 4*mn*mn + 6*mn + mx) 75 | }, 76 | 'S' => { 77 | (b'S', 4*mn*mn + 7*mn) 78 | }, 79 | 'O' => { 80 | (b'O', 3*mn + cmp::max( mx, 5*mn*mn + 4*mn )) 81 | }, 82 | 'N' => { 83 | (b'N', 3*mn + cmp::max( mx, 7*mn )) 84 | }, 85 | _ => panic!("unknown jobz: {}", jobz), 86 | }; 87 | let mut work: Vec = vec![0.; mini_work]; 88 | let lwork = mini_work as i32; 89 | let mut iwork: Vec = vec![0; 8*mn]; 90 | unsafe { 91 | dgesdd(jobz, m as i32, n as i32, 92 | a, lda as i32, 93 | s, 94 | u, ldu as i32, 95 | vt, ldvt as i32, 96 | &mut work, lwork, &mut iwork, 97 | info); 98 | } 99 | } 100 | } 101 | 102 | #[cfg(all(test, feature = "use-blas-lapack"))] 103 | mod tests { 104 | use super::*; 105 | 106 | #[test] 107 | fn test_svd() { 108 | let mut m: Vec = vec![4., 12., -16., 12., 37., -43., -16., -43., 98.]; 109 | let mut s = vec![0. ; 3]; 110 | let mut u = vec![0. ; 9]; 111 | let mut vt = vec![0. ; 9]; 112 | let mut info: i32 = 0; 113 | LapackAPI::::gesdd(&'S', 3, 3, 114 | &mut m, 3, 115 | &mut s, 116 | &mut u, 3, 117 | &mut vt, 3, 118 | &mut info); 119 | println!("{:?}, {:?}, {:?}", u, s, vt); 120 | let es: Vec = vec![123.47723179013161, 15.503963229407585, 0.018804980460810704]; 121 | assert!((s[0] - es[0]).abs() < 1e-6); 122 | assert!((s[1] - es[1]).abs() < 1e-6); 123 | assert!((s[2] - es[2]).abs() < 1e-1); 124 | } 125 | } 126 | -------------------------------------------------------------------------------- /tensor-rs/src/tensor_impl/cuda_tensor/cuda_helper.rs: -------------------------------------------------------------------------------- 1 | // wrap cuda stream and other session status. 2 | // 3 | // 4 | use std::rc::Rc; 5 | use std::cell::RefCell; 6 | use cuda11_cudart_sys::{self, cudaStreamCreate, cudaStreamSynchronize, cudaStreamDestroy, check_cuda_status, cudaStream_t}; 7 | use cuda11_cutensor_sys::{self, cutensorHandle_t, check_cutensor_status, cutensorInit}; 8 | 9 | /// 10 | /// Raw Cuda stream 11 | /// 12 | pub struct CudaStream { 13 | stream: cudaStream_t, 14 | } 15 | 16 | impl CudaStream { 17 | pub fn new() -> CudaStream { 18 | let mut stream = std::ptr::null_mut(); 19 | unsafe { 20 | check_cuda_status(cudaStreamCreate(&mut stream as *mut _ as _)); 21 | } 22 | CudaStream { 23 | stream: stream, 24 | } 25 | } 26 | pub fn empty() -> CudaStream { 27 | CudaStream { 28 | stream: std::ptr::null_mut(), 29 | } 30 | } 31 | pub fn raw_stream(&self) -> cudaStream_t { 32 | self.stream 33 | } 34 | } 35 | 36 | impl Drop for CudaStream { 37 | fn drop(&mut self) { 38 | if self.stream != std::ptr::null_mut() { 39 | unsafe { 40 | check_cuda_status(cudaStreamDestroy(self.stream as _)); 41 | //println!("cudaFree"); 42 | } 43 | } 44 | } 45 | } 46 | 47 | /// Lazy initialize cuda stream. 48 | /// Only initialize it when get_stream method is called. 49 | /// 50 | pub struct StreamCell { 51 | stream: RefCell>, 52 | } 53 | 54 | impl StreamCell { 55 | pub fn new() -> StreamCell { 56 | StreamCell { 57 | stream: RefCell::new(None), 58 | } 59 | } 60 | pub fn get_stream(&self) -> StreamCellGuard { 61 | let stream = match self.stream.borrow_mut().take() { 62 | None => {CudaStream::new()}, 63 | Some(strm) => {strm} 64 | }; 65 | StreamCellGuard { 66 | stream_cell: self, 67 | stream: stream, 68 | } 69 | } 70 | } 71 | 72 | /// move value out and back in. 73 | pub struct StreamCellGuard<'a> { 74 | stream_cell: &'a StreamCell, 75 | stream: CudaStream, 76 | } 77 | impl<'a> Drop for StreamCellGuard<'a> { 78 | fn drop(&mut self) { 79 | let stream = std::mem::replace(&mut self.stream, CudaStream::empty()); 80 | *self.stream_cell.stream.borrow_mut() = Some(stream); 81 | } 82 | } 83 | impl<'a> std::ops::Deref for StreamCellGuard<'a> { 84 | type Target = CudaStream; 85 | 86 | fn deref(&self) -> &CudaStream { 87 | // This increases the ergnomics of a `DynamicImageGuard`. Because 88 | // of this impl, most uses of `DynamicImageGuard` can be as if 89 | // it were just a `&DynamicImage`. 90 | &self.stream 91 | } 92 | } 93 | 94 | #[cfg(all(test, feature = "use-cuda"))] 95 | mod tests { 96 | use super::*; 97 | 98 | #[test] 99 | fn cuda_stream() { 100 | let mut stream = CudaStream::new(); 101 | let raw_stream = stream.raw_stream(); 102 | assert!((raw_stream as *const _) != std::ptr::null()); 103 | } 104 | 105 | #[test] 106 | fn cuda_stream_cell() { 107 | { 108 | let stream = StreamCell::new(); 109 | 110 | let mut str1: cudaStream_t = std::ptr::null_mut(); 111 | let mut str2: cudaStream_t = std::ptr::null_mut(); 112 | str1 = stream.get_stream().raw_stream(); 113 | str2 = stream.get_stream().raw_stream(); 114 | 115 | //println!("{:?}, {:?}", str1, str2); 116 | assert_eq!(str1, str2); 117 | } 118 | 119 | { 120 | let s1 = Rc::new(StreamCell::new()); 121 | let s2 = s1.clone(); 122 | let str1 = s1.get_stream().raw_stream(); 123 | let str2 = s2.get_stream().raw_stream(); 124 | //println!("{:?}, {:?}", str1, str2); 125 | } 126 | } 127 | 128 | 129 | } 130 | 131 | // 132 | // Cuda cutensor 133 | // 134 | pub struct CudaCutensor { 135 | handle: cutensorHandle_t, 136 | } 137 | 138 | impl CudaCutensor { 139 | pub fn new() -> CudaCutensor { 140 | unsafe { 141 | let mut handle:cutensorHandle_t = std::mem::uninitialized(); 142 | check_cutensor_status(cutensorInit(&mut handle as *mut _)); 143 | 144 | CudaCutensor { 145 | handle: handle, 146 | } 147 | } 148 | } 149 | } 150 | 151 | impl Drop for CudaCutensor { 152 | fn drop(&mut self) { 153 | unsafe { 154 | 155 | } 156 | } 157 | } 158 | -------------------------------------------------------------------------------- /auto-diff/examples/logistic_regression.rs: -------------------------------------------------------------------------------- 1 | //! Logistic regression example on Breast Cancer Wisconsin (Diagnostic) Data Set 2 | //! 3 | //! The dataset is from http://archive.ics.uci.edu/ml/datasets/breast+cancer+wisconsin+%28diagnostic%29 4 | 5 | 6 | use auto_diff::var::Var; 7 | use auto_diff::op::Linear; 8 | use auto_diff::op::OpCall; 9 | use auto_diff::optim::{SGD}; 10 | use csv; 11 | use std::collections::{BTreeSet}; 12 | use rand::prelude::*; 13 | extern crate openblas_src; 14 | 15 | fn main() { 16 | let mut reader = csv::ReaderBuilder::new() 17 | .has_headers(false) 18 | .from_path("examples/data/wdbc.data") 19 | .expect("Cannot read wdbc.data"); 20 | 21 | let mut id; 22 | let mut ill; 23 | let mut ids = BTreeSet::::new(); 24 | let head = reader.position().clone(); 25 | 26 | for record in reader.records() { 27 | let line = record.expect(""); 28 | id = line[0].trim().parse::().expect(""); 29 | //ill = line[1].trim().parse::().expect(""); 30 | //println!("{}, {}", id, ill); 31 | 32 | if !ids.contains(&id) { 33 | ids.insert(id); 34 | } else { 35 | println!("duplicate {}", id); 36 | } 37 | } 38 | let size = ids.len(); 39 | println!("total size: {}", size); 40 | 41 | let data = Var::empty(&vec![size, 31]); 42 | //println!("{:?} \n {}", data.size(), data); 43 | reader.seek(head).expect(""); 44 | for (record, index) in reader.records().zip(0..size) { 45 | let line = record.expect(""); 46 | let mut tmp = Vec::::with_capacity(31); 47 | 48 | ill = line[1].trim().parse::().expect(""); 49 | if ill == "M" { 50 | tmp.push(1.); 51 | } else { 52 | tmp.push(0.); 53 | } 54 | 55 | for i in 2..32 { 56 | let value = line[i].trim().parse::().expect(""); 57 | //println!("{}", value); 58 | tmp.push(value); 59 | } 60 | //println!("{:?}", tmp); 61 | data.from_record_f64(index, &tmp); 62 | } 63 | 64 | 65 | //println!("{:?} \n {}", data.size(), data); 66 | let train_size = ((size as f32)*0.7) as usize; 67 | let test_size = size - train_size; 68 | //let splited_data = data.split(&vec![train_size, test_size], 0); 69 | let data_label_split = data.split(&vec![1, 30], 1).unwrap(); 70 | let label = &data_label_split[0]; 71 | let data = &data_label_split[1]; 72 | let data = data.normalize_unit().unwrap(); 73 | let label_split = label.split(&vec![train_size, test_size], 0).unwrap(); 74 | let data_split = data.split(&vec![train_size, test_size], 0).unwrap(); 75 | let train_data = &data_split[0]; 76 | let train_label = &label_split[0]; 77 | let test_data = &data_split[1]; 78 | let test_label = &label_split[1]; 79 | 80 | 81 | train_data.reset_net(); 82 | train_label.reset_net(); 83 | test_data.reset_net(); 84 | test_label.reset_net(); 85 | 86 | println!("{:?}", train_data.size()); 87 | println!("{:?}", train_label.size()); 88 | println!("{:?}", test_data.size()); 89 | println!("{:?}", test_label.size()); 90 | 91 | 92 | // build the model 93 | let mut rng = StdRng::seed_from_u64(671); 94 | 95 | let mut op1 = Linear::new(Some(30), Some(1), true); 96 | op1.set_weight(Var::normal(&mut rng, &[30, 1], 0., 2.)); 97 | op1.set_bias(Var::normal(&mut rng, &[1, ], 0., 2.)); 98 | // let weights = op1.get_values().unwrap(); 99 | // rng.normal_(&weights[0], 0., 1.); 100 | // rng.normal_(&weights[1], 0., 1.); 101 | // op1.set_values(&weights); 102 | 103 | let input = train_data.clone(); 104 | let label = train_label.clone(); 105 | 106 | let output = op1.call(&[&input]).unwrap().pop().unwrap(); 107 | 108 | //let loss = m.bce_with_logits_loss(); 109 | println!("o: {:?}", output.size()); 110 | println!("l: {:?}", train_label.size()); 111 | let loss = output.bce_with_logits_loss(&label).unwrap(); 112 | 113 | 114 | let mut opt = SGD::new(1.); 115 | 116 | for i in 0..100 { 117 | 118 | println!("{:?}", i); 119 | input.set(train_data); 120 | label.set(train_label); 121 | loss.rerun().unwrap(); 122 | loss.bp().unwrap(); 123 | loss.step(&mut opt).unwrap(); 124 | 125 | input.set(test_data); 126 | label.set(test_label); 127 | loss.rerun().unwrap(); 128 | println!("{:?}", loss); 129 | 130 | } 131 | let weight = op1.weight(); 132 | let bias = op1.bias(); 133 | println!("{:?}, {:?}", weight, bias); 134 | } 135 | -------------------------------------------------------------------------------- /auto-diff/src/op/vision.rs: -------------------------------------------------------------------------------- 1 | use super::{Op, OpCall, OpHandle, OpTrait}; 2 | use tensor_rs::tensor::Tensor; 3 | 4 | use std::cell::RefCell; 5 | use std::rc::Rc; 6 | 7 | use crate::err::AutoDiffError; 8 | use crate::var::Var; 9 | 10 | #[cfg(feature = "use-serde")] 11 | use serde::{Deserialize, Serialize}; 12 | #[cfg(feature = "use-serde")] 13 | use std::any::Any; 14 | 15 | #[cfg_attr(feature = "use-serde", derive(Serialize, Deserialize))] 16 | pub struct GetPatch { 17 | #[cfg_attr(feature = "use-serde", serde(skip))] 18 | handle: OpHandle, 19 | range: Vec<(usize, usize)>, 20 | step: Option>, 21 | } 22 | impl GetPatch { 23 | pub fn new(range: &[(usize, usize)], step: Option<&[usize]>) -> GetPatch { 24 | let new_range = range.to_vec(); 25 | let new_step = step.map(|v| v.to_vec()); 26 | GetPatch { 27 | handle: OpHandle::new(), 28 | range: new_range, 29 | step: new_step, 30 | } 31 | } 32 | fn get_handle(&self) -> &OpHandle { 33 | &self.handle 34 | } 35 | fn get_handle_mut(&mut self) -> &mut OpHandle { 36 | &mut self.handle 37 | } 38 | } 39 | impl OpCall for GetPatch { 40 | fn call(&mut self, inputs: &[&Var]) -> Result, AutoDiffError> { 41 | let new_one = GetPatch { 42 | handle: OpHandle::new(), 43 | range: self.range.clone(), 44 | step: self.step.clone(), 45 | }; 46 | 47 | let op = Op::new(Rc::new(RefCell::new(Box::new(new_one)))); 48 | 49 | inputs[0].called_with(op, &inputs[1..inputs.len()]) 50 | } 51 | } 52 | impl OpTrait for GetPatch { 53 | fn get_name(&self) -> &'static str { 54 | "GetPatch" 55 | } 56 | fn get_input_size(&self) -> usize { 57 | 1 58 | } 59 | fn get_output_size(&self) -> usize { 60 | 1 61 | } 62 | fn apply(&self, input: &[Tensor], output: &[Tensor]) { 63 | let step = self.step.as_ref().map(|v| &v[..]); 64 | output[0].swap(&input[0].get_patch(&self.range, step)); 65 | } 66 | fn grad(&self, input: &[Tensor], output_grad: &[Tensor], input_grad: &[Tensor]) { 67 | unimplemented!(); 68 | } 69 | fn get_values(&self) -> Vec { 70 | Vec::new() 71 | } 72 | fn get_grads(&self) -> Vec { 73 | Vec::new() 74 | } 75 | fn set_values(&self, _v: &[Tensor]) {} 76 | #[cfg(feature = "use-serde")] 77 | fn as_any(&self) -> &dyn Any { 78 | self 79 | } 80 | } 81 | 82 | #[cfg_attr(feature = "use-serde", derive(Serialize, Deserialize))] 83 | pub struct SetPatch { 84 | #[cfg_attr(feature = "use-serde", serde(skip))] 85 | handle: OpHandle, 86 | range: Vec<(usize, usize)>, 87 | step: Option>, 88 | } 89 | impl SetPatch { 90 | pub fn new(range: &[(usize, usize)], step: Option<&[usize]>) -> SetPatch { 91 | let new_range = range.to_vec(); 92 | let new_step = step.map(|v| v.to_vec()); 93 | SetPatch { 94 | handle: OpHandle::new(), 95 | range: new_range, 96 | step: new_step, 97 | } 98 | } 99 | fn get_handle(&self) -> &OpHandle { 100 | &self.handle 101 | } 102 | fn get_handle_mut(&mut self) -> &mut OpHandle { 103 | &mut self.handle 104 | } 105 | } 106 | impl OpCall for SetPatch { 107 | fn call(&mut self, inputs: &[&Var]) -> Result, AutoDiffError> { 108 | let new_one = SetPatch { 109 | handle: OpHandle::new(), 110 | range: self.range.clone(), 111 | step: self.step.clone(), 112 | }; 113 | 114 | let op = Op::new(Rc::new(RefCell::new(Box::new(new_one)))); 115 | 116 | inputs[0].called_with(op, &inputs[1..inputs.len()]) 117 | } 118 | } 119 | impl OpTrait for SetPatch { 120 | fn get_name(&self) -> &'static str { 121 | "SetPatch" 122 | } 123 | fn get_input_size(&self) -> usize { 124 | 2 125 | } 126 | fn get_output_size(&self) -> usize { 127 | 1 128 | } 129 | fn apply(&self, input: &[Tensor], output: &[Tensor]) { 130 | let step = self.step.as_ref().map(|v| &v[..]); 131 | output[0].swap(&input[0].set_patch(&input[1], &self.range, step)); 132 | } 133 | fn grad(&self, input: &[Tensor], output_grad: &[Tensor], input_grad: &[Tensor]) { 134 | unimplemented!(); 135 | } 136 | fn get_values(&self) -> Vec { 137 | Vec::new() 138 | } 139 | fn get_grads(&self) -> Vec { 140 | Vec::new() 141 | } 142 | fn set_values(&self, _v: &[Tensor]) {} 143 | #[cfg(feature = "use-serde")] 144 | fn as_any(&self) -> &dyn Any { 145 | self 146 | } 147 | } 148 | -------------------------------------------------------------------------------- /auto-diff/src/op/comparison.rs: -------------------------------------------------------------------------------- 1 | #![allow(clippy::redundant_closure_call)] 2 | use super::macros::new_binary_op; 3 | use super::{Op, OpCall, OpHandle, OpTrait}; 4 | use tensor_rs::tensor::Tensor; 5 | 6 | use std::cell::RefCell; 7 | use std::rc::Rc; 8 | 9 | use crate::err::AutoDiffError; 10 | use crate::var::Var; 11 | 12 | #[cfg(feature = "use-serde")] 13 | use serde::{Deserialize, Serialize}; 14 | #[cfg(feature = "use-serde")] 15 | use std::any::Any; 16 | 17 | // max_pair 18 | new_binary_op!( 19 | MaxPair, 20 | "Max_pair", 21 | (|a: &[Tensor], b: &[Tensor]| b[0].swap(&a[0].max_pair(&a[1]))), 22 | (|input: &[Tensor], output_grad: &[Tensor], input_grad: &[Tensor]| { 23 | unimplemented!(); 24 | }) 25 | ); 26 | // max, in reduction 27 | // min_pair 28 | new_binary_op!( 29 | MinPair, 30 | "Min_pair", 31 | (|a: &[Tensor], b: &[Tensor]| b[0].swap(&a[0].min_pair(&a[1]))), 32 | (|input: &[Tensor], output_grad: &[Tensor], input_grad: &[Tensor]| { 33 | unimplemented!(); 34 | }) 35 | ); 36 | // min, in reduction 37 | // arg_sort 38 | #[cfg_attr(feature = "use-serde", derive(Serialize, Deserialize))] 39 | pub struct ArgSort { 40 | #[cfg_attr(feature = "use-serde", serde(skip))] 41 | handle: OpHandle, 42 | dim: usize, 43 | descending: bool, 44 | } 45 | impl ArgSort { 46 | pub fn new(dim: usize, descending: bool) -> ArgSort { 47 | ArgSort { 48 | handle: OpHandle::new(), 49 | dim, 50 | descending, 51 | } 52 | } 53 | fn get_handle(&self) -> &OpHandle { 54 | &self.handle 55 | } 56 | fn get_handle_mut(&mut self) -> &mut OpHandle { 57 | &mut self.handle 58 | } 59 | } 60 | impl OpCall for ArgSort { 61 | fn call(&mut self, inputs: &[&Var]) -> Result, AutoDiffError> { 62 | let new_one = ArgSort { 63 | handle: OpHandle::new(), 64 | dim: self.dim, 65 | descending: self.descending, 66 | }; 67 | 68 | let op = Op::new(Rc::new(RefCell::new(Box::new(new_one)))); 69 | 70 | inputs[0].called_with(op, &inputs[1..inputs.len()]) 71 | } 72 | } 73 | impl OpTrait for ArgSort { 74 | fn get_name(&self) -> &'static str { 75 | "Arg_sort" 76 | } 77 | fn get_input_size(&self) -> usize { 78 | 1 79 | } 80 | fn get_output_size(&self) -> usize { 81 | 1 82 | } 83 | fn apply(&self, input: &[Tensor], output: &[Tensor]) { 84 | output[0].swap(&input[0].arg_sort(self.dim, self.descending)) 85 | } 86 | fn grad(&self, input: &[Tensor], output_grad: &[Tensor], input_grad: &[Tensor]) { 87 | unimplemented!(); 88 | } 89 | fn get_values(&self) -> Vec { 90 | Vec::new() 91 | } 92 | fn get_grads(&self) -> Vec { 93 | Vec::new() 94 | } 95 | fn set_values(&self, _v: &[Tensor]) {} 96 | #[cfg(feature = "use-serde")] 97 | fn as_any(&self) -> &dyn Any { 98 | self 99 | } 100 | } 101 | // eq_t (use eq_elem) 102 | new_binary_op!( 103 | EqElem, 104 | "Eq_t", 105 | (|a: &[Tensor], b: &[Tensor]| b[0].swap(&a[0].eq_t(&a[1]))), 106 | (|input: &[Tensor], output_grad: &[Tensor], input_grad: &[Tensor]| { 107 | unimplemented!(); 108 | }) 109 | ); 110 | // equal, 0 is == 1 is != 111 | new_binary_op!( 112 | Equal, 113 | "Equal", 114 | (|a: &[Tensor], b: &[Tensor]| if a[0].equal(&a[1]) { 115 | b[0].swap(&Tensor::zeros(&[1])) 116 | } else { 117 | b[0].swap(&Tensor::ones(&[1])) 118 | }), 119 | (|input: &[Tensor], output_grad: &[Tensor], input_grad: &[Tensor]| { 120 | unimplemented!(); 121 | }) 122 | ); 123 | // ge 124 | new_binary_op!( 125 | Ge, 126 | "Ge", 127 | (|a: &[Tensor], b: &[Tensor]| b[0].swap(&a[0].ge(&a[1]))), 128 | (|input: &[Tensor], output_grad: &[Tensor], input_grad: &[Tensor]| { 129 | unimplemented!(); 130 | }) 131 | ); 132 | // gt 133 | new_binary_op!( 134 | Gt, 135 | "Gt", 136 | (|a: &[Tensor], b: &[Tensor]| b[0].swap(&a[0].gt(&a[1]))), 137 | (|input: &[Tensor], output_grad: &[Tensor], input_grad: &[Tensor]| { 138 | unimplemented!(); 139 | }) 140 | ); 141 | // le 142 | new_binary_op!( 143 | Le, 144 | "Le", 145 | (|a: &[Tensor], b: &[Tensor]| b[0].swap(&a[0].le(&a[1]))), 146 | (|input: &[Tensor], output_grad: &[Tensor], input_grad: &[Tensor]| { 147 | unimplemented!(); 148 | }) 149 | ); 150 | // lt 151 | new_binary_op!( 152 | Lt, 153 | "Lt", 154 | (|a: &[Tensor], b: &[Tensor]| b[0].swap(&a[0].lt(&a[1]))), 155 | (|input: &[Tensor], output_grad: &[Tensor], input_grad: &[Tensor]| { 156 | unimplemented!(); 157 | }) 158 | ); 159 | // ne 160 | new_binary_op!( 161 | Ne, 162 | "Ne", 163 | (|a: &[Tensor], b: &[Tensor]| b[0].swap(&a[0].ne(&a[1]))), 164 | (|input: &[Tensor], output_grad: &[Tensor], input_grad: &[Tensor]| { 165 | unimplemented!(); 166 | }) 167 | ); 168 | -------------------------------------------------------------------------------- /auto-diff/examples/mlp.rs: -------------------------------------------------------------------------------- 1 | //! 1 hidden layer MLP example on Breast Cancer Wisconsin (Diagnostic) Data Set 2 | //! 3 | //! The dataset is from http://archive.ics.uci.edu/ml/datasets/breast+cancer+wisconsin+%28diagnostic%29 4 | 5 | 6 | use auto_diff::var::Var; 7 | use auto_diff::op::{Linear, OpCall}; 8 | use auto_diff::optim::{SGD}; 9 | use csv; 10 | use std::collections::{BTreeSet}; 11 | use rand::prelude::*; 12 | extern crate openblas_src; 13 | 14 | //use tensorboard_rs::summary_writer::SummaryWriter; 15 | 16 | fn main() { 17 | let mut reader = csv::ReaderBuilder::new() 18 | .has_headers(false) 19 | .from_path("examples/data/wdbc.data") 20 | .expect("Cannot read wdbc.data"); 21 | 22 | let mut id; 23 | let mut ill; 24 | let mut ids = BTreeSet::::new(); 25 | let head = reader.position().clone(); 26 | 27 | for record in reader.records() { 28 | let line = record.expect(""); 29 | id = line[0].trim().parse::().expect(""); 30 | //ill = line[1].trim().parse::().expect(""); 31 | //println!("{}, {}", id, ill); 32 | 33 | if !ids.contains(&id) { 34 | ids.insert(id); 35 | } else { 36 | println!("duplicate {}", id); 37 | } 38 | } 39 | let size = ids.len(); 40 | println!("total size: {}", size); 41 | 42 | let data = Var::empty(&vec![size, 31]); 43 | //println!("{:?} \n {}", data.size(), data); 44 | reader.seek(head).expect(""); 45 | for (record, index) in reader.records().zip(0..size) { 46 | let line = record.expect(""); 47 | let mut tmp = Vec::::with_capacity(31); 48 | 49 | ill = line[1].trim().parse::().expect(""); 50 | if ill == "M" { 51 | tmp.push(1.); 52 | } else { 53 | tmp.push(0.); 54 | } 55 | 56 | for i in 2..32 { 57 | let value = line[i].trim().parse::().expect(""); 58 | //println!("{}", value); 59 | tmp.push(value); 60 | } 61 | //println!("{:?}", tmp); 62 | data.from_record_f64(index, &tmp); 63 | } 64 | 65 | 66 | //println!("{:?} \n {}", data.size(), data); 67 | let train_size = ((size as f32)*0.7) as usize; 68 | let test_size = size - train_size; 69 | //let splited_data = data.split(&vec![train_size, test_size], 0); 70 | let data_label_split = data.split(&vec![1, 30], 1).unwrap(); 71 | let label = &data_label_split[0]; 72 | let data = &data_label_split[1]; 73 | let data = data.normalize_unit().unwrap(); 74 | let label_split = label.split(&vec![train_size, test_size], 0).unwrap(); 75 | let data_split = data.split(&vec![train_size, test_size], 0).unwrap(); 76 | let train_data = &data_split[0]; 77 | let train_label = &label_split[0]; 78 | let test_data = &data_split[1]; 79 | let test_label = &label_split[1]; 80 | 81 | 82 | train_data.reset_net(); 83 | train_label.reset_net(); 84 | test_data.reset_net(); 85 | test_label.reset_net(); 86 | 87 | println!("{:?}", train_data.size()); 88 | println!("{:?}", train_label.size()); 89 | println!("{:?}", test_data.size()); 90 | println!("{:?}", test_label.size()); 91 | 92 | 93 | // build the model 94 | let mut rng = StdRng::seed_from_u64(671); 95 | 96 | let mut op1 = Linear::new(Some(30), Some(10), true); 97 | op1.set_weight(Var::normal(&mut rng, &[30, 10], 0., 1.)); 98 | op1.set_bias(Var::normal(&mut rng, &[10, ], 0., 1.)); 99 | 100 | let mut op2 = Linear::new(Some(10), Some(1), true); 101 | op2.set_weight(Var::normal(&mut rng, &[10, 1], 0., 1.)); 102 | op2.set_bias(Var::normal(&mut rng, &[1, ], 0., 1.)); 103 | 104 | // let mut writer = SummaryWriter::new(&("./logdir".to_string())); 105 | let input = train_data.clone(); 106 | let label = train_label.clone(); 107 | 108 | let output1 = op1.call(&[&input]).unwrap().pop().unwrap(); 109 | let output2 = output1.sigmoid().unwrap(); 110 | let output = op2.call(&[&output2]).unwrap().pop().unwrap(); 111 | 112 | let loss = output.bce_with_logits_loss(&label).unwrap(); 113 | 114 | let mut opt = SGD::new(1.); 115 | 116 | for i in 0..500 { 117 | 118 | println!("i: {:?}", i); 119 | input.set(train_data); 120 | label.set(train_label); 121 | loss.rerun().unwrap(); 122 | loss.bp().unwrap(); 123 | loss.step(&mut opt).unwrap(); 124 | 125 | input.set(test_data); 126 | label.set(test_label); 127 | loss.rerun().unwrap(); 128 | println!("loss: {:?}", loss); 129 | 130 | //writer.add_scalar("run1/loss", loss.get().get_scale_f32(), i); 131 | //writer.add_scalar("run1/accuracy", 1.-tsum.get_scale_f32()/(test_size as f32), i); 132 | //writer.flush(); 133 | 134 | let output1 = output.clone(); 135 | let err = (output1.sigmoid().unwrap() - test_label.clone()).abs().unwrap().sum(None, false).unwrap(); 136 | println!("err: {:?}", err); 137 | } 138 | 139 | //println!("{:?}, {:?}", test_label, output.sigmoid()); 140 | } 141 | -------------------------------------------------------------------------------- /auto-diff/src/op/reduction.rs: -------------------------------------------------------------------------------- 1 | #![allow(clippy::redundant_closure_call)] 2 | use std::cell::RefCell; 3 | use std::rc::Rc; 4 | 5 | use super::{Op, OpCall, OpHandle, OpTrait}; 6 | use crate::err::AutoDiffError; 7 | use tensor_rs::tensor::Tensor; 8 | 9 | #[cfg(feature = "use-serde")] 10 | use serde::{Deserialize, Serialize}; 11 | #[cfg(feature = "use-serde")] 12 | use std::any::Any; 13 | 14 | macro_rules! reduce_macro { 15 | ($a:ident, $b:expr, $c:ident, $d: tt) => { 16 | #[cfg_attr(feature = "use-serde", derive(Serialize, Deserialize))] 17 | pub struct $a { 18 | #[cfg_attr(feature = "use-serde", serde(skip))] 19 | handle: OpHandle, 20 | dim: Option>, 21 | keepdim: bool, 22 | } 23 | impl $a { 24 | pub fn new(dim: Option<&[usize]>, keepdim: bool) -> $a { 25 | $a { 26 | handle: OpHandle::new(), 27 | dim: dim.map(|v| v.to_vec()), 28 | keepdim, 29 | } 30 | } 31 | fn get_handle(&self) -> &OpHandle { 32 | &self.handle 33 | } 34 | fn get_handle_mut(&mut self) -> &mut OpHandle { 35 | &mut self.handle 36 | } 37 | } 38 | impl OpCall for $a { 39 | fn call( 40 | &mut self, 41 | inputs: &[&crate::var::Var], 42 | ) -> Result, AutoDiffError> { 43 | let new_one = $a { 44 | handle: OpHandle::new(), 45 | dim: self.dim.as_ref().map(|v| v.to_vec()), 46 | keepdim: self.keepdim, 47 | }; 48 | 49 | let op = Op::new(Rc::new(RefCell::new(Box::new(new_one)))); 50 | 51 | inputs[0].called_with(op, &inputs[1..inputs.len()]) 52 | } 53 | } 54 | impl OpTrait for $a { 55 | fn get_name(&self) -> &'static str { 56 | ($b) 57 | } 58 | fn get_input_size(&self) -> usize { 59 | 1 60 | } 61 | fn get_output_size(&self) -> usize { 62 | 1 63 | } 64 | fn apply(&self, input: &[Tensor], output: &[Tensor]) { 65 | match &self.dim { 66 | Some(v) => { 67 | let v1 = v.clone(); 68 | output[0].swap(&input[0].$c(Some(&v1), self.keepdim)); 69 | } 70 | None => { 71 | output[0].swap(&input[0].$c(None, self.keepdim)); 72 | } 73 | } 74 | } 75 | fn grad(&self, input: &[Tensor], output_grad: &[Tensor], input_grad: &[Tensor]) { 76 | $d(input, output_grad, input_grad) 77 | } 78 | fn get_values(&self) -> Vec { 79 | Vec::new() 80 | } 81 | fn get_grads(&self) -> Vec { 82 | Vec::new() 83 | } 84 | fn set_values(&self, _v: &[Tensor]) {} 85 | #[cfg(feature = "use-serde")] 86 | fn as_any(&self) -> &dyn Any { 87 | self 88 | } 89 | } 90 | }; 91 | } 92 | 93 | reduce_macro!( 94 | Argmax, 95 | "Argmax", 96 | argmax, 97 | (|input: &[Tensor], output_grad: &[Tensor], input_grad: &[Tensor]| { 98 | unimplemented!(); 99 | }) 100 | ); 101 | 102 | reduce_macro!( 103 | Argmin, 104 | "Argmin", 105 | argmin, 106 | (|input: &[Tensor], output_grad: &[Tensor], input_grad: &[Tensor]| { 107 | unimplemented!(); 108 | }) 109 | ); 110 | 111 | reduce_macro!( 112 | Logsumexp, 113 | "Logsumexp", 114 | logsumexp, 115 | (|input: &[Tensor], output_grad: &[Tensor], input_grad: &[Tensor]| { 116 | unimplemented!(); 117 | }) 118 | ); 119 | 120 | reduce_macro!( 121 | Mean, 122 | "Mean", 123 | mean, 124 | (|input: &[Tensor], output_grad: &[Tensor], input_grad: &[Tensor]| { 125 | unimplemented!(); 126 | }) 127 | ); 128 | 129 | reduce_macro!( 130 | Prod, 131 | "Prod", 132 | prod, 133 | (|input: &[Tensor], output_grad: &[Tensor], input_grad: &[Tensor]| { 134 | unimplemented!(); 135 | }) 136 | ); 137 | 138 | reduce_macro!( 139 | Std, 140 | "Std", 141 | std, 142 | (|input: &[Tensor], output_grad: &[Tensor], input_grad: &[Tensor]| { 143 | unimplemented!(); 144 | }) 145 | ); 146 | 147 | reduce_macro!( 148 | Sum, 149 | "Sum", 150 | sum, 151 | (|input: &[Tensor], output_grad: &[Tensor], input_grad: &[Tensor]| { 152 | unimplemented!(); 153 | }) 154 | ); 155 | 156 | reduce_macro!( 157 | Variance, 158 | "Var", 159 | var, 160 | (|input: &[Tensor], output_grad: &[Tensor], input_grad: &[Tensor]| { 161 | unimplemented!(); 162 | }) 163 | ); 164 | 165 | reduce_macro!( 166 | Max, 167 | "Max", 168 | max, 169 | (|input: &[Tensor], output_grad: &[Tensor], input_grad: &[Tensor]| { 170 | unimplemented!(); 171 | }) 172 | ); 173 | 174 | reduce_macro!( 175 | Min, 176 | "Min", 177 | min, 178 | (|input: &[Tensor], output_grad: &[Tensor], input_grad: &[Tensor]| { 179 | unimplemented!(); 180 | }) 181 | ); 182 | -------------------------------------------------------------------------------- /auto-diff/examples/mlp_mnist.rs: -------------------------------------------------------------------------------- 1 | use auto_diff::op::{Linear, OpCall}; 2 | use auto_diff::optim::{SGD, MiniBatch}; 3 | use auto_diff::Var; 4 | use rand::prelude::*; 5 | use ::rand::prelude::StdRng; 6 | extern crate openblas_src; 7 | 8 | 9 | //use tensorboard_rs::summary_writer::SummaryWriter; 10 | 11 | mod mnist; 12 | use mnist::{load_images, load_labels}; 13 | 14 | fn main() { 15 | 16 | let train_img = load_images("examples/data/mnist/train-images-idx3-ubyte"); 17 | let test_img = load_images("examples/data/mnist/t10k-images-idx3-ubyte"); 18 | let train_label = load_labels("examples/data/mnist/train-labels-idx1-ubyte"); 19 | let test_label = load_labels("examples/data/mnist/t10k-labels-idx1-ubyte"); 20 | 21 | let train_size = train_img.size(); 22 | let n = train_size[0]; 23 | let h = train_size[1]; 24 | let w = train_size[2]; 25 | let train_data = train_img.reshape(&vec![n, h*w]).unwrap(); 26 | 27 | let test_size = test_img.size(); 28 | let n = test_size[0]; 29 | let h = test_size[1]; 30 | let w = test_size[2]; 31 | let test_data = test_img.reshape(&vec![n, h*w]).unwrap(); 32 | 33 | train_data.reset_net(); 34 | train_label.reset_net(); 35 | test_data.reset_net(); 36 | test_label.reset_net(); 37 | 38 | assert_eq!(train_data.size(), [60000, 784]); 39 | assert_eq!(train_label.size(), [60000]); 40 | assert_eq!(test_data.size(), [10000, 784]); 41 | assert_eq!(test_label.size(), [10000]); 42 | 43 | 44 | // build the model 45 | // let mut m = Module::new(); 46 | // let mut rng = RNG::new(); 47 | // rng.set_seed(123); 48 | // 49 | // let op1 = Linear::new(Some(h*w), Some(h*w*2), true); 50 | // rng.normal_(op1.weight(), 0., 1.); 51 | // rng.normal_(op1.bias(), 0., 1.); 52 | // 53 | // let linear1 = Op::new(Box::new(op1)); 54 | // 55 | // let op2 = Linear::new(Some(h*w*2), Some(10), true); 56 | // rng.normal_(op2.weight(), 0., 1.); 57 | // rng.normal_(op2.bias(), 0., 1.); 58 | // 59 | // let linear2 = Op::new(Box::new(op2)); 60 | // 61 | // let activator = Op::new(Box::new(Sigmoid::new())); 62 | // 63 | // let input = m.var(); 64 | // let output = input 65 | // .to(&linear1) 66 | // .to(&activator) 67 | // .to(&linear2); 68 | // let label = m.var(); 69 | // 70 | // let loss = crossentropyloss(&output, &label); 71 | 72 | let mut rng = StdRng::seed_from_u64(671); 73 | 74 | let mut op1 = Linear::new(Some(h*w), Some(h*w*2), true); 75 | op1.set_weight(Var::normal(&mut rng, &[h*w, h*w*2], 0., 1.)); 76 | op1.set_bias(Var::normal(&mut rng, &[h*w*2, ], 0., 1.)); 77 | 78 | let mut op2 = Linear::new(Some(h*w*2), Some(10), true); 79 | op2.set_weight(Var::normal(&mut rng, &[h*w*2, 10], 0., 1.)); 80 | op2.set_bias(Var::normal(&mut rng, &[10, ], 0., 1.)); 81 | 82 | // //println!("{}, {}", &train_data, &train_label); 83 | let rng = StdRng::seed_from_u64(671); 84 | let mut minibatch = MiniBatch::new(rng, 16); 85 | 86 | // let mut writer = SummaryWriter::new(&("./logdir".to_string())); 87 | let (input, label) = minibatch.next(&train_data, &train_label).unwrap(); 88 | 89 | let output1 = op1.call(&[&input]).unwrap().pop().unwrap(); 90 | let output2 = output1.sigmoid().unwrap(); 91 | let output = op2.call(&[&output2]).unwrap().pop().unwrap(); 92 | 93 | let loss = output.cross_entropy_loss(&label).unwrap(); 94 | 95 | let lr = 0.1; 96 | let mut opt = SGD::new(lr); 97 | 98 | 99 | for i in 0..900 { 100 | println!("index: {}", i); 101 | let (input_next, label_next) = minibatch.next(&train_data, &train_label).unwrap(); 102 | input.set(&input_next); 103 | label.set(&label_next); 104 | println!("load data done"); 105 | 106 | //println!("dump net: {:?}", loss.dump_net().borrow()); 107 | loss.rerun().unwrap(); 108 | loss.bp().unwrap(); 109 | loss.step(&mut opt).unwrap(); 110 | 111 | println!("loss: {:?}", loss); 112 | // writer.add_scalar(&"mlp_mnist/train_loss".to_string(), f64::try_from(loss.clone()).unwrap() as f32, i); 113 | 114 | 115 | if i % 10 == 0 { 116 | let (input_next, label_next) = minibatch.next(&test_data, &test_label).unwrap(); 117 | input.set(&input_next); 118 | label.set(&label_next); 119 | loss.rerun().unwrap(); 120 | 121 | println!("test loss: {:?}", loss); 122 | 123 | //let loss_value = loss.get().get_scale_f32(); 124 | 125 | let tsum = output.clone().argmax(Some(&[1]), false).unwrap().eq_elem(&test_label).unwrap().mean(None, false); 126 | //let accuracy = tsum.get_scale_f32(); 127 | //println!("{}, loss: {}, accuracy: {}", i, loss_value, accuracy); 128 | println!("test accuracy: {:?}", tsum); 129 | 130 | //writer.add_scalar(&"run3/accuracy".to_string(), accuracy, i); 131 | //writer.flush(); 132 | } 133 | // 134 | // //println!("{}, loss: {}", i, loss.get().get_scale_f32()); 135 | // writer.add_scalar(&"run3/test_loss".to_string(), loss.get().get_scale_f32(), i); 136 | // writer.flush(); 137 | // 138 | // if i != 0 && i % 300 == 0 { 139 | // lr = lr / 3.; 140 | // opt = SGD::new(lr); 141 | // } 142 | } 143 | } 144 | -------------------------------------------------------------------------------- /auto-diff/src/serde/op.rs: -------------------------------------------------------------------------------- 1 | use crate::op::{Op, OpTrait}; 2 | #[cfg(feature = "use-serde")] 3 | use serde::{ 4 | de, de::MapAccess, de::SeqAccess, de::Visitor, ser::SerializeStruct, Deserialize, Deserializer, 5 | Serialize, Serializer, 6 | }; 7 | use std::fmt; 8 | use std::ops::Deref; 9 | 10 | impl Serialize for Box { 11 | fn serialize(&self, serializer: S) -> Result 12 | where 13 | S: Serializer, 14 | { 15 | // 3 is the number of fields in the struct. 16 | //let mut state = serializer.serialize_struct("OpTrait", 1)?; 17 | //state.serialize_field("op_name", &self.get_name())?; 18 | //state.end() 19 | crate::op::serialize_box::(&self, serializer) 20 | } 21 | } 22 | 23 | impl Serialize for Op { 24 | fn serialize(&self, serializer: S) -> Result 25 | where 26 | S: Serializer, 27 | { 28 | // 3 is the number of fields in the struct. 29 | let mut state = serializer.serialize_struct("Op", 2)?; 30 | state.serialize_field("op_name", &self.get_name())?; 31 | state.serialize_field("op_obj", &self.inner().borrow().deref())?; 32 | state.end() 33 | } 34 | } 35 | 36 | impl<'de> Deserialize<'de> for Op { 37 | fn deserialize(deserializer: D) -> Result 38 | where 39 | D: Deserializer<'de>, 40 | { 41 | enum Field { 42 | OpName, 43 | OpObj, 44 | } 45 | 46 | impl<'de> Deserialize<'de> for Field { 47 | fn deserialize(deserializer: D) -> Result 48 | where 49 | D: Deserializer<'de>, 50 | { 51 | struct FieldVisitor; 52 | 53 | impl<'de> Visitor<'de> for FieldVisitor { 54 | type Value = Field; 55 | 56 | fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { 57 | formatter.write_str("op_name or op_obj") 58 | } 59 | 60 | fn visit_str(self, value: &str) -> Result 61 | where 62 | E: de::Error, 63 | { 64 | match value { 65 | "op_name" => Ok(Field::OpName), 66 | "op_obj" => Ok(Field::OpObj), 67 | _ => Err(de::Error::unknown_field(value, &FIELDS)), 68 | } 69 | } 70 | } 71 | 72 | deserializer.deserialize_identifier(FieldVisitor) 73 | } 74 | } 75 | 76 | struct OpVisitor; 77 | 78 | impl<'de> Visitor<'de> for OpVisitor { 79 | type Value = Op; 80 | 81 | fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { 82 | formatter.write_str("struct Op") 83 | } 84 | 85 | fn visit_map(self, mut map: V) -> Result 86 | where 87 | V: MapAccess<'de>, 88 | { 89 | let mut op_name = None; 90 | while let Some(key) = map.next_key()? { 91 | match key { 92 | Field::OpName => { 93 | if op_name.is_some() { 94 | return Err(de::Error::duplicate_field("op_name")); 95 | } 96 | op_name = Some(map.next_value()?); 97 | } 98 | Field::OpObj => { 99 | //if op_obj.is_some() { 100 | // return Err(de::Error::duplicate_field("op_obj")); 101 | //} 102 | //op_obj = Some(map.next_value()?); 103 | let op_name: String = 104 | op_name.ok_or_else(|| de::Error::missing_field("op_name"))?; 105 | 106 | return crate::op::deserialize_map(op_name, map); 107 | } 108 | } 109 | } 110 | Err(de::Error::missing_field("op_obj")) 111 | } 112 | 113 | fn visit_seq(self, mut seq: V) -> Result 114 | where 115 | V: SeqAccess<'de>, 116 | { 117 | let op_name: String = seq 118 | .next_element()? 119 | .ok_or_else(|| de::Error::invalid_length(0, &self))?; 120 | return crate::op::deserialize_seq(op_name, seq); 121 | } 122 | } 123 | 124 | const FIELDS: [&str; 2] = ["op_name", "op_obj"]; 125 | deserializer.deserialize_struct("Op", &FIELDS, OpVisitor) 126 | } 127 | } 128 | 129 | #[cfg(all(test, feature = "use-serde"))] 130 | mod tests { 131 | use super::*; 132 | use crate::op::linear::Linear; 133 | use std::cell::RefCell; 134 | use std::rc::Rc; 135 | 136 | #[test] 137 | fn test_serde_op() { 138 | let m1 = Linear::new(None, None, true); 139 | let m1 = Op::new(Rc::new(RefCell::new(Box::new(m1)))); 140 | 141 | let serialized = serde_pickle::to_vec(&m1, true).unwrap(); 142 | let deserialized: Op = serde_pickle::from_slice(&serialized).unwrap(); 143 | //println!("{:?}", deserialized); 144 | //assert_eq!(m1, deserialized); 145 | } 146 | } 147 | -------------------------------------------------------------------------------- /auto-diff/src/op/linalg.rs: -------------------------------------------------------------------------------- 1 | use super::{OpHandle, OpTrait}; 2 | use tensor_rs::tensor::Tensor; 3 | 4 | #[cfg(feature = "use-serde")] 5 | use serde::{Deserialize, Serialize}; 6 | #[cfg(feature = "use-serde")] 7 | use std::any::Any; 8 | 9 | #[cfg_attr(feature = "use-serde", derive(Serialize, Deserialize))] 10 | pub struct NormalizeUnit { 11 | #[cfg_attr(feature = "use-serde", serde(skip))] 12 | handle: OpHandle, 13 | } 14 | impl NormalizeUnit { 15 | pub fn new() -> NormalizeUnit { 16 | NormalizeUnit { 17 | handle: OpHandle::new(), 18 | } 19 | } 20 | fn get_handle(&self) -> &OpHandle { 21 | &self.handle 22 | } 23 | fn get_handle_mut(&mut self) -> &mut OpHandle { 24 | &mut self.handle 25 | } 26 | } 27 | impl OpTrait for NormalizeUnit { 28 | fn get_name(&self) -> &'static str { 29 | "NormalizeUnit" 30 | } 31 | fn get_input_size(&self) -> usize { 32 | 1 33 | } 34 | fn get_output_size(&self) -> usize { 35 | 1 36 | } 37 | fn apply(&self, input: &[Tensor], output: &[Tensor]) { 38 | output[0].swap(&input[0].normalize_unit()); 39 | } 40 | fn grad(&self, input: &[Tensor], output_grad: &[Tensor], input_grad: &[Tensor]) { 41 | unimplemented!(); 42 | } 43 | fn get_values(&self) -> Vec { 44 | Vec::new() 45 | } 46 | fn get_grads(&self) -> Vec { 47 | Vec::new() 48 | } 49 | fn set_values(&self, _v: &[Tensor]) {} 50 | #[cfg(feature = "use-serde")] 51 | fn as_any(&self) -> &dyn Any { 52 | self 53 | } 54 | } 55 | impl Default for NormalizeUnit { 56 | fn default() -> Self { 57 | Self::new() 58 | } 59 | } 60 | 61 | #[cfg_attr(feature = "use-serde", derive(Serialize, Deserialize))] 62 | pub struct Det { 63 | #[cfg_attr(feature = "use-serde", serde(skip))] 64 | handle: OpHandle, 65 | } 66 | impl Det { 67 | pub fn new() -> Det { 68 | Det { 69 | handle: OpHandle::new(), 70 | } 71 | } 72 | fn get_handle(&self) -> &OpHandle { 73 | &self.handle 74 | } 75 | fn get_handle_mut(&mut self) -> &mut OpHandle { 76 | &mut self.handle 77 | } 78 | } 79 | impl OpTrait for Det { 80 | fn get_name(&self) -> &'static str { 81 | "Det" 82 | } 83 | fn get_input_size(&self) -> usize { 84 | 1 85 | } 86 | fn get_output_size(&self) -> usize { 87 | 1 88 | } 89 | fn apply(&self, input: &[Tensor], output: &[Tensor]) { 90 | output[0].swap(&input[0].det().expect("det() does not get a result.")); 91 | } 92 | fn grad(&self, input: &[Tensor], output_grad: &[Tensor], input_grad: &[Tensor]) { 93 | unimplemented!(); 94 | } 95 | fn get_values(&self) -> Vec { 96 | Vec::new() 97 | } 98 | fn get_grads(&self) -> Vec { 99 | Vec::new() 100 | } 101 | fn set_values(&self, _v: &[Tensor]) {} 102 | #[cfg(feature = "use-serde")] 103 | fn as_any(&self) -> &dyn Any { 104 | self 105 | } 106 | } 107 | impl Default for Det { 108 | fn default() -> Self { 109 | Self::new() 110 | } 111 | } 112 | 113 | #[cfg_attr(feature = "use-serde", derive(Serialize, Deserialize))] 114 | pub struct Inv { 115 | #[cfg_attr(feature = "use-serde", serde(skip))] 116 | handle: OpHandle, 117 | } 118 | impl Inv { 119 | pub fn new() -> Inv { 120 | Inv { 121 | handle: OpHandle::new(), 122 | } 123 | } 124 | fn get_handle(&self) -> &OpHandle { 125 | &self.handle 126 | } 127 | fn get_handle_mut(&mut self) -> &mut OpHandle { 128 | &mut self.handle 129 | } 130 | } 131 | impl OpTrait for Inv { 132 | fn get_name(&self) -> &'static str { 133 | "Inv" 134 | } 135 | fn get_input_size(&self) -> usize { 136 | 1 137 | } 138 | fn get_output_size(&self) -> usize { 139 | 1 140 | } 141 | fn apply(&self, input: &[Tensor], output: &[Tensor]) { 142 | output[0].swap(&input[0].inv().expect("inv() does not get a result.")); 143 | } 144 | fn grad(&self, input: &[Tensor], output_grad: &[Tensor], input_grad: &[Tensor]) { 145 | unimplemented!(); 146 | } 147 | fn get_values(&self) -> Vec { 148 | Vec::new() 149 | } 150 | fn get_grads(&self) -> Vec { 151 | Vec::new() 152 | } 153 | fn set_values(&self, _v: &[Tensor]) {} 154 | #[cfg(feature = "use-serde")] 155 | fn as_any(&self) -> &dyn Any { 156 | self 157 | } 158 | } 159 | impl Default for Inv { 160 | fn default() -> Self { 161 | Self::new() 162 | } 163 | } 164 | 165 | #[cfg_attr(feature = "use-serde", derive(Serialize, Deserialize))] 166 | pub struct Tr { 167 | #[cfg_attr(feature = "use-serde", serde(skip))] 168 | handle: OpHandle, 169 | } 170 | impl Tr { 171 | pub fn new() -> Tr { 172 | Tr { 173 | handle: OpHandle::new(), 174 | } 175 | } 176 | fn get_handle(&self) -> &OpHandle { 177 | &self.handle 178 | } 179 | fn get_handle_mut(&mut self) -> &mut OpHandle { 180 | &mut self.handle 181 | } 182 | } 183 | impl OpTrait for Tr { 184 | fn get_name(&self) -> &'static str { 185 | "Tr" 186 | } 187 | fn get_input_size(&self) -> usize { 188 | 1 189 | } 190 | fn get_output_size(&self) -> usize { 191 | 1 192 | } 193 | fn apply(&self, input: &[Tensor], output: &[Tensor]) { 194 | output[0].swap(&input[0].tr()); 195 | } 196 | fn grad(&self, input: &[Tensor], output_grad: &[Tensor], input_grad: &[Tensor]) { 197 | unimplemented!(); 198 | } 199 | fn get_values(&self) -> Vec { 200 | Vec::new() 201 | } 202 | fn get_grads(&self) -> Vec { 203 | Vec::new() 204 | } 205 | fn set_values(&self, _v: &[Tensor]) {} 206 | #[cfg(feature = "use-serde")] 207 | fn as_any(&self) -> &dyn Any { 208 | self 209 | } 210 | } 211 | impl Default for Tr { 212 | fn default() -> Self { 213 | Self::new() 214 | } 215 | } 216 | -------------------------------------------------------------------------------- /auto-diff/src/serde/var_inner.rs: -------------------------------------------------------------------------------- 1 | #[cfg(feature = "use-serde")] 2 | use serde::{ 3 | de, de::MapAccess, de::SeqAccess, de::Visitor, ser::SerializeStruct, Deserialize, Deserializer, 4 | Serialize, Serializer, 5 | }; 6 | use std::fmt; 7 | 8 | use crate::var_inner::VarInner; 9 | 10 | impl Serialize for VarInner { 11 | fn serialize(&self, serializer: S) -> Result 12 | where 13 | S: Serializer, 14 | { 15 | // 3 is the number of fields in the struct. 16 | let mut state = serializer.serialize_struct("VarInner", 3)?; 17 | state.serialize_field("id", &self.get_id())?; 18 | state.serialize_field("need_grad", &self.get_need_grad())?; 19 | state.serialize_field("net", &*self.get_net().borrow())?; 20 | state.end() 21 | } 22 | } 23 | 24 | impl<'de> Deserialize<'de> for VarInner { 25 | fn deserialize(deserializer: D) -> Result 26 | where 27 | D: Deserializer<'de>, 28 | { 29 | enum Field { 30 | Id, 31 | NeedGrad, 32 | Net, 33 | } 34 | 35 | impl<'de> Deserialize<'de> for Field { 36 | fn deserialize(deserializer: D) -> Result 37 | where 38 | D: Deserializer<'de>, 39 | { 40 | struct FieldVisitor; 41 | 42 | impl<'de> Visitor<'de> for FieldVisitor { 43 | type Value = Field; 44 | 45 | fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { 46 | formatter.write_str("id, need_grad, or net") 47 | } 48 | 49 | fn visit_str(self, value: &str) -> Result 50 | where 51 | E: de::Error, 52 | { 53 | match value { 54 | "id" => Ok(Field::Id), 55 | "need_grad" => Ok(Field::NeedGrad), 56 | "net" => Ok(Field::Net), 57 | _ => Err(de::Error::unknown_field(value, &FIELDS)), 58 | } 59 | } 60 | } 61 | 62 | deserializer.deserialize_identifier(FieldVisitor) 63 | } 64 | } 65 | 66 | struct VarInnerVisitor; 67 | 68 | impl<'de> Visitor<'de> for VarInnerVisitor { 69 | type Value = VarInner; 70 | 71 | fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { 72 | formatter.write_str("struct VarInner") 73 | } 74 | 75 | fn visit_map(self, mut map: V) -> Result 76 | where 77 | V: MapAccess<'de>, 78 | { 79 | let mut id = None; 80 | let mut need_grad = None; 81 | let mut net = None; 82 | while let Some(key) = map.next_key()? { 83 | match key { 84 | Field::Id => { 85 | if id.is_some() { 86 | return Err(de::Error::duplicate_field("id")); 87 | } 88 | id = Some(map.next_value()?); 89 | } 90 | Field::NeedGrad => { 91 | if need_grad.is_some() { 92 | return Err(de::Error::duplicate_field("need_grad")); 93 | } 94 | need_grad = Some(map.next_value()?); 95 | } 96 | Field::Net => { 97 | if net.is_some() { 98 | return Err(de::Error::duplicate_field("net")); 99 | } 100 | net = Some(map.next_value()?); 101 | } 102 | } 103 | } 104 | let id = id.ok_or_else(|| de::Error::missing_field("id"))?; 105 | let need_grad = need_grad.ok_or_else(|| de::Error::missing_field("need_grad"))?; 106 | let net = net.ok_or_else(|| de::Error::missing_field("net"))?; 107 | Ok(VarInner::set_inner(id, need_grad, net)) 108 | } 109 | 110 | fn visit_seq(self, mut seq: V) -> Result 111 | where 112 | V: SeqAccess<'de>, 113 | { 114 | let id = seq 115 | .next_element()? 116 | .ok_or_else(|| de::Error::invalid_length(0, &self))?; 117 | let need_grad = seq 118 | .next_element()? 119 | .ok_or_else(|| de::Error::invalid_length(0, &self))?; 120 | let net = seq 121 | .next_element()? 122 | .ok_or_else(|| de::Error::invalid_length(0, &self))?; 123 | Ok(VarInner::set_inner(id, need_grad, net)) 124 | } 125 | } 126 | 127 | const FIELDS: [&str; 3] = ["id", "need_grad", "net"]; 128 | deserializer.deserialize_struct("Duration", &FIELDS, VarInnerVisitor) 129 | } 130 | } 131 | 132 | #[cfg(all(test, feature = "use-serde"))] 133 | mod tests { 134 | use crate::var::Var; 135 | use crate::var_inner::VarInner; 136 | use rand::prelude::*; 137 | 138 | #[test] 139 | fn test_serde_var_inner() { 140 | let mut rng = StdRng::seed_from_u64(671); 141 | let n = 10; 142 | let data = Var::normal(&mut rng, &vec![n, 2], 0., 2.); 143 | let result = data.matmul(&Var::new(&vec![2., 3.], &vec![2, 1])).unwrap() 144 | + Var::new(&vec![1.], &vec![1]); 145 | 146 | let serialized = serde_pickle::to_vec(&*result.inner().borrow(), true).unwrap(); 147 | let deserialized: VarInner = serde_pickle::from_slice(&serialized).unwrap(); 148 | println!("{:?}", deserialized.dump_net()); 149 | //assert_eq!(*result.dump_net().borrow(), deserialized); 150 | } 151 | } 152 | -------------------------------------------------------------------------------- /tensorboard-rs/src/summary_writer.rs: -------------------------------------------------------------------------------- 1 | #![allow(clippy::too_many_arguments)] 2 | use std::path::{PathBuf, Path}; 3 | use std::time::SystemTime; 4 | use std::collections::HashMap; 5 | use protobuf::Message; 6 | use protobuf::RepeatedField; 7 | use crate::event_file_writer::EventFileWriter; 8 | use tensorboard_proto::event::{Event, TaggedRunMetadata}; 9 | use tensorboard_proto::summary::{Summary}; 10 | use tensorboard_proto::graph::{GraphDef, }; 11 | use tensorboard_proto::node_def::{NodeDef, }; 12 | use tensorboard_proto::versions::{VersionDef, }; 13 | //use tensorboard_proto::attr_value::{AttrValue, }; 14 | //use tensorboard_proto::tensor_shape::{TensorShapeProto, }; 15 | use tensorboard_proto::step_stats::{RunMetadata, }; 16 | use crate::summary::{scalar, image, histogram_raw}; 17 | 18 | 19 | pub struct FileWriter { 20 | writer: EventFileWriter, 21 | } 22 | impl FileWriter { 23 | pub fn new>(logdir: P) -> FileWriter { 24 | FileWriter { 25 | writer: EventFileWriter::new(logdir), 26 | } 27 | } 28 | pub fn get_logdir(&self) -> PathBuf { 29 | self.writer.get_logdir() 30 | } 31 | pub fn add_event(&mut self, event: &Event, step: usize) { 32 | let mut event = event.clone(); 33 | 34 | let mut time_full = 0.0; 35 | if let Ok(n) = SystemTime::now().duration_since(SystemTime::UNIX_EPOCH) { 36 | time_full = n.as_secs_f64(); 37 | } 38 | event.set_wall_time(time_full); 39 | 40 | event.set_step(step as i64); 41 | 42 | self.writer.add_event(&event) 43 | } 44 | pub fn add_summary(&mut self, summary: Summary, step: usize) { 45 | let mut evn = Event::new(); 46 | evn.set_summary(summary); 47 | self.add_event(&evn, step) 48 | } 49 | pub fn add_graph(&mut self, graph: GraphDef, meta: RunMetadata) { 50 | let mut graph_vec: Vec = Vec::new(); 51 | graph.write_to_vec(&mut graph_vec).expect(""); 52 | let mut graph_evn = Event::new(); 53 | graph_evn.set_graph_def(graph_vec); 54 | self.writer.add_event(&graph_evn); 55 | 56 | let mut meta_vec: Vec = Vec::new(); 57 | meta.write_to_vec(&mut meta_vec).expect(""); 58 | let mut tagged_meta = TaggedRunMetadata::new(); 59 | tagged_meta.set_tag("profiler".to_string()); 60 | tagged_meta.set_run_metadata(meta_vec); 61 | let mut meta_evn = Event::new(); 62 | meta_evn.set_tagged_run_metadata(tagged_meta); 63 | self.writer.add_event(&meta_evn); 64 | } 65 | pub fn flush(&mut self) { 66 | self.writer.flush() 67 | } 68 | } 69 | 70 | pub struct SummaryWriter { 71 | writer: FileWriter, 72 | all_writers: HashMap, 73 | } 74 | impl SummaryWriter { 75 | pub fn new>(logdir: P) -> SummaryWriter { 76 | SummaryWriter { 77 | writer: FileWriter::new(logdir), 78 | all_writers: HashMap::new(), 79 | } 80 | } 81 | pub fn add_hparams(&mut self) {unimplemented!();} 82 | pub fn add_scalar(&mut self, tag: &str, scalar_value: f32, step: usize) { 83 | self.writer.add_summary(scalar(tag, scalar_value), step); 84 | } 85 | pub fn add_scalars(&mut self, main_tag: &str, tag_scalar: &HashMap, step: usize) { 86 | let base_logdir = self.writer.get_logdir(); 87 | for (tag, scalar_value) in tag_scalar.iter() { 88 | let fw_tag = base_logdir.join(main_tag).join(tag); 89 | if ! self.all_writers.contains_key(&fw_tag) { 90 | let new_writer = FileWriter::new(fw_tag.clone()); 91 | self.all_writers.insert(fw_tag.clone(), new_writer); 92 | } 93 | let fw = self.all_writers.get_mut(&fw_tag).expect(""); 94 | fw.add_summary(scalar(main_tag, *scalar_value), step); 95 | } 96 | } 97 | 98 | pub fn export_scalars_to_json(&self) {unimplemented!();} 99 | pub fn add_histogram(&mut self) {unimplemented!();} 100 | pub fn add_histogram_raw(&mut self, 101 | tag: &str, 102 | min: f64, max: f64, 103 | num: f64, 104 | sum: f64, sum_squares: f64, 105 | bucket_limits: &[f64], bucket_counts: &[f64], 106 | step: usize 107 | ) { 108 | if bucket_limits.len() != bucket_counts.len() { 109 | panic!("bucket_limits.len() != bucket_counts.len()"); 110 | } 111 | 112 | self.writer.add_summary(histogram_raw(tag, min, max, num, sum, sum_squares, bucket_limits, bucket_counts), step); 113 | } 114 | pub fn add_image(&mut self, tag: &str, data: &[u8], dim: &[usize], step: usize) { 115 | self.writer.add_summary(image(tag, data, dim), step); 116 | } 117 | pub fn add_images(&mut self) {unimplemented!();} 118 | pub fn add_image_with_boxes(&mut self) {unimplemented!();} 119 | pub fn add_figure(&mut self) {unimplemented!();} 120 | pub fn add_video(&mut self) {unimplemented!();} 121 | pub fn add_audio(&mut self) {unimplemented!();} 122 | pub fn add_text(&mut self) {unimplemented!();} 123 | pub fn add_onnx_graph(&mut self) {unimplemented!();} 124 | pub fn add_openvino_graph(&mut self) {unimplemented!();} 125 | pub fn add_graph(&mut self, node_list: &[NodeDef]) { 126 | let mut graph = GraphDef::new(); 127 | 128 | let nodes = RepeatedField::from(node_list.to_vec()); 129 | graph.set_node(nodes); 130 | 131 | let mut version = VersionDef::new(); 132 | version.set_producer(22); 133 | graph.set_versions(version); 134 | 135 | let stats = RunMetadata::new(); 136 | 137 | self.writer.add_graph(graph, stats); 138 | } 139 | pub fn add_embedding(&mut self) {unimplemented!();} 140 | pub fn add_pr_curve(&mut self) {unimplemented!();} 141 | pub fn add_pr_curve_raw(&mut self) {unimplemented!();} 142 | pub fn add_custom_scalars_multilinechart(&mut self) {unimplemented!();} 143 | pub fn add_custom_scalars_marginchart(&mut self) {unimplemented!();} 144 | pub fn add_custom_scalars(&mut self) {unimplemented!();} 145 | pub fn add_mesh(&mut self) {unimplemented!();} 146 | 147 | pub fn flush(&mut self) { 148 | self.writer.flush(); 149 | } 150 | } 151 | -------------------------------------------------------------------------------- /data-pipe/src/dataloader/mnist.rs: -------------------------------------------------------------------------------- 1 | use crate::dataloader::{DataLoader, DataSlice}; 2 | use auto_diff::{Var, AutoDiffError}; 3 | use std::path::{Path, }; 4 | use std::io; 5 | use std::fs::File; 6 | use std::io::Read; 7 | 8 | pub struct Mnist { 9 | //path: PathBuf, 10 | train: Var, 11 | test: Var, 12 | train_label: Var, 13 | test_label: Var, 14 | } 15 | impl Mnist { 16 | pub fn new() -> Mnist { 17 | // TODO download the data if it is not there. 18 | unimplemented!() 19 | } 20 | pub fn load(path: &Path) -> Mnist { 21 | 22 | let train_fn = path.join("train-images-idx3-ubyte"); 23 | let test_fn = path.join("t10k-images-idx3-ubyte"); 24 | let train_label_fn = path.join("train-labels-idx1-ubyte"); 25 | let test_label_fn = path.join("t10k-labels-idx1-ubyte"); 26 | 27 | let train_img; 28 | let test_img; 29 | let train_label; 30 | let test_label; 31 | if path.exists() { 32 | train_img = Self::load_images(train_fn); 33 | test_img = Self::load_images(test_fn); 34 | train_label = Self::load_labels(train_label_fn); 35 | test_label = Self::load_labels(test_label_fn); 36 | } else { 37 | // TODO download the data if it is not there. 38 | 39 | unimplemented!() 40 | } 41 | 42 | Mnist { 43 | //path: PathBuf::from(path), 44 | train: train_img, 45 | test: test_img, 46 | train_label, 47 | test_label, 48 | } 49 | } 50 | 51 | fn load_images>(path: P) -> Var { 52 | let mut reader = io::BufReader::new(File::open(path).expect("")); 53 | let magic = Self::read_as_u32(&mut reader); 54 | if magic != 2051 { 55 | panic!("Invalid magic number. expected 2051, got {}", magic) 56 | } 57 | let num_image = Self::read_as_u32(&mut reader) as usize; 58 | let rows = Self::read_as_u32(&mut reader) as usize; 59 | let cols = Self::read_as_u32(&mut reader) as usize; 60 | assert!(rows == 28 && cols == 28); 61 | 62 | // read images 63 | let mut buf: Vec = vec![0u8; num_image * rows * cols]; 64 | let _ = reader.read_exact(buf.as_mut()); 65 | let ret: Vec = buf.into_iter().map(|x| (x as f64) / 255.).collect(); 66 | Var::new(&ret[..], &[num_image, rows, cols]) 67 | } 68 | 69 | fn load_labels>(path: P) -> Var { 70 | let mut reader = io::BufReader::new(File::open(path).expect("")); 71 | let magic = Self::read_as_u32(&mut reader); 72 | if magic != 2049 { 73 | panic!("Invalid magic number. Got expect 2049, got {}", magic); 74 | } 75 | let num_label = Self::read_as_u32(&mut reader) as usize; 76 | // read labels 77 | let mut buf: Vec = vec![0u8; num_label]; 78 | let _ = reader.read_exact(buf.as_mut()); 79 | let ret: Vec = buf.into_iter().map(|x| x as f64).collect(); 80 | Var::new(&ret[..], &[num_label]) 81 | } 82 | 83 | fn read_as_u32(reader: &mut T) -> u32 { 84 | let mut buf: [u8; 4] = [0, 0, 0, 0]; 85 | let _ = reader.read_exact(&mut buf); 86 | u32::from_be_bytes(buf) 87 | } 88 | } 89 | impl DataLoader for Mnist { 90 | fn get_size(&self, slice: Option) -> Result, AutoDiffError> { 91 | match slice { 92 | Some(DataSlice::Train) => {Ok(self.train.size())}, 93 | Some(DataSlice::Test) => {Ok(self.test.size())}, 94 | None => { 95 | let n = self.train.size()[0] + self.test.size()[1]; 96 | let mut new_size = self.train.size(); 97 | new_size[0] = n; 98 | Ok(new_size) 99 | }, 100 | _ => {Err(AutoDiffError::new("TODO"))} 101 | } 102 | } 103 | fn get_item(&self, index: usize, slice: Option) -> Result<(Var, Var), AutoDiffError> { 104 | match slice { 105 | Some(DataSlice::Train) => { 106 | let dim = self.train.size().len(); 107 | let mut index_block = vec![(index, index+1)]; 108 | index_block.append( 109 | &mut vec![0; dim-1].iter().zip(&self.train.size()[1..]) 110 | .map(|(x,y)| (*x, *y)).collect()); 111 | let data = self.train.get_patch(&index_block, None)?; 112 | let label = self.train_label.get_patch(&[(index, index+1)], None)?; 113 | self.train.reset_net(); 114 | self.train_label.reset_net(); 115 | Ok((data, label)) 116 | }, 117 | Some(DataSlice::Test) => { 118 | let dim = self.test.size().len(); 119 | let mut index_block = vec![(index, index+1)]; 120 | index_block.append( 121 | &mut vec![0; dim-1].iter().zip(&self.test.size()[1..]) 122 | .map(|(x,y)| (*x, *y)).collect()); 123 | let data = self.test.get_patch(&index_block, None)?; 124 | let label = self.test_label.get_patch(&[(index, index+1)], None)?; 125 | self.test.reset_net(); 126 | self.test_label.reset_net(); 127 | Ok((data, label)) 128 | }, 129 | _ => {Err(AutoDiffError::new("only train and test"))} 130 | } 131 | } 132 | fn get_batch(&self, start: usize, end: usize, slice: Option) -> Result<(Var, Var), AutoDiffError> { 133 | match slice { 134 | Some(DataSlice::Train) => { 135 | let dim = self.train.size().len(); 136 | let mut index_block = vec![(start, end)]; 137 | index_block.append( 138 | &mut vec![0; dim-1].iter().zip(&self.train.size()[1..]) 139 | .map(|(x,y)| (*x, *y)).collect()); 140 | let data = self.train.get_patch(&index_block, None)?; 141 | let label = self.train_label.get_patch(&[(start, end)], None)?; 142 | self.train.reset_net(); 143 | self.train_label.reset_net(); 144 | Ok((data, label)) 145 | }, 146 | Some(DataSlice::Test) => { 147 | let dim = self.test.size().len(); 148 | let mut index_block = vec![(start, end)]; 149 | index_block.append( 150 | &mut vec![0; dim-1].iter().zip(&self.test.size()[1..]) 151 | .map(|(x,y)| (*x, *y)).collect()); 152 | let data = self.test.get_patch(&index_block, None)?; 153 | let label = self.test_label.get_patch(&[(start, end)], None)?; 154 | self.test.reset_net(); 155 | self.test_label.reset_net(); 156 | Ok((data, label)) 157 | }, 158 | _ => {Err(AutoDiffError::new("only train and test"))} 159 | } 160 | } 161 | } 162 | 163 | #[cfg(test)] 164 | mod tests { 165 | use super::*; 166 | 167 | #[test] 168 | fn mnist() { 169 | let mnist = Mnist::load(Path::new("../auto-diff/examples/data/mnist/")); 170 | let (t0, l0) = mnist.get_item(0, Some(DataSlice::Test)).unwrap(); 171 | println!("{:?}", t0); 172 | } 173 | } 174 | 175 | -------------------------------------------------------------------------------- /tensor-rs/src/quaternion.rs: -------------------------------------------------------------------------------- 1 | use std::ops::Div; 2 | 3 | #[derive(PartialEq, Debug)] 4 | pub struct Quaternion { 5 | d: (T, T, T, T), 6 | } 7 | 8 | impl Default for Quaternion 9 | where T: num_traits::Float { 10 | fn default() -> Self { 11 | Quaternion { 12 | d: (T::one(), T::zero(), T::zero(), T::zero()) 13 | } 14 | } 15 | } 16 | 17 | impl Div for Quaternion 18 | where T: num_traits::Float { 19 | type Output = Self; 20 | 21 | fn div(self, rhs: T) -> Self::Output { 22 | Quaternion { 23 | d: (self.d.0/rhs, self.d.1/rhs, self.d.2/rhs, self.d.3/rhs) 24 | } 25 | } 26 | } 27 | 28 | impl Quaternion 29 | where T: num_traits::Float { 30 | 31 | pub fn new(a: T, b: T, c: T, d: T) -> Self { 32 | Quaternion { 33 | d: (a, b, c, d) 34 | } 35 | } 36 | 37 | pub fn scalar_part(&self) -> T { 38 | self.d.0 39 | } 40 | 41 | pub fn vector_part(&self) -> (T, T, T) { 42 | (self.d.1, self.d.2, self.d.3) 43 | } 44 | 45 | pub fn conjugate(&self) -> Self { 46 | Quaternion { 47 | d: (self.d.0, -self.d.1, -self.d.2, -self.d.3) 48 | } 49 | } 50 | 51 | pub fn dot(&self, o: &Quaternion) -> T { 52 | self.d.0*o.d.0 53 | + self.d.1*o.d.1 54 | + self.d.2*o.d.2 55 | + self.d.3*o.d.3 56 | } 57 | 58 | pub fn len(&self) -> T { 59 | T::sqrt(self.dot(self)) 60 | } 61 | 62 | pub fn norm(&self) -> T { 63 | self.len() 64 | } 65 | 66 | pub fn inverse(&self) -> Self { 67 | self.conjugate()/self.dot(self) 68 | } 69 | 70 | /// Make it a unit quaternion 71 | pub fn normalize(&self) -> Self { 72 | let n = self.norm(); 73 | 74 | Quaternion { 75 | d: (self.d.0/n, self.d.1/n, self.d.2/n, self.d.3/n) 76 | } 77 | } 78 | 79 | /// Quaternion multiplication 80 | pub fn qm(&self, o: &Quaternion) -> Self { 81 | Quaternion { 82 | d: (self.d.0*o.d.0 - self.d.1*o.d.1 - self.d.2*o.d.2 - self.d.3*o.d.3, 83 | self.d.0*o.d.1 + self.d.1*o.d.0 + self.d.2*o.d.3 - self.d.3*o.d.2, 84 | self.d.0*o.d.2 - self.d.1*o.d.3 + self.d.2*o.d.0 + self.d.3*o.d.1, 85 | self.d.0*o.d.3 + self.d.1*o.d.2 - self.d.2*o.d.1 + self.d.3*o.d.0,) 86 | } 87 | } 88 | 89 | /// Create a quaternion ready to apply to vector for rotation. 90 | pub fn rotation_around_axis(axis: (T, T, T), theta: T) -> Self { 91 | let a = T::cos(theta/(T::one() + T::one())); 92 | let coef = T::sin(theta/(T::one() + T::one())); 93 | let norm = T::sqrt(axis.0*axis.0 + axis.1*axis.1 + axis.2*axis.2); 94 | 95 | Quaternion { 96 | d: (a, coef*axis.0/norm, 97 | coef*axis.1/norm, 98 | coef*axis.2/norm) 99 | } 100 | } 101 | 102 | pub fn rotate_around_x(theta: T) -> Self { 103 | Self::rotation_around_axis((T::one(), T::zero(), T::zero()), theta) 104 | } 105 | 106 | pub fn rotate_around_y(theta: T) -> Self { 107 | Self::rotation_around_axis((T::zero(), T::one(), T::zero()), theta) 108 | } 109 | 110 | pub fn rotate_around_z(theta: T) -> Self { 111 | Self::rotation_around_axis((T::zero(), T::zero(), T::one()), theta) 112 | } 113 | 114 | /// Apply unit quaternion to 3d vector for rotation. 115 | pub fn apply_rotation(&self, v: (T, T, T)) -> (T, T, T) { 116 | if self.norm() != T::one() { 117 | println!("Apply a non unit quaternion for rotation!"); 118 | } 119 | 120 | let x = Quaternion { 121 | d: (T::zero(), v.0, v.1, v.2) 122 | }; 123 | let xp = self.qm(&x).qm(&self.conjugate()); 124 | (xp.d.1, xp.d.2, xp.d.3) 125 | } 126 | 127 | pub fn rotate_around_axis(axis: (T, T, T), theta: T, v: (T, T, T)) -> (T, T, T) { 128 | let q = Self::rotation_around_axis(axis, theta); 129 | q.apply_rotation(v) 130 | } 131 | 132 | pub fn unit_exp(&self, t: T) -> Self { 133 | if self.norm() != T::one() { 134 | println!("unit_exp needs unit quaternion!"); 135 | } 136 | 137 | let omega = T::acos(self.d.0); 138 | 139 | if T::sin(omega) == T::zero() { 140 | return Self::default(); 141 | } 142 | 143 | let i = self.d.1/T::sin(omega); 144 | let j = self.d.2/T::sin(omega); 145 | let k = self.d.3/T::sin(omega); 146 | 147 | Quaternion { 148 | d: (T::cos(t*omega), T::sin(t*omega)*i, 149 | T::sin(t*omega)*j, T::sin(t*omega)*k) 150 | } 151 | } 152 | 153 | pub fn slerp(p: &Self, q: &Self, t: T) -> Self { 154 | if p.norm() != T::one() || q.norm() != T::one() { 155 | println!("slerp need unit quaternion!"); 156 | } 157 | 158 | let p1 = p.normalize(); 159 | let q1 = q.normalize(); 160 | 161 | p1.qm(&p1.inverse().qm(&q1).unit_exp(t)) 162 | } 163 | 164 | 165 | 166 | } 167 | 168 | 169 | #[cfg(test)] 170 | mod tests { 171 | use super::*; 172 | 173 | #[test] 174 | fn test_qm() { 175 | let a = Quaternion::::new(1., 2., 3., 4.); 176 | let b = Quaternion::::new(2., 3., 4., 5.); 177 | 178 | let c = a.qm(&b); 179 | assert_eq!(c, Quaternion::::new(-36., 6., 12., 12.,)) 180 | } 181 | 182 | #[test] 183 | fn test_rotate_around_axis() { 184 | let a = Quaternion::::rotate_around_x(3.1415/2.); 185 | let v = (0., 0., 1.); 186 | let r = a.apply_rotation(v); 187 | 188 | assert!(f64::abs(r.0-0.) + f64::abs(r.1 + 1.) + f64::abs(r.2-0.) < 0.001); 189 | } 190 | 191 | #[test] 192 | fn test_unit_exp() { 193 | let b = Quaternion::::default(); 194 | let b1 = b.unit_exp(1.); 195 | assert_eq!(b1, Quaternion::::new(1., 0., 0., 0.)); 196 | 197 | let b = Quaternion::::new(0.5, 0.5, 0.5, 0.5); 198 | let b1 = b.unit_exp(0.); 199 | assert_eq!(b1, Quaternion::::new(1., 0., 0., 0.)); 200 | 201 | //let b = Quaternion::::new(0.5, 0.5, 0.5, 0.5); 202 | //let b1 = b.unit_exp(0.5); 203 | //assert_eq!(b1, Quaternion::::new(1., 0., 0., 0.)); 204 | 205 | //let b = Quaternion::::new(0.5, 0.5, 0.5, 0.5); 206 | //let b1 = b.unit_exp(1.); 207 | //assert_eq!(b1, Quaternion::::new(0.5, 0.5, 0.5, 0.5)); 208 | 209 | } 210 | 211 | #[test] 212 | fn test_slerp() { 213 | let a = Quaternion::::rotate_around_x(3.1415/2.); 214 | let b = Quaternion::::default(); 215 | let c = Quaternion::::slerp(&a, &b, 0.5); 216 | 217 | let v = (0., 0., 1.); 218 | let r = c.apply_rotation(v); 219 | 220 | assert_eq!(r, (0.0, -0.7070904020014415, 0.7071231599922606)); 221 | 222 | //assert_eq!(c, Quaternion::::default()); 223 | } 224 | } 225 | -------------------------------------------------------------------------------- /tensor-rs/src/tensor_impl/lapack_tensor/elemwise.rs: -------------------------------------------------------------------------------- 1 | use crate::tensor_impl::gen_tensor::GenTensor; 2 | #[cfg(feature = "use-blas-lapack")] 3 | use super::blas_api::BlasAPI; 4 | 5 | 6 | #[cfg(feature = "use-blas-lapack")] 7 | macro_rules! blas_add { 8 | ($a:ty, $b: ident) => { 9 | pub fn $b( 10 | x: &GenTensor<$a>, 11 | y: &GenTensor<$a>, 12 | ) -> GenTensor<$a> { 13 | let real_x; 14 | let mut real_y = y.get_data().clone(); 15 | let mut real_size = x.numel(); 16 | let real_x_vec; 17 | if x.numel() == 1 && y.numel() > 1 { 18 | real_x_vec = vec![x.get_data()[0]; y.numel()]; 19 | real_x = &real_x_vec; 20 | real_size = y.numel(); 21 | } else if x.numel() > 1 && y.numel() == 1 { 22 | real_x = x.get_data(); 23 | real_y = vec![real_y[0]; x.numel()]; 24 | real_size = x.numel(); 25 | } else if x.numel() == y.numel() { 26 | real_x = x.get_data(); 27 | } else { 28 | if x.numel() < y.numel() { 29 | panic!("right-hand broadcast only."); 30 | } 31 | if x.size().len() <= y.size().len() { 32 | panic!("unmatched dimension. {}, {}", x.size().len(), y.size().len()); 33 | } 34 | for i in 0..y.size().len() { 35 | if y.size()[y.size().len()-i-1] != x.size()[x.size().len()-i-1] { 36 | panic!("unmatched size."); 37 | } 38 | } 39 | real_x = x.get_data(); 40 | real_y = real_y.repeat(x.numel()/y.numel()); 41 | } 42 | 43 | BlasAPI::<$a>::axpy(real_size, 44 | 1.0 as $a, 45 | real_x, 1, 46 | &mut real_y, 1); 47 | GenTensor::<$a>::new_move(real_y, x.size().clone()) 48 | } 49 | } 50 | } 51 | 52 | #[cfg(feature = "use-blas-lapack")] 53 | blas_add!(f32, add_f32); 54 | 55 | #[cfg(feature = "use-blas-lapack")] 56 | blas_add!(f64, add_f64); 57 | 58 | 59 | #[cfg(feature = "use-blas-lapack")] 60 | macro_rules! blas_sub { 61 | ($a:ty, $b: ident) => { 62 | pub fn $b( 63 | x: &GenTensor<$a>, 64 | y: &GenTensor<$a>, 65 | ) -> GenTensor<$a> { 66 | if x.numel() == 1 && y.numel() > 1 { 67 | let mut real_x_vec = vec![x.get_data()[0]; y.numel()]; 68 | let real_size = y.numel(); 69 | BlasAPI::<$a>::axpy(real_size, 70 | -1.0 as $a, 71 | y.get_data(), 1, 72 | &mut real_x_vec, 1); 73 | return GenTensor::<$a>::new_move(real_x_vec, y.size().clone()); 74 | } else if x.numel() > 1 && y.numel() == 1 { 75 | let mut real_x_vec = x.get_data().clone(); 76 | let real_size = x.numel(); 77 | BlasAPI::<$a>::axpy(real_size, 78 | -1.0 as $a, 79 | y.get_data(), 1, 80 | &mut real_x_vec, 1); 81 | return GenTensor::<$a>::new_move(real_x_vec, y.size().clone()); 82 | } else if x.size() == y.size() { 83 | let mut real_x_vec = x.get_data().clone(); 84 | let real_size = x.numel(); 85 | BlasAPI::<$a>::axpy(real_size, 86 | -1.0 as $a, 87 | y.get_data(), 1, 88 | &mut real_x_vec, 1); 89 | return GenTensor::<$a>::new_move(real_x_vec, y.size().clone()); 90 | } else { 91 | if x.numel() < y.numel() { 92 | panic!("right-hand broadcast only."); 93 | } 94 | if x.size().len() <= y.size().len() { 95 | panic!("unmatched dimension and right-hand broadcast only. {}, {}", 96 | x.size().len(), y.size().len()); 97 | } 98 | for i in 0..y.size().len() { 99 | if y.size()[y.size().len()-i-1] != x.size()[x.size().len()-i-1] { 100 | panic!("unmatched size."); 101 | } 102 | } 103 | let mut real_x_vec = x.get_data().clone(); 104 | let real_y_vec = y.get_data().repeat(x.numel()/y.numel()); 105 | let real_size = x.numel(); 106 | BlasAPI::<$a>::axpy(real_size, 107 | -1.0 as $a, 108 | &real_y_vec, 1, 109 | &mut real_x_vec, 1); 110 | return GenTensor::<$a>::new_move(real_x_vec, x.size().clone()); 111 | } 112 | } 113 | } 114 | } 115 | 116 | #[cfg(feature = "use-blas-lapack")] 117 | blas_sub!(f32, sub_f32); 118 | 119 | #[cfg(feature = "use-blas-lapack")] 120 | blas_sub!(f64, sub_f64); 121 | 122 | #[cfg(test)] 123 | mod tests { 124 | use crate::tensor_impl::gen_tensor::GenTensor; 125 | use super::*; 126 | 127 | #[test] 128 | #[cfg(feature = "use-blas-lapack")] 129 | fn test_add() { 130 | let a = GenTensor::::ones(&[1, 2, 3]); 131 | let b = GenTensor::::ones(&[1, 2, 3]); 132 | let c = add_f32(&a, &b); 133 | let em = GenTensor::::new_raw(&[2.0, 2.0, 2.0, 2.0, 2.0, 2.0], &[1, 2, 3]); 134 | assert_eq!(c, em); 135 | 136 | let a = GenTensor::::ones(&[1, 2, 3]); 137 | let b = GenTensor::::ones(&[1, 2, 3]); 138 | let c = add_f64(&a, &b); 139 | let em = GenTensor::::new_raw(&[2.0, 2.0, 2.0, 2.0, 2.0, 2.0], &[1, 2, 3]); 140 | assert_eq!(c, em); 141 | 142 | let a = GenTensor::::ones(&[1, 2, 3]); 143 | let b = GenTensor::::ones(&[3]); 144 | let c = add_f64(&a, &b); 145 | let em = GenTensor::::new_raw(&[2.0, 2.0, 2.0, 2.0, 2.0, 2.0], &[1, 2, 3]); 146 | assert_eq!(c, em); 147 | } 148 | 149 | #[test] 150 | #[cfg(feature = "use-blas-lapack")] 151 | fn test_sub() { 152 | let a = GenTensor::::ones(&[1, 2, 3]); 153 | let b = GenTensor::::ones(&[1, 2, 3]); 154 | let c = sub_f32(&a, &b); 155 | let em = GenTensor::::new_raw(&[0.0, 0.0, 0.0, 0.0, 0.0, 0.0], &[1, 2, 3]); 156 | assert_eq!(c, em); 157 | 158 | let a = GenTensor::::ones(&[1, 2, 3]); 159 | let b = GenTensor::::ones(&[1, 2, 3]); 160 | let c = sub_f64(&a, &b); 161 | let em = GenTensor::::new_raw(&[0.0, 0.0, 0.0, 0.0, 0.0, 0.0], &[1, 2, 3]); 162 | assert_eq!(c, em); 163 | 164 | let a = GenTensor::::ones(&[1, 2, 3]); 165 | let b = GenTensor::::ones(&[3]); 166 | let c = sub_f64(&a, &b); 167 | let em = GenTensor::::new_raw(&[0.0, 0.0, 0.0, 0.0, 0.0, 0.0], &[1, 2, 3]); 168 | assert_eq!(c, em); 169 | } 170 | } 171 | -------------------------------------------------------------------------------- /auto-diff/examples/cnn_mnist.rs: -------------------------------------------------------------------------------- 1 | use tensor_rs::tensor::{PaddingMode}; 2 | use auto_diff::op::{Linear, OpCall, Conv2d}; 3 | use auto_diff::optim::{SGD, MiniBatch}; 4 | use auto_diff::Var; 5 | use rand::prelude::*; 6 | use ::rand::prelude::StdRng; 7 | extern crate openblas_src; 8 | 9 | //use tensorboard_rs::summary_writer::SummaryWriter; 10 | 11 | mod mnist; 12 | use mnist::{load_images, load_labels}; 13 | 14 | fn main() { 15 | let train_img = load_images("examples/data/mnist/train-images-idx3-ubyte"); 16 | let test_img = load_images("examples/data/mnist/t10k-images-idx3-ubyte"); 17 | let train_label = load_labels("examples/data/mnist/train-labels-idx1-ubyte"); 18 | let test_label = load_labels("examples/data/mnist/t10k-labels-idx1-ubyte"); 19 | 20 | let train_size = train_img.size(); 21 | let n = train_size[0]; 22 | let h = train_size[1]; 23 | let w = train_size[2]; 24 | let train_data = train_img.reshape(&vec![n, 1, h, w]).unwrap(); 25 | 26 | let test_size = test_img.size(); 27 | let n = test_size[0]; 28 | let h = test_size[1]; 29 | let w = test_size[2]; 30 | let test_data = test_img.reshape(&vec![n, 1, h, w]).unwrap(); 31 | 32 | train_data.reset_net(); 33 | train_label.reset_net(); 34 | test_data.reset_net(); 35 | test_label.reset_net(); 36 | 37 | let patch_size = 16; 38 | //let class_size = 10; 39 | 40 | // build the model 41 | // let mut m = Module::new(); 42 | // let mut rng = RNG::new(); 43 | // rng.set_seed(123); 44 | // 45 | // // 28 - (3x3) - 28 - (3x3,2) - 14 - (view) - 196 - (linear, 98.0) - 98 - (linear, 10) - 10 46 | // 47 | // let op1 = Conv2d::new(1, 32, (3,3), (1,1), (1,1), (1,1), true, PaddingMode::Zeros); 48 | // rng.normal_(op1.get_values()[0], 0., 1.); 49 | // rng.normal_(op1.get_values()[1], 0., 1.); 50 | // let conv1 = Op::new(Box::new(op1)); 51 | // 52 | // let op2 = Conv2d::new(32, 64, (3,3), (2,2), (1,1), (1,1), true, PaddingMode::Zeros); 53 | // rng.normal_(op2.get_values()[0], 0., 1.); 54 | // rng.normal_(op2.get_values()[1], 0., 1.); 55 | // let conv2 = Op::new(Box::new(op2)); 56 | // 57 | // let view = Op::new(Box::new(View::new(&[patch_size, 14*14*64]))); 58 | // 59 | // let op3 = Linear::new(Some(14*14*64), Some(14*14), true); 60 | // rng.normal_(op3.weight(), 0., 1.); 61 | // rng.normal_(op3.bias(), 0., 1.); 62 | // let linear3 = Op::new(Box::new(op3)); 63 | // 64 | // let op4 = Linear::new(Some(14*14), Some(10), true); 65 | // rng.normal_(op4.weight(), 0., 1.); 66 | // rng.normal_(op4.bias(), 0., 1.); 67 | // let linear4 = Op::new(Box::new(op4)); 68 | // 69 | // let mut acts = Vec::new(); 70 | // for i in 0..3 { 71 | // let act1 = Op::new(Box::new(ReLU::new())); 72 | // acts.push(act1); 73 | // } 74 | // 75 | // let input = m.var(); 76 | // let output = input 77 | // .to(&conv1) 78 | // .to(&acts[0]) 79 | // .to(&conv2) 80 | // .to(&acts[1]) 81 | // .to(&view) 82 | // .to(&linear3) 83 | // .to(&acts[2]) 84 | // .to(&linear4) 85 | // ; 86 | // let label = m.var(); 87 | // 88 | // let loss = crossentropyloss(&output, &label); 89 | // 90 | // let rng = RNG::new(); 91 | // let minibatch = MiniBatch::new(rng, patch_size); 92 | // 93 | // let mut lr = 0.01; 94 | // let mut opt = SGD::new(lr); 95 | // 96 | // let mut writer = SummaryWriter::new(&("./logdir".to_string())); 97 | 98 | 99 | let mut rng = StdRng::seed_from_u64(671); 100 | 101 | let mut op1 = Conv2d::new(1, 32, (3,3), (1,1), (1,1), (1,1), true, PaddingMode::Zeros); 102 | op1.set_weight(Var::normal(&mut rng, &op1.weight().size(), 0., 1.)); 103 | op1.set_bias(Var::normal(&mut rng, &op1.bias().size(), 0., 1.)); 104 | 105 | let mut op2 = Conv2d::new(32, 64, (3,3), (2,2), (1,1), (1,1), true, PaddingMode::Zeros); 106 | op2.set_weight(Var::normal(&mut rng, &op2.weight().size(), 0., 1.)); 107 | op2.set_bias(Var::normal(&mut rng, &op2.bias().size(), 0., 1.)); 108 | 109 | let mut op3 = Linear::new(Some(14*14*64), Some(14*14), true); 110 | op3.set_weight(Var::normal(&mut rng, &[14*14*64, 14*14], 0., 1.)); 111 | op3.set_bias(Var::normal(&mut rng, &[14*14, ], 0., 1.)); 112 | 113 | let mut op4 = Linear::new(Some(14*14), Some(10), true); 114 | op4.set_weight(Var::normal(&mut rng, &[14*14, 10], 0., 1.)); 115 | op4.set_bias(Var::normal(&mut rng, &[10, ], 0., 1.)); 116 | 117 | // //println!("{}, {}", &train_data, &train_label); 118 | let rng = StdRng::seed_from_u64(671); 119 | let mut minibatch = MiniBatch::new(rng, 16); 120 | 121 | // let mut writer = SummaryWriter::new(&("./logdir".to_string())); 122 | let (input, label) = minibatch.next(&train_data, &train_label).unwrap(); println!("here0"); 123 | 124 | let output1 = op1.call(&[&input]).unwrap().pop().unwrap(); println!("here"); 125 | let output1_1 = output1.relu().unwrap(); println!("here2"); 126 | let output2 = op2.call(&[&output1_1]).unwrap().pop().unwrap(); println!("here3"); 127 | let output2_1 = output2.relu().unwrap().view(&[patch_size, 14*14*64]).unwrap(); println!("her4"); 128 | let output3 = op3.call(&[&output2_1]).unwrap().pop().unwrap(); println!("here5"); 129 | let output3_1 = output3.relu().unwrap(); println!("her6"); 130 | let output = op4.call(&[&output3_1]).unwrap().pop().unwrap(); println!("here7"); 131 | 132 | let loss = output.cross_entropy_loss(&label).unwrap(); println!("here8"); 133 | 134 | let lr = 0.1; 135 | let mut opt = SGD::new(lr); 136 | 137 | println!("{:?}", loss); 138 | 139 | // 140 | // 141 | for i in 1..900 { 142 | println!("index: {}", i); 143 | 144 | //let (mdata, mlabel) = minibatch.next(&train_data, &train_label).unwrap(); 145 | let (input_next, label_next) = minibatch.next(&train_data, &train_label).unwrap(); 146 | input.set(&input_next); 147 | label.set(&label_next); 148 | println!("load data done"); 149 | 150 | loss.rerun().unwrap(); println!("rerun"); 151 | loss.bp().unwrap(); println!("bp"); 152 | loss.step(&mut opt).unwrap(); println!("step"); 153 | 154 | if i % 10 == 0 { 155 | 156 | let (input_next, label_next) = minibatch.next(&test_data, &test_label).unwrap(); 157 | input.set(&input_next); 158 | label.set(&label_next); 159 | loss.rerun().unwrap(); 160 | 161 | println!("test loss: {:?}", loss); 162 | 163 | //let loss_value = loss.get().get_scale_f32(); 164 | 165 | let tsum = output.clone().argmax(Some(&[1]), false).unwrap().eq_elem(&test_label).unwrap().mean(None, false); 166 | //let accuracy = tsum.get_scale_f32(); 167 | //println!("{}, loss: {}, accuracy: {}", i, loss_value, accuracy); 168 | println!("test error: {:?}", tsum); 169 | 170 | //writer.add_scalar(&"cnn/run1/accuracy".to_string(), accuracy, i); 171 | //writer.flush(); 172 | } 173 | 174 | //println!("{}, loss: {}", i, loss.get().get_scale_f32()); 175 | //writer.add_scalar(&"cnn/run1/test_loss".to_string(), loss.get().get_scale_f32(), i); 176 | //writer.flush(); 177 | // 178 | //if i != 0 && i % 300 == 0 { 179 | // lr = lr / 3.; 180 | // opt = SGD::new(lr); 181 | //} 182 | } 183 | } 184 | --------------------------------------------------------------------------------