├── src ├── macros │ ├── mod.rs │ ├── utils.rs │ └── aop_check.rs ├── proto │ └── mod.rs ├── util │ ├── mod.rs │ ├── rng.rs │ ├── blocking_queue.rs │ ├── io.rs │ ├── insert_splits.rs │ ├── mkl_alternate.rs │ ├── im2col.rs │ └── math_functions.rs ├── layers │ ├── mod.rs │ ├── neuron_layer.rs │ ├── bnll_layer.rs │ ├── absval_layer.rs │ ├── clip_layer.rs │ ├── batch_reindex_layer.rs │ ├── concat_layer.rs │ ├── argmax_layer.rs │ ├── bias_layer.rs │ ├── base_data_layer.rs │ ├── accuracy_layer.rs │ └── batch_norm_layer.rs ├── lib.rs ├── internal_thread.rs ├── common.rs ├── layer_factory.rs ├── synced_mem.rs ├── filler.rs ├── layer.rs ├── data_transformer.rs └── blob.rs ├── .cargo └── config.toml ├── .gitignore ├── .github └── workflows │ ├── rust-stable-latest.yml │ ├── rust-stable-min.yml │ └── rust-nightly.yml ├── Cargo.toml ├── res └── docs-header.html ├── LICENSE └── Readme.md /src/macros/mod.rs: -------------------------------------------------------------------------------- 1 | #[macro_use] 2 | pub mod aop_check; 3 | #[macro_use] 4 | pub mod utils; 5 | -------------------------------------------------------------------------------- /.cargo/config.toml: -------------------------------------------------------------------------------- 1 | [build] 2 | rustdocflags = [ "--html-in-header", "./res/docs-header.html" ] 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /target 2 | Cargo.lock 3 | 4 | # Jetbrains 5 | .idea/ 6 | 7 | # explicit ignore the temp protobuf generated file 8 | src/proto/caffe.rs 9 | -------------------------------------------------------------------------------- /src/proto/mod.rs: -------------------------------------------------------------------------------- 1 | //! Generated files are imported from here. 2 | 3 | include!(concat!(env!("OUT_DIR"), "/proto_gen/mod.rs")); 4 | 5 | // enable clion rust plugin analysis 6 | // pub mod caffe; 7 | -------------------------------------------------------------------------------- /src/macros/utils.rs: -------------------------------------------------------------------------------- 1 | 2 | /// Stub out GPU calls as unavailable. 3 | macro_rules! no_gpu { 4 | () => { 5 | assert!(false, "Cannot use GPU in CPU-only Caffe: check mode."); 6 | }; 7 | } 8 | -------------------------------------------------------------------------------- /src/util/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod mkl_alternate; 2 | pub mod math_functions; 3 | pub mod insert_splits; 4 | pub mod io; 5 | pub mod upgrade_proto; 6 | pub mod im2col; 7 | pub mod rng; 8 | pub mod blocking_queue; 9 | -------------------------------------------------------------------------------- /src/layers/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod neuron_layer; 2 | pub mod absval_layer; 3 | pub mod accuracy_layer; 4 | pub mod argmax_layer; 5 | pub mod base_conv_layer; 6 | pub mod base_data_layer; 7 | pub mod batch_norm_layer; 8 | pub mod batch_reindex_layer; 9 | pub mod bias_layer; 10 | pub mod bnll_layer; 11 | pub mod clip_layer; 12 | pub mod concat_layer; 13 | -------------------------------------------------------------------------------- /src/lib.rs: -------------------------------------------------------------------------------- 1 | #[macro_use] extern crate log; 2 | #[macro_use] extern crate static_init; 3 | #[macro_use] extern crate paste; 4 | 5 | #[macro_use] 6 | mod macros; 7 | mod proto; 8 | mod util; 9 | mod common; 10 | mod synced_mem; 11 | mod blob; 12 | mod filler; 13 | mod internal_thread; 14 | mod data_transformer; 15 | 16 | mod layer; 17 | #[macro_use] 18 | mod layer_factory; 19 | mod net; 20 | mod layers; 21 | 22 | #[cfg(test)] 23 | mod tests { 24 | use test_env_log::test; 25 | 26 | #[test] 27 | fn it_works() { 28 | assert_eq!(2 + 2, 4); 29 | } 30 | } 31 | -------------------------------------------------------------------------------- /.github/workflows/rust-stable-latest.yml: -------------------------------------------------------------------------------- 1 | name: Rust 2 | 3 | on: 4 | push: 5 | branches: [ master ] 6 | pull_request: 7 | branches: [ master ] 8 | workflow_dispatch: 9 | 10 | env: 11 | CARGO_TERM_COLOR: always 12 | 13 | jobs: 14 | build: 15 | name: Build and run test on ${{ matrix.os }} 16 | runs-on: ${{ matrix.os }} 17 | strategy: 18 | matrix: 19 | os: [ubuntu-latest, windows-latest, macOS-latest] 20 | 21 | steps: 22 | - name: Setup protoc 23 | uses: arduino/setup-protoc@v1 24 | with: 25 | repo-token: ${{ secrets.GITHUB_TOKEN }} 26 | - uses: actions/checkout@v2 27 | - name: Build 28 | run: cargo build --verbose 29 | - name: Run tests 30 | run: cargo test --verbose 31 | -------------------------------------------------------------------------------- /src/util/rng.rs: -------------------------------------------------------------------------------- 1 | use std::cell::RefCell; 2 | use std::rc::Rc; 3 | 4 | use rand::RngCore; 5 | use rand::distributions::{Uniform, Distribution}; 6 | 7 | use crate::common::{CaffeRng, Caffe}; 8 | 9 | 10 | pub fn caffe_rng() -> Rc> { 11 | Caffe::rng() 12 | } 13 | 14 | pub fn caffe_rng_rand() -> u32 { 15 | Caffe::rng_rand() 16 | } 17 | 18 | pub fn shuffle(slice: &mut [T], gen: &mut dyn RngCore) { 19 | if slice.len() < 1 { 20 | return; 21 | } 22 | 23 | for i in (1..slice.len()).rev() { 24 | let dist = Uniform::new(0, i); 25 | slice.swap(i, dist.sample(gen)); 26 | } 27 | } 28 | 29 | pub fn shuffle_caffe_rng(slice: &mut [T]) { 30 | shuffle(slice, caffe_rng().as_ref().borrow_mut().generator()); 31 | } 32 | -------------------------------------------------------------------------------- /.github/workflows/rust-stable-min.yml: -------------------------------------------------------------------------------- 1 | name: Rust - Stable Minimum 2 | 3 | on: 4 | push: 5 | branches: [ master ] 6 | pull_request: 7 | branches: [ master ] 8 | 9 | env: 10 | CARGO_TERM_COLOR: always 11 | 12 | jobs: 13 | build: 14 | 15 | runs-on: ubuntu-latest 16 | 17 | steps: 18 | - name: Setup protoc 19 | uses: arduino/setup-protoc@v1 20 | with: 21 | repo-token: ${{ secrets.GITHUB_TOKEN }} 22 | - name: Install V1.51 rust-toolchain 23 | uses: actions-rs/toolchain@v1 24 | with: 25 | profile: minimal 26 | default: true 27 | override: true 28 | toolchain: 1.51.0 29 | - uses: actions/checkout@v2 30 | - name: Check 31 | uses: actions-rs/cargo@v1 32 | with: 33 | command: check 34 | -------------------------------------------------------------------------------- /.github/workflows/rust-nightly.yml: -------------------------------------------------------------------------------- 1 | name: Rust - Nightly 2 | 3 | on: 4 | push: 5 | branches: [ master ] 6 | pull_request: 7 | branches: [ master ] 8 | 9 | env: 10 | CARGO_TERM_COLOR: always 11 | 12 | jobs: 13 | build: 14 | name: Nightly check on ubuntu-latest 15 | runs-on: ubuntu-latest 16 | 17 | steps: 18 | - name: Setup protoc 19 | uses: arduino/setup-protoc@v1 20 | with: 21 | repo-token: ${{ secrets.GITHUB_TOKEN }} 22 | - uses: actions/checkout@v2 23 | - name: Install nightly rust-toolchain 24 | uses: actions-rs/toolchain@v1 25 | with: 26 | profile: minimal 27 | default: true 28 | override: true 29 | toolchain: nightly 30 | - name: Build 31 | uses: actions-rs/cargo@v1 32 | with: 33 | command: check 34 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "caffe-rs" 3 | version = "0.1.0" 4 | authors = ["mx "] 5 | edition = "2018" 6 | 7 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html 8 | 9 | [dependencies] 10 | log = "0.4" 11 | cblas = "0.3" 12 | bytes = { version = "1.0" } 13 | protobuf = { version = "2", features = ["with-bytes"] } 14 | static_init = "^1" 15 | paste = "1" 16 | rand = { version = "0.8", features = ["std_rng"] } 17 | rand_distr = "0.4" 18 | mt19937 = "2" 19 | float_next_after = "0.1" 20 | #num-traits = "0.2" 21 | 22 | [dev-dependencies] 23 | env_logger = "*" 24 | test-env-log = "0.2" 25 | 26 | [build-dependencies] 27 | protobuf = { version = "2", features = ["with-bytes"] } 28 | protoc-rust = "2" 29 | 30 | [package.metadata.docs.rs] 31 | rustdoc-args = [ "--html-in-header", "./res/docs-header.html" ] 32 | -------------------------------------------------------------------------------- /res/docs-header.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /src/util/blocking_queue.rs: -------------------------------------------------------------------------------- 1 | //! Simple blocking queue implementation which take [this repo](https://github.com/Julian6bG/rust-blockinqueue) 2 | //! as a reference. 3 | 4 | use std::sync::{Arc, Mutex}; 5 | use std::sync::mpsc::{Sender, Receiver, channel}; 6 | 7 | 8 | pub struct BlockingQueue { 9 | sender: Sender, 10 | receiver: Arc>>, 11 | } 12 | 13 | impl Clone for BlockingQueue { 14 | fn clone(&self) -> Self { 15 | Self { 16 | sender: self.sender.clone(), 17 | receiver: self.receiver.clone(), 18 | } 19 | } 20 | } 21 | 22 | impl BlockingQueue { 23 | pub fn new() -> Self { 24 | let (sender, receiver) = channel(); 25 | Self { 26 | sender, 27 | receiver: Arc::new(Mutex::new(receiver)), 28 | } 29 | } 30 | 31 | pub fn push(&self, v: T) { 32 | self.sender.send(v).unwrap(); 33 | } 34 | 35 | pub fn pop(&self) -> T { 36 | self.receiver.lock().unwrap().recv().unwrap() 37 | } 38 | 39 | pub fn try_pop(&self) -> Option { 40 | self.receiver.lock().unwrap().try_recv().ok() 41 | } 42 | } 43 | -------------------------------------------------------------------------------- /src/layers/neuron_layer.rs: -------------------------------------------------------------------------------- 1 | use crate::blob::{BlobType}; 2 | use crate::layer::{LayerImpl, BlobVec}; 3 | use crate::proto::caffe::{LayerParameter}; 4 | 5 | 6 | /// An interface for layers that take one blob as input ($ x $) and produce one 7 | /// equally-sized blob as output ($ y $), where each element of the output 8 | /// depends only on the corresponding input element. 9 | pub struct NeuronLayer { 10 | layer: LayerImpl, 11 | } 12 | 13 | impl NeuronLayer { 14 | pub fn new(param: &LayerParameter) -> Self { 15 | NeuronLayer { 16 | layer: LayerImpl::new(param) 17 | } 18 | } 19 | 20 | #[inline] 21 | pub fn get_impl(&self) -> &LayerImpl { 22 | &self.layer 23 | } 24 | 25 | #[inline] 26 | pub fn get_impl_mut(&mut self) -> &mut LayerImpl { 27 | &mut self.layer 28 | } 29 | 30 | #[inline] 31 | pub fn reshape(&mut self, bottom: &BlobVec, top: &BlobVec) { 32 | top[0].borrow_mut().reshape_like(&*bottom[0].as_ref().borrow()); 33 | } 34 | 35 | #[inline] 36 | pub fn exact_num_bottom_blobs(&self) -> i32 { 37 | 1 38 | } 39 | 40 | #[inline] 41 | pub fn exact_num_top_blobs(&self) -> i32 { 42 | 1 43 | } 44 | } 45 | -------------------------------------------------------------------------------- /src/util/io.rs: -------------------------------------------------------------------------------- 1 | use std::fs::{File}; 2 | 3 | use protobuf::{Message, CodedInputStream}; 4 | use protobuf::text_format::print_to_string; 5 | 6 | 7 | pub fn read_proto_from_text_file(filename: &str, proto: &mut T) -> bool { 8 | // todo: protobuf impl not support text_format currently. use binary read instead. 9 | // todo: user should transform the .prototxt file into a binary format. 10 | let mut file = File::open(filename).unwrap(); 11 | let mut istream = CodedInputStream::new(&mut file); 12 | let success = proto.merge_from(&mut istream); 13 | success.is_ok() 14 | } 15 | 16 | #[inline] 17 | pub fn read_proto_from_text_file_or_die(filename: &str, proto: &mut T) { 18 | let r = read_proto_from_text_file(filename, proto); 19 | assert!(r); 20 | } 21 | 22 | pub fn write_proto_to_text_file(proto: &T, filename: &str) { 23 | // todo: protobuf impl only support serialize to string currently. use OutputStream in the future. 24 | let msg = print_to_string(proto); 25 | std::fs::write(filename, msg).unwrap(); 26 | } 27 | 28 | pub fn read_proto_from_binary_file(filename: &str, proto: &mut T) -> bool { 29 | let mut file = File::open(filename).unwrap(); 30 | let mut istream = CodedInputStream::new(&mut file); 31 | let success = proto.merge_from(&mut istream); 32 | success.is_ok() 33 | } 34 | 35 | #[inline] 36 | pub fn read_proto_from_binary_file_or_die(filename: &str, proto: &mut T) { 37 | let r = read_proto_from_binary_file(filename, proto); 38 | assert!(r); 39 | } 40 | 41 | pub fn write_proto_to_binary_file(proto: &T, filename: &str) { 42 | let mut file = File::create(filename).unwrap(); 43 | proto.write_to_writer(&mut file).unwrap(); 44 | } 45 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | COPYRIGHT 2 | 3 | All contributions by the University of California: 4 | Copyright (c) 2014-2017 The Regents of the University of California (Regents) 5 | All rights reserved. 6 | 7 | All other contributions: 8 | Copyright (c) 2014-2017, the respective contributors 9 | All rights reserved. 10 | 11 | Caffe uses a shared copyright model: each contributor holds copyright over 12 | their contributions to Caffe. The project versioning records all such 13 | contribution and copyright details. If a contributor wants to further mark 14 | their specific copyright on a particular contribution, they should indicate 15 | their copyright solely in the commit message of the change when it is 16 | committed. 17 | 18 | LICENSE 19 | 20 | Redistribution and use in source and binary forms, with or without 21 | modification, are permitted provided that the following conditions are met: 22 | 23 | 1. Redistributions of source code must retain the above copyright notice, this 24 | list of conditions and the following disclaimer. 25 | 2. Redistributions in binary form must reproduce the above copyright notice, 26 | this list of conditions and the following disclaimer in the documentation 27 | and/or other materials provided with the distribution. 28 | 29 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 30 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 31 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 32 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR 33 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 34 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 35 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 36 | ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 37 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 38 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 39 | 40 | CONTRIBUTION AGREEMENT 41 | 42 | By contributing to the BVLC/caffe repository through pull-request, comment, 43 | or otherwise, the contributor releases their content to the 44 | license and copyright terms herein. 45 | -------------------------------------------------------------------------------- /src/macros/aop_check.rs: -------------------------------------------------------------------------------- 1 | 2 | macro_rules! impl_op_check { 3 | ($op:tt, $left:expr, $right:expr) => { 4 | if !($left $op $right) { 5 | let lv = $left; 6 | let rv = $right; 7 | panic!("check failed: {:?}({:?}) {:?} {:?}({:?}).", 8 | stringify!($left), 9 | lv, 10 | stringify!($op), 11 | stringify!($right), 12 | rv); 13 | } 14 | }; 15 | ($op:tt, $left:expr, $right:expr, $msg:literal) => { 16 | if !($left $op $right) { 17 | let lv = $left; 18 | let rv = $right; 19 | panic!("check failed: {:?}({:?}) {:?} {:?}({:?}). msg: {:?}", 20 | stringify!($left), 21 | lv, 22 | stringify!($op), 23 | stringify!($right), 24 | rv, 25 | $msg); 26 | } 27 | }; 28 | ($op:tt, $left:expr, $right:expr, $fmt:literal, $($element:expr),*) => { 29 | if !($left $op $right) { 30 | let lv = $left; 31 | let rv = $right; 32 | panic!("check failed: {:?}({:?}) {:?} {:?}({:?}). msg: {:?}", 33 | stringify!($left), 34 | lv, 35 | stringify!($op), 36 | stringify!($right), 37 | rv, 38 | format!($fmt, $($element),*)); 39 | } 40 | } 41 | } 42 | 43 | // #[macro_export] 44 | macro_rules! check_eq { 45 | ($left:expr, $right:expr, $($e:expr),*) => { 46 | impl_op_check!(==, $left, $right, $($e),*); 47 | }; 48 | ($left:expr, $right:expr) => { 49 | impl_op_check!(==, $left, $right); 50 | }; 51 | } 52 | 53 | // #[macro_export] 54 | macro_rules! check_le { 55 | ($left:expr, $right:expr, $($e:expr),*) => { 56 | impl_op_check!(<=, $left, $right, $($e),*); 57 | }; 58 | ($left:expr, $right:expr) => { 59 | impl_op_check!(<=, $left, $right); 60 | }; 61 | } 62 | 63 | macro_rules! check_lt { 64 | ($left:expr, $right:expr, $($e:expr),*) => { 65 | impl_op_check!(<, $left, $right, $($e),*); 66 | }; 67 | ($left:expr, $right:expr) => { 68 | impl_op_check!(<, $left, $right); 69 | }; 70 | } 71 | 72 | macro_rules! check_ge { 73 | ($left:expr, $right:expr, $($e:expr),*) => { 74 | impl_op_check!(>=, $left, $right, $($e),*); 75 | }; 76 | ($left:expr, $right:expr) => { 77 | impl_op_check!(>=, $left, $right); 78 | }; 79 | } 80 | 81 | macro_rules! check_gt { 82 | ($left:expr, $right:expr, $($e:expr),*) => { 83 | impl_op_check!(>, $left, $right, $($e),*); 84 | }; 85 | ($left:expr, $right:expr) => { 86 | impl_op_check!(>, $left, $right); 87 | }; 88 | } 89 | -------------------------------------------------------------------------------- /src/internal_thread.rs: -------------------------------------------------------------------------------- 1 | use std::sync::Arc; 2 | use std::sync::atomic::{AtomicBool, Ordering}; 3 | use std::thread::JoinHandle; 4 | 5 | use crate::common::{CaffeBrew, Caffe}; 6 | use crate::util::rng::caffe_rng_rand; 7 | 8 | 9 | #[derive(Default)] 10 | pub struct InternalThreadImpl { 11 | pub thread: Option>, 12 | pub interrupt: Arc, 13 | } 14 | 15 | pub struct CancelToken { 16 | interrupt: Arc, 17 | } 18 | 19 | impl CancelToken { 20 | pub fn new(ir: &Arc) -> Self { 21 | CancelToken { 22 | interrupt: Arc::clone(ir) 23 | } 24 | } 25 | 26 | pub fn is_cancelled(&self) -> bool { 27 | self.interrupt.load(Ordering::Relaxed) 28 | } 29 | } 30 | 31 | 32 | /// Trait encapsulate std::thread for use in base class. The child class will acquire the 33 | /// ability to run a single thread, by implementing the virtual function `internal_thread_entry`. 34 | pub trait InternalThread { 35 | type EntryData: Send + 'static; 36 | 37 | fn get_thread(&self) -> &InternalThreadImpl; 38 | 39 | fn get_thread_mut(&mut self) -> &mut InternalThreadImpl; 40 | 41 | fn get_entry_data(&mut self) -> Box; 42 | 43 | /// Implement this method in your subclass with the code you want your thread to run. 44 | fn internal_thread_entry(token: CancelToken, data: Box); 45 | 46 | /// Caffe's thread local state will be initialized using the current 47 | /// thread values, e.g. device id, solver index etc. The random seed 48 | /// is initialized using caffe_rng_rand. 49 | fn start_internal_thread(&mut self) { 50 | assert!(!self.is_started(), "Threads should persist and not be restarted."); 51 | 52 | let _device = 0; 53 | let mode = Caffe::mode(); 54 | let rand_seed = caffe_rng_rand(); 55 | let solver_count = Caffe::solver_count(); 56 | let solver_rank = Caffe::solver_rank(); 57 | let multiprocess = Caffe::multiprocess(); 58 | let data = self.get_entry_data(); 59 | 60 | let th = self.get_thread_mut(); 61 | let token = CancelToken::new(&th.interrupt); 62 | th.thread = Some(std::thread::spawn(move || { 63 | Caffe::set_mode(mode); 64 | Caffe::set_random_seed(rand_seed as u64); 65 | Caffe::set_solver_count(solver_count); 66 | Caffe::set_solver_rank(solver_rank); 67 | Caffe::set_multiprocess(multiprocess); 68 | Self::internal_thread_entry(token, data); 69 | })); 70 | } 71 | 72 | /// Will block until the internal thread has exited. 73 | fn stop_internal_thread(&mut self) { 74 | let th = self.get_thread_mut(); 75 | th.interrupt.store(true, Ordering::Relaxed); 76 | let handle = th.thread.take(); 77 | handle.map(|t| t.join().unwrap()); 78 | } 79 | 80 | fn is_started(&self) -> bool { 81 | self.get_thread().thread.is_some() 82 | } 83 | } 84 | -------------------------------------------------------------------------------- /src/layers/bnll_layer.rs: -------------------------------------------------------------------------------- 1 | use crate::blob::BlobType; 2 | use crate::layer::{CaffeLayer, LayerImpl, BlobVec}; 3 | use crate::layers::neuron_layer::NeuronLayer; 4 | use crate::proto::caffe::LayerParameter; 5 | 6 | 7 | pub struct BNLLLayer { 8 | layer: NeuronLayer, 9 | } 10 | 11 | impl BNLLLayer { 12 | pub fn new(param: &LayerParameter) -> Self { 13 | Self { 14 | layer: NeuronLayer::new(param), 15 | } 16 | } 17 | } 18 | 19 | const BNLL_THRESHOLD: f64 = 50.0; 20 | 21 | impl CaffeLayer for BNLLLayer { 22 | type DataType = T; 23 | 24 | fn get_impl(&self) -> &LayerImpl { 25 | self.layer.get_impl() 26 | } 27 | 28 | fn get_impl_mut(&mut self) -> &mut LayerImpl { 29 | self.layer.get_impl_mut() 30 | } 31 | 32 | fn reshape(&mut self, bottom: &BlobVec, top: &BlobVec) { 33 | self.layer.reshape(bottom, top); 34 | } 35 | 36 | fn layer_type(&self) -> &'static str { 37 | "BNLL" 38 | } 39 | 40 | fn exact_num_bottom_blobs(&self) -> i32 { 41 | self.layer.exact_num_bottom_blobs() 42 | } 43 | 44 | fn exact_num_top_blobs(&self) -> i32 { 45 | self.layer.exact_num_bottom_blobs() 46 | } 47 | 48 | fn forward_cpu(&mut self, bottom: &BlobVec, top: &BlobVec) { 49 | let b0 = bottom[0].as_ref().borrow(); 50 | let mut t0 = top[0].borrow_mut(); 51 | let count = b0.count(); 52 | let bottom_data = b0.cpu_data(); 53 | let top_data = t0.mutable_cpu_data(); 54 | for i in 0..count { 55 | let bottom = bottom_data[i]; 56 | top_data[i] = if bottom > T::default() { 57 | bottom + T::ln(T::exp(-bottom) + T::from_i32(1)) 58 | } else { 59 | T::ln(T::exp(bottom) + T::from_i32(1)) 60 | }; 61 | } 62 | } 63 | 64 | fn forward_gpu(&mut self, _bottom: &BlobVec, _top: &BlobVec) { 65 | no_gpu!(); 66 | } 67 | 68 | fn backward_cpu(&mut self, top: &BlobVec, propagate_down: &Vec, bottom: &BlobVec) { 69 | if !propagate_down[0] { 70 | return; 71 | } 72 | 73 | let t0 = top[0].as_ref().borrow(); 74 | let mut b0 = bottom[0].borrow_mut(); 75 | let top_diff = t0.cpu_diff(); 76 | let count = b0.count(); 77 | let mem_ref = b0.mutable_cpu_mem_ref(); 78 | let threshold = T::from_f64(BNLL_THRESHOLD); 79 | for i in 0..count { 80 | let exp_val = T::exp(T::min(mem_ref.data[i], threshold)); 81 | mem_ref.diff[i] = (top_diff[i] * exp_val) / (exp_val + T::from_i32(1)); 82 | } 83 | } 84 | 85 | fn backward_gpu(&mut self, _top: &BlobVec, _propagate_down: &Vec, _bottom: &BlobVec) { 86 | no_gpu!(); 87 | } 88 | } 89 | 90 | register_layer_class!(BNLL); 91 | -------------------------------------------------------------------------------- /src/layers/absval_layer.rs: -------------------------------------------------------------------------------- 1 | use super::neuron_layer::NeuronLayer; 2 | 3 | use crate::blob::{BlobType, BlobMemRefMut}; 4 | use crate::layer::{CaffeLayer, LayerImpl, BlobVec, def_layer_setup}; 5 | use crate::proto::caffe::{LayerParameter}; 6 | use crate::util::math_functions::CaffeNum; 7 | 8 | 9 | /// Computes $ y = |x| $. 10 | /// 11 | /// Bottom input Blob vector (length 1) 12 | /// - $ (N \times C \times H \times W) $ the inputs $ x $. 13 | /// 14 | /// Top output Blob vector (length 1) 15 | /// - $ (N \times C \times H \times W) $ the computed outputs $ y = |x| $. 16 | pub struct AbsValLayer { 17 | layer: NeuronLayer, 18 | } 19 | 20 | impl AbsValLayer { 21 | pub fn new(param: &LayerParameter) -> Self { 22 | AbsValLayer { 23 | layer: NeuronLayer::new(param) 24 | } 25 | } 26 | } 27 | 28 | impl CaffeLayer for AbsValLayer { 29 | type DataType = T; 30 | 31 | fn get_impl(&self) -> &LayerImpl { 32 | self.layer.get_impl() 33 | } 34 | 35 | fn get_impl_mut(&mut self) -> &mut LayerImpl { 36 | self.layer.get_impl_mut() 37 | } 38 | 39 | fn layer_setup(&mut self, bottom: &BlobVec, top: &BlobVec) { 40 | def_layer_setup(self, bottom, top); 41 | assert_ne!(top[0].as_ptr(), bottom[0].as_ptr(), 42 | "{:?} Layer does not allow in-place computation.", 43 | self.layer_type()); 44 | } 45 | 46 | fn reshape(&mut self, bottom: &BlobVec, top: &BlobVec) { 47 | self.layer.reshape(bottom, top); 48 | } 49 | 50 | fn layer_type(&self) -> &'static str { 51 | "AbsVal" 52 | } 53 | 54 | fn exact_num_bottom_blobs(&self) -> i32 { 55 | 1 56 | } 57 | 58 | fn exact_num_top_blobs(&self) -> i32 { 59 | 1 60 | } 61 | 62 | fn forward_cpu(&mut self, bottom: &BlobVec, top: &BlobVec) { 63 | let count = top[0].as_ref().borrow().count(); 64 | let mut top_data = top[0].borrow_mut(); 65 | T::caffe_abs(count, bottom[0].as_ref().borrow().cpu_data(), top_data.mutable_cpu_data()); 66 | } 67 | 68 | fn forward_gpu(&mut self, bottom: &BlobVec, top: &BlobVec) { 69 | no_gpu!(); 70 | } 71 | 72 | fn backward_cpu(&mut self, top: &BlobVec, propagate_down: &Vec, bottom: &BlobVec) { 73 | if propagate_down[0] { 74 | let count = top[0].as_ref().borrow().count(); 75 | let top_diff = top[0].as_ref().borrow(); 76 | let top_diff = top_diff.cpu_diff(); 77 | let mut bottom_mut = bottom[0].borrow_mut(); 78 | let bottom_mut = bottom_mut.mutable_cpu_mem_ref(); 79 | T::caffe_cpu_sign(count, bottom_mut.data, bottom_mut.diff); 80 | T::caffe_mul_assign(count, bottom_mut.diff, top_diff); 81 | } 82 | } 83 | 84 | fn backward_gpu(&mut self, top: &BlobVec, propagate_down: &Vec, bottom: &BlobVec) { 85 | no_gpu!(); 86 | } 87 | } 88 | 89 | register_layer_class!(AbsVal); 90 | -------------------------------------------------------------------------------- /Readme.md: -------------------------------------------------------------------------------- 1 | # Caffe-rs 2 | 3 | [![GitHub](https://img.shields.io/badge/GitHub-mematrix/caffe--rs-lightgrey?style=flat&logo=github&color=orange)](https://github.com/mematrix/caffe-rs) 4 | [![License](https://img.shields.io/badge/license-BSD-blue.svg)](LICENSE) 5 | 6 | [![Rust](https://github.com/mematrix/caffe-rs/actions/workflows/rust-stable-latest.yml/badge.svg)](https://github.com/mematrix/caffe-rs/actions/workflows/rust-stable-latest.yml) 7 | [![Rust - Nightly](https://github.com/mematrix/caffe-rs/actions/workflows/rust-nightly.yml/badge.svg)](https://github.com/mematrix/caffe-rs/actions/workflows/rust-nightly.yml) 8 | [![Rust - Stable Minimum](https://github.com/mematrix/caffe-rs/actions/workflows/rust-stable-min.yml/badge.svg)](https://github.com/mematrix/caffe-rs/actions/workflows/rust-stable-min.yml) 9 | 10 | ## Toolchain Required 11 | With **no unstable** feature enabled, this project needs a toolchain version `>=1.51.0` to be built. 12 | 13 | ## Build 14 | 15 | ```shell 16 | cd path/to/caffe-rs 17 | cargo build 18 | ``` 19 | 20 | ## Usage 21 | 22 | ### Register custom layer 23 | The library exports two macros: 24 | 25 | - `register_layer_class!` which is used to register the Layer class which impls the trait `caffe_rs::layer::CaffeLayer` and has a public fn names `new` with a single `caffe_rs::proto::caffe::LayerParameter` param. 26 | ```rust 27 | use caffe_rs::proto::caffe::LayerParameter; 28 | use caffe_rs::blob::BlobType; 29 | use caffe_rs::layer::CaffeLayer; 30 | use std::marker::PhantomData; 31 | 32 | struct TestLayer { 33 | phantom: PhantomData, 34 | } 35 | 36 | impl CaffeLayer for TestLayer { 37 | type DataType = T; 38 | // Impl the necessary functions. 39 | } 40 | 41 | impl TestLayer { 42 | pub fn new(_param: &LayerParameter) -> Self { 43 | TestLayer { 44 | phantom: PhantomData 45 | } 46 | } 47 | } 48 | 49 | // Note: the name does not contains the trailing 'Layer'. 50 | register_layer_class!(Test); 51 | ``` 52 | 53 | - `register_layer_creator!` which is used to register the custom layer creator function. The creator signature is defined as likes the `test`: 54 | ```rust 55 | use caffe_rs::proto::caffe::LayerParameter; 56 | use caffe_rs::layer::SharedLayer; 57 | use caffe_rs::blob::BlobType; 58 | 59 | fn test(p: &LayerParameter) -> Rc>> { 60 | unimplemented!(); 61 | } 62 | 63 | // register `TestLayer` with a creator. 64 | register_layer_creator!(Test, test); 65 | ``` 66 | 67 | **And important**: you need import the dependency crate `paste` to use the register macros. 68 | 69 | On your `Cargo.toml`: 70 | 71 | ```toml 72 | [dependencies] 73 | paste = "1" 74 | ``` 75 | 76 | and import the macro on your crate root (`src/lib.rs` or `src/main.rs`): 77 | 78 | ```rust 79 | #[macro_use] extern crate paste; 80 | ``` 81 | 82 | ## License 83 | **Caffe-rs** is released under the [BSD 2-Clause license](https://github.com/mematrix/caffe-rs/blob/master/LICENSE) that is the same as the [Caffe project](https://github.com/BVLC/caffe/). 84 | -------------------------------------------------------------------------------- /src/layers/clip_layer.rs: -------------------------------------------------------------------------------- 1 | use crate::blob::BlobType; 2 | use crate::layer::{CaffeLayer, LayerImpl, BlobVec}; 3 | use crate::layers::neuron_layer::NeuronLayer; 4 | use crate::proto::caffe::LayerParameter; 5 | 6 | 7 | /// Clip: $$ y = \max(min, \min(max, x)) $$. 8 | pub struct ClipLayer { 9 | layer: NeuronLayer, 10 | } 11 | 12 | impl ClipLayer { 13 | /// `param` provides **ClipParameter** clip_param, with **ClipLayer options**: 14 | /// - min 15 | /// - max 16 | pub fn new(param: &LayerParameter) -> Self { 17 | Self { 18 | layer: NeuronLayer::new(param), 19 | } 20 | } 21 | } 22 | 23 | impl CaffeLayer for ClipLayer { 24 | type DataType = T; 25 | 26 | fn get_impl(&self) -> &LayerImpl { 27 | self.layer.get_impl() 28 | } 29 | 30 | fn get_impl_mut(&mut self) -> &mut LayerImpl { 31 | self.layer.get_impl_mut() 32 | } 33 | 34 | fn reshape(&mut self, bottom: &BlobVec, top: &BlobVec) { 35 | self.layer.reshape(bottom, top); 36 | } 37 | 38 | fn layer_type(&self) -> &'static str { 39 | "Clip" 40 | } 41 | 42 | fn exact_num_bottom_blobs(&self) -> i32 { 43 | self.layer.exact_num_bottom_blobs() 44 | } 45 | 46 | fn exact_num_top_blobs(&self) -> i32 { 47 | self.layer.exact_num_top_blobs() 48 | } 49 | 50 | fn forward_cpu(&mut self, bottom: &BlobVec, top: &BlobVec) { 51 | let b0 = bottom[0].as_ref().borrow(); 52 | let mut t0 = top[0].borrow_mut(); 53 | let bottom_data = b0.cpu_data(); 54 | let count = b0.count(); 55 | let top_data = t0.mutable_cpu_data(); 56 | 57 | let min = self.layer.get_impl().layer_param.get_clip_param().get_min(); 58 | let max = self.layer.get_impl().layer_param.get_clip_param().get_max(); 59 | let min = T::from_f32(min); 60 | let max = T::from_f32(max); 61 | 62 | for i in 0..count { 63 | top_data[i] = T::max(min, T::min(bottom_data[i], max)); 64 | } 65 | } 66 | 67 | fn forward_gpu(&mut self, _bottom: &BlobVec, _top: &BlobVec) { 68 | no_gpu!(); 69 | } 70 | 71 | fn backward_cpu(&mut self, top: &BlobVec, propagate_down: &Vec, 72 | bottom: &BlobVec) { 73 | if !propagate_down[0] { 74 | return; 75 | } 76 | 77 | let mut b0 = bottom[0].borrow_mut(); 78 | let t0 = top[0].as_ref().borrow(); 79 | 80 | let count = b0.count(); 81 | let top_diff = t0.cpu_diff(); 82 | let bottom_ref = b0.mutable_cpu_mem_ref(); 83 | 84 | let min = self.layer.get_impl().layer_param.get_clip_param().get_min(); 85 | let max = self.layer.get_impl().layer_param.get_clip_param().get_max(); 86 | let min = T::from_f32(min); 87 | let max = T::from_f32(max); 88 | 89 | for i in 0..count { 90 | let data = bottom_ref.data[i]; 91 | let in_range = (data >= min && data <= max) as i32; 92 | bottom_ref.diff[i] = top_diff[i] * T::from_i32(in_range); 93 | } 94 | } 95 | 96 | fn backward_gpu(&mut self, _top: &BlobVec, _propagate_down: &Vec, 97 | _bottom: &BlobVec) { 98 | no_gpu!(); 99 | } 100 | } 101 | 102 | register_layer_class!(Clip); 103 | -------------------------------------------------------------------------------- /src/common.rs: -------------------------------------------------------------------------------- 1 | use std::cell::RefCell; 2 | use std::rc::Rc; 3 | 4 | use mt19937::MT19937; 5 | use rand::{RngCore, thread_rng, SeedableRng}; 6 | 7 | 8 | #[derive(Copy, Clone)] 9 | pub enum CaffeBrew { 10 | CPU, 11 | GPU, 12 | } 13 | 14 | pub struct CaffeRng { 15 | rng: MT19937, 16 | } 17 | 18 | // Random seeding. The c++ source only read '/dev/urandom' to fetch a seed. 19 | // Use `ThreadRng` to get a cross-platform Cryptographically Secure-PRNG 20 | fn cluster_seed_gen() -> u64 { 21 | thread_rng().next_u64() 22 | } 23 | 24 | impl CaffeRng { 25 | pub fn new() -> Self { 26 | CaffeRng { 27 | rng: MT19937::seed_from_u64(cluster_seed_gen()), 28 | } 29 | } 30 | 31 | pub fn new_with_seed(seed: u64) -> Self { 32 | CaffeRng { 33 | rng: MT19937::seed_from_u64(seed), 34 | } 35 | } 36 | 37 | pub fn generator(&mut self) -> &mut dyn RngCore { 38 | &mut self.rng 39 | } 40 | } 41 | 42 | pub struct Caffe { 43 | mode: CaffeBrew, 44 | solver_count: i32, 45 | solver_rank: i32, 46 | multiprocess: bool, 47 | random_generator: Option>>, 48 | } 49 | 50 | impl Caffe { 51 | const fn new() -> Self { 52 | Caffe { 53 | mode: CaffeBrew::CPU, 54 | solver_count: 1, 55 | solver_rank: 0, 56 | multiprocess: false, 57 | random_generator: None, 58 | } 59 | } 60 | } 61 | 62 | thread_local! { 63 | static CAFFE: RefCell = RefCell::new(Caffe::new()); 64 | } 65 | 66 | impl Caffe { 67 | pub fn mode() -> CaffeBrew { 68 | CAFFE.with(|f| { 69 | f.borrow().mode 70 | }) 71 | } 72 | 73 | pub fn set_mode(mode: CaffeBrew) { 74 | CAFFE.with(|f| { 75 | (*f.borrow_mut()).mode = mode; 76 | }); 77 | } 78 | 79 | pub fn set_random_seed(seed: u64) { 80 | // RNG seed 81 | CAFFE.with(|f| { 82 | f.borrow_mut().random_generator.replace(Rc::new(RefCell::new(CaffeRng::new_with_seed(seed)))); 83 | }); 84 | } 85 | 86 | pub fn rng() -> Rc> { 87 | CAFFE.with(|f| { 88 | f.borrow_mut().random_generator 89 | .get_or_insert_with(|| Rc::new(RefCell::new(CaffeRng::new()))) 90 | .clone() 91 | }) 92 | } 93 | 94 | pub fn rng_rand() -> u32 { 95 | CAFFE.with(|f| { 96 | f.borrow_mut().random_generator 97 | .get_or_insert_with(|| Rc::new(RefCell::new(CaffeRng::new()))) 98 | .as_ref().borrow_mut() 99 | .rng.next_u32() 100 | }) 101 | } 102 | 103 | pub fn set_device(device_id: i32) { 104 | // todo: gpu 105 | unimplemented!(); 106 | } 107 | 108 | pub fn device_query() { 109 | // todo: gpu 110 | unimplemented!(); 111 | } 112 | 113 | pub fn check_device(device_id: i32) -> bool { 114 | // todo: gpu 115 | unimplemented!(); 116 | } 117 | 118 | pub fn find_device(start_id: i32) -> i32 { 119 | // todo: gpu 120 | unimplemented!(); 121 | } 122 | 123 | pub fn solver_count() -> i32 { 124 | CAFFE.with(|f| { 125 | f.borrow().solver_count 126 | }) 127 | } 128 | 129 | pub fn set_solver_count(val: i32) { 130 | CAFFE.with(|f| { 131 | f.borrow_mut().solver_count = val; 132 | }); 133 | } 134 | 135 | pub fn solver_rank() -> i32 { 136 | CAFFE.with(|f| { 137 | f.borrow().solver_rank 138 | }) 139 | } 140 | 141 | pub fn set_solver_rank(val: i32) { 142 | CAFFE.with(|f| { 143 | f.borrow_mut().solver_rank = val; 144 | }); 145 | } 146 | 147 | pub fn multiprocess() -> bool { 148 | CAFFE.with(|f| { 149 | f.borrow().multiprocess 150 | }) 151 | } 152 | 153 | pub fn set_multiprocess(val: bool) { 154 | CAFFE.with(|f| { 155 | f.borrow_mut().multiprocess = val; 156 | }); 157 | } 158 | 159 | pub fn root_solver() -> bool { 160 | CAFFE.with(|f| { 161 | f.borrow().solver_rank 162 | }) == 0 163 | } 164 | } 165 | -------------------------------------------------------------------------------- /src/layers/batch_reindex_layer.rs: -------------------------------------------------------------------------------- 1 | use crate::blob::BlobType; 2 | use crate::layer::{LayerImpl, CaffeLayer, BlobVec}; 3 | use crate::proto::caffe::LayerParameter; 4 | use crate::util::math_functions::caffe_set; 5 | 6 | 7 | /// Index into the input blob along its first axis. 8 | /// 9 | /// This layer can be used to select, reorder, and even replicate examples in a 10 | /// batch. The second blob is cast to int and treated as an index into the 11 | /// first axis of the first blob. 12 | pub struct BatchReindexLayer { 13 | layer: LayerImpl, 14 | } 15 | 16 | impl BatchReindexLayer { 17 | pub fn new(param: &LayerParameter) -> Self { 18 | Self { 19 | layer: LayerImpl::new(param), 20 | } 21 | } 22 | 23 | fn check_batch_reindex(&self, initial_num: i32, final_num: usize, ridx_data: &[T]) { 24 | assert!(final_num <= ridx_data.len()); 25 | for i in 0..final_num { 26 | let d = *unsafe { ridx_data.get_unchecked(i) }; 27 | assert!(d >= T::default(), "Index specified for reindex layer was negative."); 28 | assert!(d < T::from_i32(initial_num), "Index specified for reindex layer was greater than batch size."); 29 | } 30 | } 31 | } 32 | 33 | impl CaffeLayer for BatchReindexLayer { 34 | type DataType = T; 35 | 36 | fn get_impl(&self) -> &LayerImpl { 37 | &self.layer 38 | } 39 | 40 | fn get_impl_mut(&mut self) -> &mut LayerImpl { 41 | &mut self.layer 42 | } 43 | 44 | fn reshape(&mut self, bottom: &BlobVec, top: &BlobVec) { 45 | let b0 = bottom[0].as_ref().borrow(); 46 | let b1 = bottom[1].as_ref().borrow(); 47 | assert_eq!(1, b1.num_axes()); 48 | 49 | let shape = b0.shape(); 50 | let mut new_shape = shape.clone(); 51 | if new_shape.is_empty() { 52 | new_shape.push(b1.shape_idx(0)); 53 | } else { 54 | new_shape[0] = b1.shape_idx(0); 55 | } 56 | 57 | top[0].borrow_mut().reshape(&new_shape); 58 | } 59 | 60 | fn layer_type(&self) -> &'static str { 61 | "BatchReindex" 62 | } 63 | 64 | fn exact_num_bottom_blobs(&self) -> i32 { 65 | 2 66 | } 67 | 68 | fn exact_num_top_blobs(&self) -> i32 { 69 | 1 70 | } 71 | 72 | fn forward_cpu(&mut self, bottom: &BlobVec, top: &BlobVec) { 73 | let b0 = bottom[0].as_ref().borrow(); 74 | let b1 = bottom[1].as_ref().borrow(); 75 | self.check_batch_reindex(b0.shape_idx(0), b1.count(), b1.cpu_data()); 76 | 77 | let mut t0 = top[0].borrow_mut(); 78 | let t0_count = t0.count(); 79 | if t0_count == 0 { 80 | return; 81 | } 82 | 83 | let inner_dim = b0.count() / b0.shape_idx(0) as usize; 84 | let d_in = b0.cpu_data(); 85 | let permut = b1.cpu_data(); 86 | let out = t0.mutable_cpu_data(); 87 | for index in 0..t0_count { 88 | let n = index / inner_dim; 89 | let in_n = permut[n].to_usize(); 90 | out[index] = d_in[in_n * inner_dim + index % inner_dim]; 91 | } 92 | } 93 | 94 | fn forward_gpu(&mut self, _bottom: &BlobVec, _top: &BlobVec) { 95 | no_gpu!(); 96 | } 97 | 98 | fn backward_cpu(&mut self, top: &BlobVec, propagate_down: &Vec, bottom: &BlobVec) { 99 | assert!(!propagate_down[1], "Cannot backprop to index."); 100 | if !propagate_down[0] { 101 | return; 102 | } 103 | 104 | let mut b0 = bottom[0].borrow_mut(); 105 | let b1 = bottom[1].as_ref().borrow(); 106 | let t0 = top[0].as_ref().borrow(); 107 | let count = b0.count(); 108 | let inner_dim = count / b0.shape_idx(0) as usize; 109 | let bot_diff = b0.mutable_cpu_diff(); 110 | let permut = b1.cpu_data(); 111 | let top_diff = t0.cpu_diff(); 112 | caffe_set(count, T::default(), bot_diff); 113 | for index in 0..t0.count() { 114 | let n = index / inner_dim; 115 | let in_n = permut[n].to_usize(); 116 | bot_diff[in_n * inner_dim + index % inner_dim] += top_diff[index]; 117 | } 118 | } 119 | 120 | fn backward_gpu(&mut self, _top: &BlobVec, _propagate_down: &Vec, _bottom: &BlobVec) { 121 | no_gpu!(); 122 | } 123 | } 124 | 125 | register_layer_class!(BatchReindex); 126 | -------------------------------------------------------------------------------- /src/layer_factory.rs: -------------------------------------------------------------------------------- 1 | use std::any::{TypeId, Any}; 2 | use std::cell::RefCell; 3 | use std::collections::HashMap; 4 | use std::marker::PhantomData; 5 | use std::rc::Rc; 6 | 7 | use static_init::dynamic; 8 | use paste::paste; 9 | 10 | use crate::proto::caffe::LayerParameter; 11 | use crate::blob::BlobType; 12 | use crate::layer::{Layer, CaffeLayer, LayerImpl, BlobVec, SharedLayer}; 13 | 14 | 15 | pub type LayerCreator = fn(&LayerParameter) -> SharedLayer; 16 | 17 | #[derive(Default)] 18 | struct CreatorRegistry { 19 | registry: HashMap>, 20 | } 21 | 22 | impl CreatorRegistry { 23 | pub fn new() -> Self { 24 | Default::default() 25 | } 26 | 27 | /// Add a creator. 28 | pub fn add_creator(&mut self, ty: &str, creator: LayerCreator) { 29 | assert!(!self.registry.contains_key(ty), "Layer type {:?} already registered.", ty); 30 | 31 | self.registry.insert(ty.to_string(), creator); 32 | } 33 | 34 | /// Get a layer using a `LayerParameter`. 35 | pub fn create_layer(&self, param: &LayerParameter) -> SharedLayer { 36 | let ty = param.get_field_type(); 37 | match self.registry.get(ty) { 38 | Some(creator) => creator(param), 39 | None => panic!("Unknown layer type: {:?} (known types: {:?})", ty, self.layer_type_list_string()), 40 | } 41 | } 42 | 43 | pub fn layer_type_list(&self) -> Vec { 44 | let mut layer_types = Vec::with_capacity(self.registry.len()); 45 | for (k, _) in &self.registry { 46 | layer_types.push(k.clone()); 47 | } 48 | 49 | layer_types 50 | } 51 | 52 | fn layer_type_list_string(&self) -> String { 53 | self.layer_type_list().join(" ,") 54 | } 55 | } 56 | 57 | fn add_creator_impl(ty: &str, creator: LayerCreator) { 58 | let mut lock = REGISTRY.write(); 59 | let registry = lock.entry(TypeId::of::()).or_insert_with(|| Box::new(CreatorRegistry::::new())); 60 | let registry = registry.downcast_mut::>().unwrap(); 61 | registry.add_creator(ty, creator); 62 | } 63 | 64 | fn create_layer_impl(param: &LayerParameter) -> SharedLayer { 65 | let mut lock = REGISTRY.write(); 66 | let registry = lock.entry(TypeId::of::()).or_insert_with(|| Box::new(CreatorRegistry::::new())); 67 | let registry = registry.downcast_ref::>().unwrap(); 68 | registry.create_layer(param) 69 | } 70 | 71 | #[dynamic] 72 | static mut REGISTRY: HashMap> = HashMap::new(); 73 | 74 | pub struct LayerRegistry { 75 | phantom: PhantomData, 76 | } 77 | 78 | impl LayerRegistry { 79 | pub fn new(ty: &str, creator: LayerCreator) -> Self { 80 | Self::add_creator(ty, creator); 81 | LayerRegistry { 82 | phantom: PhantomData 83 | } 84 | } 85 | 86 | pub fn add_creator(ty: &str, creator: LayerCreator) { 87 | add_creator_impl(ty, creator); 88 | } 89 | 90 | pub fn create_layer(param: &LayerParameter) -> SharedLayer { 91 | create_layer_impl(param) 92 | } 93 | } 94 | 95 | 96 | #[macro_export] 97 | macro_rules! register_layer_creator { 98 | ($t:ident, $creator:path) => { 99 | paste! { 100 | #[dynamic(init)] 101 | static []: $crate::layer_factory::LayerRegistry = 102 | $crate::layer_factory::LayerRegistry::::new(stringify!($t), $creator); 103 | 104 | #[dynamic(init)] 105 | static []: $crate::layer_factory::LayerRegistry = 106 | $crate::layer_factory::LayerRegistry::::new(stringify!($t), $creator); 107 | } 108 | }; 109 | } 110 | 111 | #[macro_export] 112 | macro_rules! register_layer_class { 113 | ($t:ident) => { 114 | paste! { 115 | pub fn [](param: &crate::proto::caffe::LayerParameter) 116 | -> std::rc::Rc>> { 117 | std::rc::Rc::new(std::cell::RefCell::new($crate::layer::Layer::new(Box::new([<$t Layer>]::::new(param))))) 118 | } 119 | 120 | register_layer_creator!($t, self::[]); 121 | } 122 | }; 123 | } 124 | 125 | 126 | #[cfg(test)] 127 | mod test { 128 | use super::*; 129 | 130 | fn test(p: &LayerParameter) -> SharedLayer { 131 | unimplemented!(); 132 | } 133 | 134 | static TTT: LayerCreator = test::; 135 | 136 | register_layer_creator!(PhantomData, self::test); 137 | 138 | struct TestLayer { 139 | phantom: PhantomData, 140 | } 141 | 142 | impl CaffeLayer for TestLayer { 143 | type DataType = T; 144 | 145 | fn get_impl(&self) -> &LayerImpl { 146 | todo!() 147 | } 148 | 149 | fn get_impl_mut(&mut self) -> &mut LayerImpl { 150 | todo!() 151 | } 152 | 153 | fn reshape(&mut self, bottom: &BlobVec, top: &BlobVec) { 154 | todo!() 155 | } 156 | 157 | fn forward_cpu(&mut self, bottom: &BlobVec, top: &BlobVec) { 158 | todo!() 159 | } 160 | 161 | fn backward_cpu(&mut self, top: &BlobVec, propagate_down: &Vec, bottom: &BlobVec) { 162 | todo!() 163 | } 164 | } 165 | 166 | impl TestLayer { 167 | pub fn new(_param: &LayerParameter) -> Self { 168 | TestLayer { 169 | phantom: PhantomData 170 | } 171 | } 172 | } 173 | 174 | register_layer_class!(Test); 175 | } 176 | -------------------------------------------------------------------------------- /src/util/insert_splits.rs: -------------------------------------------------------------------------------- 1 | use std::collections::HashMap; 2 | 3 | use protobuf::{Chars, Clear}; 4 | 5 | use crate::proto::caffe::{NetParameter, LayerParameter}; 6 | 7 | 8 | /// Copy NetParameters with SplitLayers added to replace any shared bottom 9 | /// blobs with unique bottom blobs provided by the SplitLayer. 10 | pub fn insert_splits(param: &NetParameter, param_split: &mut NetParameter) { 11 | // Initialize by copying from the input NetParameter. 12 | param_split.clone_from(param); 13 | param_split.clear_layer(); 14 | 15 | let mut blob_name_to_last_top_idx = HashMap::::new(); 16 | let mut bottom_idx_to_source_top_idx = HashMap::<(usize, usize), (usize, usize)>::new(); 17 | let mut top_idx_to_bottom_count = HashMap::<(usize, usize), usize>::new(); 18 | let mut top_idx_to_loss_weight = HashMap::<(usize, usize), f32>::new(); 19 | let mut top_idx_to_bottom_split_idx = HashMap::<(usize, usize), usize>::new(); 20 | let mut layer_idx_to_layer_name = HashMap::::new(); 21 | 22 | for i in 0..param.get_layer().len() { 23 | let layer_param = ¶m.get_layer()[i]; 24 | layer_idx_to_layer_name.insert(i, layer_param.get_name().to_string()); 25 | for j in 0..layer_param.get_bottom().len() { 26 | let blob_name: &str = layer_param.get_bottom()[j].as_ref(); 27 | if !blob_name_to_last_top_idx.contains_key(blob_name) { 28 | assert!(false, "Unknown bottom blob '{:?}' (layer '{:?}', bottom index {:?})", 29 | blob_name, layer_param.get_name(), j); 30 | } 31 | 32 | let top_idx = blob_name_to_last_top_idx[blob_name]; 33 | bottom_idx_to_source_top_idx.insert((i, j), top_idx); 34 | *top_idx_to_bottom_count.entry(top_idx).or_default() += 1; 35 | } 36 | for j in 0..layer_param.get_top().len() { 37 | let blob_name: &str = layer_param.get_top()[j].as_ref(); 38 | blob_name_to_last_top_idx.insert(blob_name.to_string(), (i, j)); 39 | } 40 | // A use of a top blob as a loss should be handled similarly to the use of a top blob as 41 | // a bottom blob to another layer. 42 | let last_loss = std::cmp::min(layer_param.get_loss_weight().len(), layer_param.get_top().len()); 43 | for j in 0..last_loss { 44 | let blob_name: &str = layer_param.get_top()[j].as_ref(); 45 | let top_idx = blob_name_to_last_top_idx[blob_name]; 46 | let loss = layer_param.get_loss_weight()[j]; 47 | top_idx_to_loss_weight.insert(top_idx, loss); 48 | if loss != 0f32 { 49 | *top_idx_to_bottom_count.entry(top_idx).or_default() += 1; 50 | } 51 | } 52 | } 53 | 54 | for i in 0..param.get_layer().len() { 55 | let layer_top_len; 56 | let layer_idx; 57 | 58 | { 59 | layer_idx = param_split.get_layer().len(); 60 | let layer_param = param_split.mut_layer().push_default(); 61 | layer_param.clone_from(¶m.get_layer()[i]); 62 | layer_top_len = layer_param.get_top().len(); 63 | 64 | // Replace any shared bottom blobs with split layer outputs. 65 | for j in 0..layer_param.get_bottom().len() { 66 | let top_idx = bottom_idx_to_source_top_idx[&(i, j)]; 67 | let split_count = top_idx_to_bottom_count[&top_idx]; 68 | if split_count > 1 { 69 | let layer_name = layer_idx_to_layer_name[&top_idx.0].as_str(); 70 | let split_idx = top_idx_to_bottom_split_idx.entry(top_idx).or_default(); 71 | let value = split_blob_name(layer_name, layer_param.get_bottom()[j].as_ref(), 72 | top_idx.1, *split_idx); 73 | *split_idx += 1; 74 | layer_param.mut_bottom()[j] = Chars::from(value); 75 | } 76 | } 77 | } 78 | 79 | // Create split layer for any top blobs used by other layers as bottom blobs more than once. 80 | for j in 0..layer_top_len { 81 | let top_idx = (i, j); 82 | let split_count = top_idx_to_bottom_count.get(&top_idx).map_or(0usize, |v| *v); 83 | if split_count > 1 { 84 | let layer_name = layer_idx_to_layer_name[&i].as_str(); 85 | let mut split_layer_param = LayerParameter::new(); 86 | let loss_weight = top_idx_to_loss_weight.get(&top_idx).map_or(0f32, |v| *v); 87 | { 88 | let blob_name: &str = param_split.get_layer()[layer_idx].get_top()[j].as_ref(); 89 | configure_split_layer(layer_name, blob_name, j, split_count, loss_weight, &mut split_layer_param); 90 | } 91 | param_split.mut_layer().push(split_layer_param); 92 | 93 | if loss_weight != 0f32 { 94 | param_split.mut_layer()[layer_idx].clear_loss_weight(); 95 | *top_idx_to_bottom_split_idx.entry(top_idx).or_default() += 1; 96 | } 97 | } 98 | } 99 | } 100 | } 101 | 102 | pub fn configure_split_layer(layer_name: &str, blob_name: &str, blob_idx: usize, split_count: usize, 103 | loss_weight: f32, split_layer_param: &mut LayerParameter) { 104 | split_layer_param.clear(); 105 | split_layer_param.mut_bottom().push(Chars::from(blob_name)); 106 | split_layer_param.set_name(Chars::from(split_layer_name(layer_name, blob_name, blob_idx))); 107 | split_layer_param.set_field_type(Chars::from("Split")); 108 | 109 | for k in 0..split_count { 110 | split_layer_param.mut_top().push(Chars::from(split_blob_name(layer_name, blob_name, blob_idx, k))); 111 | if loss_weight != 0f32 { 112 | if k == 0 { 113 | split_layer_param.mut_loss_weight().push(loss_weight); 114 | } else { 115 | split_layer_param.mut_loss_weight().push(0f32); 116 | } 117 | } 118 | } 119 | } 120 | 121 | pub fn split_layer_name(layer_name: &str, blob_name: &str, blob_idx: usize) -> String { 122 | format!("{}_{}_{}_split", blob_name, layer_name, blob_idx) 123 | } 124 | 125 | pub fn split_blob_name(layer_name: &str, blob_name: &str, blob_idx: usize, split_idx: usize) -> String { 126 | format!("{}_{}_{}_split_{}", blob_name, layer_name, blob_idx, split_idx) 127 | } 128 | -------------------------------------------------------------------------------- /src/layers/concat_layer.rs: -------------------------------------------------------------------------------- 1 | use crate::blob::{BlobType, MAX_BLOB_AXES}; 2 | use crate::layer::{LayerImpl, CaffeLayer, BlobVec, def_layer_setup}; 3 | use crate::proto::caffe::LayerParameter; 4 | use crate::util::math_functions::caffe_copy; 5 | 6 | 7 | /// Takes at least two `Blob`s and concatenates them along either the num 8 | /// or channel dimension, outputting the result. 9 | pub struct ConcatLayer { 10 | layer: LayerImpl, 11 | count: i32, 12 | num_concats: i32, 13 | concat_input_size: i32, 14 | concat_axis: i32, 15 | } 16 | 17 | impl ConcatLayer { 18 | pub fn new(param: &LayerParameter) -> Self { 19 | Self { 20 | layer: LayerImpl::new(param), 21 | count: 0, 22 | num_concats: 0, 23 | concat_input_size: 0, 24 | concat_axis: 0, 25 | } 26 | } 27 | } 28 | 29 | impl CaffeLayer for ConcatLayer { 30 | type DataType = T; 31 | 32 | fn get_impl(&self) -> &LayerImpl { 33 | &self.layer 34 | } 35 | 36 | fn get_impl_mut(&mut self) -> &mut LayerImpl { 37 | &mut self.layer 38 | } 39 | 40 | fn layer_setup(&mut self, bottom: &BlobVec, top: &BlobVec) { 41 | let concat_param = self.layer.layer_param.get_concat_param(); 42 | assert!(!(concat_param.has_axis() && concat_param.has_concat_dim()), 43 | "Either axis or concat_dim should be specified; not both."); 44 | } 45 | 46 | fn reshape(&mut self, bottom: &BlobVec, top: &BlobVec) { 47 | let b0 = bottom[0].as_ref().borrow(); 48 | let num_axes = b0.num_axes(); 49 | let concat_param = self.layer.layer_param.get_concat_param(); 50 | if concat_param.has_concat_dim() { 51 | self.concat_axis = concat_param.get_concat_dim() as i32; 52 | assert!(self.concat_axis >= 0, "casting concat_dim from uint32 to int32 produced \ 53 | negative result; concat_dim must satisfy 0 <= concat_dim < {}", MAX_BLOB_AXES); 54 | assert!(self.concat_axis < num_axes, "concat_dim out of range."); 55 | } else { 56 | self.concat_axis = b0.canonical_axis_index(concat_param.get_axis()) as i32; 57 | } 58 | 59 | // Initialize with the first blob. 60 | let mut top_shape = b0.shape().clone(); 61 | self.num_concats = b0.count_range(0, self.concat_axis as usize); 62 | self.concat_input_size = b0.count_range_to_end((self.concat_axis + 1) as usize); 63 | let mut bottom_count_sum = b0.count(); 64 | for i in 1..bottom.len() { 65 | let bi = bottom[i].as_ref().borrow(); 66 | assert_eq!(num_axes, bi.num_axes(), "All inputs must have the same #axes."); 67 | for j in 0..num_axes { 68 | if j == self.concat_axis { continue; } 69 | assert_eq!(top_shape[j as usize], bi.shape_idx(j), 70 | "All inputs must have the same shape, except at concat_axis."); 71 | } 72 | 73 | bottom_count_sum += bi.count(); 74 | top_shape[self.concat_axis as usize] += bi.shape_idx(self.concat_axis); 75 | } 76 | 77 | let mut t0 = top[0].borrow_mut(); 78 | let top_count = t0.count(); 79 | t0.reshape(&top_shape); 80 | assert_eq!(bottom_count_sum, top_count); 81 | if bottom.len() == 1 { 82 | t0.share_data(&*b0); 83 | t0.share_diff(&*b0); 84 | } 85 | } 86 | 87 | fn layer_type(&self) -> &'static str { 88 | "Concat" 89 | } 90 | 91 | fn min_bottom_blobs(&self) -> i32 { 92 | 1 93 | } 94 | 95 | fn exact_num_top_blobs(&self) -> i32 { 96 | 1 97 | } 98 | 99 | fn forward_cpu(&mut self, bottom: &BlobVec, top: &BlobVec) { 100 | if bottom.len() == 1 { 101 | return; 102 | } 103 | 104 | let mut t0 = top[0].borrow_mut(); 105 | let top_concat_axis = t0.shape_idx(self.concat_axis); 106 | let top_data = t0.mutable_cpu_data(); 107 | let mut offset_concat_axis = 0; 108 | for bi in bottom { 109 | let blob = bi.as_ref().borrow(); 110 | let bottom_data = blob.cpu_data(); 111 | let bottom_concat_axis = blob.shape_idx(self.concat_axis); 112 | for n in 0..self.num_concats { 113 | let size = (bottom_concat_axis * self.concat_input_size) as usize; 114 | let top_offset = ((n * top_concat_axis + offset_concat_axis) * self.concat_input_size) as usize; 115 | caffe_copy(size, &bottom_data[(n as usize * size)..], &mut top_data[top_offset..]); 116 | } 117 | offset_concat_axis += bottom_concat_axis; 118 | } 119 | } 120 | 121 | fn forward_gpu(&mut self, _bottom: &BlobVec, _top: &BlobVec) { 122 | no_gpu!(); 123 | } 124 | 125 | fn backward_cpu(&mut self, top: &BlobVec, propagate_down: &Vec, 126 | bottom: &BlobVec) { 127 | if bottom.len() == 1 { 128 | return; 129 | } 130 | 131 | let t0 = top[0].as_ref().borrow(); 132 | let top_diff = t0.cpu_diff(); 133 | let mut offset_concat_axis = 0; 134 | let top_concat_axis = t0.shape_idx(self.concat_axis); 135 | for i in 0..bottom.len() { 136 | let mut blob = bottom[i].borrow_mut(); 137 | let bottom_concat_axis = blob.shape_idx(self.concat_axis); 138 | if propagate_down[i] { 139 | let bottom_diff = blob.mutable_cpu_diff(); 140 | for n in 0..self.num_concats { 141 | let size = (bottom_concat_axis * self.concat_input_size) as usize; 142 | let top_offset = ((n * top_concat_axis + offset_concat_axis) * self.concat_input_size) as usize; 143 | caffe_copy(size, &top_diff[top_offset..], &mut bottom_diff[(n as usize * size)..]); 144 | } 145 | } 146 | offset_concat_axis += bottom_concat_axis; 147 | } 148 | } 149 | 150 | fn backward_gpu(&mut self, _top: &BlobVec, _propagate_down: &Vec, 151 | _bottom: &BlobVec) { 152 | no_gpu!(); 153 | } 154 | } 155 | 156 | register_layer_class!(Concat); 157 | -------------------------------------------------------------------------------- /src/layers/argmax_layer.rs: -------------------------------------------------------------------------------- 1 | use std::cmp::Ordering; 2 | 3 | use crate::blob::BlobType; 4 | use crate::layer::{CaffeLayer, LayerImpl, BlobVec}; 5 | use crate::proto::caffe::LayerParameter; 6 | 7 | 8 | /// Compute the index of the $ K $ max values for each datum across all dimensions 9 | /// $ (C \times H \times W) $. 10 | /// 11 | /// Intended for use after a classification layer to produce a prediction. If parameter 12 | /// out_max_val is set to true, output is a vector of pairs (max_ind, max_val) for each 13 | /// image. The axis parameter specifies an axis along which to maximise. 14 | /// 15 | /// **NOTE**: **does not** implement `Backwards` operation. 16 | pub struct ArgMaxLayer { 17 | layer: LayerImpl, 18 | out_max_val: bool, 19 | top_k: usize, 20 | has_axis: bool, 21 | axis: i32, 22 | } 23 | 24 | impl ArgMaxLayer { 25 | /// `param` provides **ArgMaxParameter** argmax_param, with **ArgMaxLayer options**: 26 | /// - top_k (**optional uint, default `1`**). the number $ K $ of maximal items to output. 27 | /// - out_max_val (**optional bool, default `false`**). if set, output a vector of pairs 28 | /// (max_ind, max_val) unless axis is set then output max_val along the specified axis. 29 | /// - axis (**optional int**). if set, maximise along the specified axis else maximise 30 | /// the flattened trailing dimensions for each index of the first / num dimension. 31 | pub fn new(param: &LayerParameter) -> Self { 32 | ArgMaxLayer { 33 | layer: LayerImpl::new(param), 34 | out_max_val: false, 35 | top_k: 0, 36 | has_axis: false, 37 | axis: 0, 38 | } 39 | } 40 | } 41 | 42 | impl CaffeLayer for ArgMaxLayer { 43 | type DataType = T; 44 | 45 | fn get_impl(&self) -> &LayerImpl { 46 | &self.layer 47 | } 48 | 49 | fn get_impl_mut(&mut self) -> &mut LayerImpl { 50 | &mut self.layer 51 | } 52 | 53 | fn layer_setup(&mut self, bottom: &BlobVec, top: &BlobVec) { 54 | let argmax_param = self.layer.layer_param.get_argmax_param(); 55 | self.out_max_val = argmax_param.get_out_max_val(); 56 | self.top_k = argmax_param.get_top_k() as usize; 57 | self.has_axis = argmax_param.has_axis(); 58 | check_ge!(self.top_k, 1, "top k must not be less than 1."); 59 | 60 | if self.has_axis { 61 | let b0 = bottom[0].as_ref().borrow(); 62 | self.axis = b0.canonical_axis_index(argmax_param.get_axis()) as i32; 63 | check_ge!(self.axis, 0, "axis must not be less than 0."); 64 | check_le!(self.axis, b0.num_axes(), "axis must be less than or equal to the number of axis."); 65 | check_le!(self.top_k as i32, b0.shape_idx(self.axis), 66 | "top_k must be less than or equal to the dimension of the axis."); 67 | } else { 68 | let count = bottom[0].as_ref().borrow().count_range_to_end(1); 69 | check_le!(self.top_k as i32, count, "top_k must be less than or equal to the \ 70 | dimension of the flattened bottom blob per instance."); 71 | } 72 | } 73 | 74 | fn reshape(&mut self, bottom: &BlobVec, top: &BlobVec) { 75 | let b0 = bottom[0].as_ref().borrow(); 76 | let mut num_top_axes = b0.num_axes(); 77 | if num_top_axes < 3 { 78 | num_top_axes = 3; 79 | } 80 | 81 | let mut shape; 82 | if self.has_axis { 83 | // Produces max_ind or max_val per axis 84 | shape = b0.shape().clone(); 85 | shape[self.axis as usize] = self.top_k as i32; 86 | } else { 87 | let num_top_axes = num_top_axes as usize; 88 | shape = Vec::with_capacity(num_top_axes); 89 | shape.resize(num_top_axes, 1); 90 | shape[0] = b0.shape_idx(0); 91 | // Produces max_ind 92 | shape[2] = self.top_k as i32; 93 | if self.out_max_val { 94 | // Produces max_ind and max_val 95 | shape[1] = 2; 96 | } 97 | } 98 | 99 | top[0].borrow_mut().reshape(&shape); 100 | } 101 | 102 | fn layer_type(&self) -> &'static str { 103 | "ArgMax" 104 | } 105 | 106 | fn exact_num_bottom_blobs(&self) -> i32 { 107 | 1 108 | } 109 | 110 | fn exact_num_top_blobs(&self) -> i32 { 111 | 1 112 | } 113 | 114 | /// `bottom` input Blob vector (length 1). 115 | /// - $ (N \times C \times H \times W) $ the inputs $ x $. 116 | /// 117 | /// `top` output Blob vector (length 1). 118 | /// - $ (N \times 1 \times K) $ or, if `out_max_val` $ (N \times 2 \times K) $ unless `axis` 119 | /// is set than e.g. $ (N \times K \times H \times W) $ if `axis == 1` the computed outputs 120 | /// $$ y\_n = \arg\max\limits\_i x\_{ni} $$ (for $ K = 1 $). 121 | fn forward_cpu(&mut self, bottom: &BlobVec, top: &BlobVec) { 122 | let b0 = bottom[0].as_ref().borrow(); 123 | let mut t0 = top[0].borrow_mut(); 124 | let bottom_data = b0.cpu_data(); 125 | let top_data = t0.mutable_cpu_data(); 126 | let dim; 127 | let axis_dist; 128 | if self.has_axis { 129 | dim = b0.shape_idx(self.axis); 130 | // Distance between values of axis in blob 131 | axis_dist = b0.count_range_to_end(self.axis as usize) / dim; 132 | } else { 133 | dim = b0.count_range_to_end(1); 134 | axis_dist = 1; 135 | } 136 | 137 | let dim = dim as usize; 138 | let axis_dist = axis_dist as usize; 139 | let num = b0.count() / dim; 140 | let mut bottom_data_vector: Vec<(T, usize)> = Vec::with_capacity(dim); 141 | bottom_data_vector.resize(dim, Default::default()); 142 | let cmp = |a: &(T, usize), b: &(T, usize)| { 143 | // Treat the `NAN` as the minimum value. 144 | b.partial_cmp(a).unwrap_or_else(|| if (*a).0.is_nan_v() { Ordering::Greater } else { Ordering::Less }) 145 | }; 146 | for i in 0..num { 147 | for j in 0..dim { 148 | bottom_data_vector[j] = (bottom_data[(i / axis_dist * dim + j) * axis_dist + i % axis_dist], j); 149 | } 150 | // C++ std::partial_sort 151 | let (first, _, _) = bottom_data_vector.select_nth_unstable_by(self.top_k, &cmp); 152 | first.sort_unstable_by(&cmp); 153 | 154 | for j in 0..self.top_k { 155 | if self.out_max_val { 156 | if self.has_axis { 157 | // Produces max_val per axis 158 | top_data[(i / axis_dist * self.top_k + j) * axis_dist + i % axis_dist] = 159 | bottom_data_vector[j].0; 160 | } else { 161 | // Produces max_ind and max_val 162 | top_data[2usize * i * self.top_k + j] = T::from_usize(bottom_data_vector[j].1); 163 | top_data[2usize * i * self.top_k + self.top_k + j] = bottom_data_vector[j].0; 164 | } 165 | } else { 166 | // Produces max_ind per axis 167 | top_data[(i / axis_dist * self.top_k + j) * axis_dist + i % axis_dist] = 168 | T::from_usize(bottom_data_vector[j].1); 169 | } 170 | } 171 | } 172 | } 173 | 174 | /// Not implemented (non-differentiable function). 175 | fn backward_cpu(&mut self, _top: &BlobVec, _propagate_down: &Vec, _bottom: &BlobVec) { 176 | unimplemented!(); 177 | } 178 | } 179 | 180 | register_layer_class!(ArgMax); 181 | -------------------------------------------------------------------------------- /src/util/mkl_alternate.rs: -------------------------------------------------------------------------------- 1 | use cblas::{sscal, dscal, saxpy, daxpy}; 2 | 3 | // Functions that caffe uses but are not present if MKL is not linked. 4 | 5 | 6 | /// Simple macro to generate an unsafe loop on two slice with given loop size. 7 | macro_rules! check_loop_unsafe { 8 | ($i:ident, $n:tt, $ex:expr) => { 9 | for $i in 0..$n { 10 | // SAFETY: the `assert` check guards that the index is not out-of-bounds. 11 | unsafe { 12 | $ex; 13 | } 14 | } 15 | }; 16 | ($i:ident, $n:tt, $a:tt, $y:tt, $ex:expr) => { 17 | assert!($a.len() >= $n && $y.len() >= $n); 18 | check_loop_unsafe!($i, $n, $ex); 19 | }; 20 | ($i:ident, $n:tt, $a:tt, $b:tt, $y:tt, $ex:expr) => { 21 | assert!($a.len() >= $n && $b.len() >= $n && $y.len() >= $n); 22 | check_loop_unsafe!($i, $n, $ex); 23 | }; 24 | } 25 | 26 | // A simple way to define the vsl unary functions. The operation should be in the 27 | // form e.g. y[i] = sqrt(a[i]) 28 | 29 | pub fn vs_sqr(n: usize, a: &[f32], y: &mut [f32]) { 30 | check_loop_unsafe!(i, n, a, y, *y.get_unchecked_mut(i) = *a.get_unchecked(i) * *a.get_unchecked(i)); 31 | } 32 | 33 | pub fn vd_sqr(n: usize, a: &[f64], y: &mut [f64]) { 34 | check_loop_unsafe!(i, n, a, y, *y.get_unchecked_mut(i) = *a.get_unchecked(i) * *a.get_unchecked(i)); 35 | } 36 | 37 | pub fn vs_sqrt(n: usize, a: &[f32], y: &mut [f32]) { 38 | check_loop_unsafe!(i, n, a, y, *y.get_unchecked_mut(i) = a.get_unchecked(i).sqrt()); 39 | } 40 | 41 | pub fn vd_sqrt(n: usize, a: &[f64], y: &mut [f64]) { 42 | check_loop_unsafe!(i, n, a, y, *y.get_unchecked_mut(i) = a.get_unchecked(i).sqrt()); 43 | } 44 | 45 | pub fn vs_exp(n: usize, a: &[f32], y: &mut [f32]) { 46 | check_loop_unsafe!(i, n, a, y, *y.get_unchecked_mut(i) = a.get_unchecked(i).exp()); 47 | } 48 | 49 | pub fn vd_exp(n: usize, a: &[f64], y: &mut [f64]) { 50 | check_loop_unsafe!(i, n, a, y, *y.get_unchecked_mut(i) = a.get_unchecked(i).exp()); 51 | } 52 | 53 | pub fn vs_ln(n: usize, a: &[f32], y: &mut [f32]) { 54 | check_loop_unsafe!(i, n, a, y, *y.get_unchecked_mut(i) = a.get_unchecked(i).ln()); 55 | } 56 | 57 | pub fn vd_ln(n: usize, a: &[f64], y: &mut [f64]) { 58 | check_loop_unsafe!(i, n, a, y, *y.get_unchecked_mut(i) = a.get_unchecked(i).ln()); 59 | } 60 | 61 | pub fn vs_abs(n: usize, a: &[f32], y: &mut [f32]) { 62 | check_loop_unsafe!(i, n, a, y, *y.get_unchecked_mut(i) = a.get_unchecked(i).abs()); 63 | } 64 | 65 | pub fn vd_abs(n: usize, a: &[f64], y: &mut [f64]) { 66 | check_loop_unsafe!(i, n, a, y, *y.get_unchecked_mut(i) = a.get_unchecked(i).abs()); 67 | } 68 | 69 | /// Output is 1 for the positives, 0 for zero, and -1 for the negatives. 70 | pub fn vs_sign(n: usize, a: &[f32], y: &mut [f32]) { 71 | check_loop_unsafe!(i, n, a, y, *y.get_unchecked_mut(i) = vs_get_sign(*a.get_unchecked(i)) as f32); 72 | } 73 | 74 | /// Output is 1 for the positives, 0 for zero, and -1 for the negatives. 75 | pub fn vd_sign(n: usize, a: &[f64], y: &mut [f64]) { 76 | check_loop_unsafe!(i, n, a, y, *y.get_unchecked_mut(i) = vd_get_sign(*a.get_unchecked(i)) as f64); 77 | } 78 | 79 | /// Returns 1 if the input has its sign bit set (is negative, include -0.0, NAN with neg sign). 80 | pub fn vs_sgn_bit(n: usize, a: &[f32], y: &mut [f32]) { 81 | check_loop_unsafe!(i, n, a, y, *y.get_unchecked_mut(i) = a.get_unchecked(i).is_sign_negative() as i32 as f32); 82 | } 83 | 84 | /// Returns 1 if the input has its sign bit set (is negative, include -0.0, NAN with neg sign). 85 | pub fn vd_sgn_bit(n: usize, a: &[f64], y: &mut [f64]) { 86 | check_loop_unsafe!(i, n, a, y, *y.get_unchecked_mut(i) = a.get_unchecked(i).is_sign_negative() as i32 as f64); 87 | } 88 | 89 | pub fn vs_fabs(n: usize, a: &[f32], y: &mut [f32]) { 90 | check_loop_unsafe!(i, n, a, y, *y.get_unchecked_mut(i) = a.get_unchecked(i).abs()); 91 | } 92 | 93 | pub fn vd_fabs(n: usize, a: &[f64], y: &mut [f64]) { 94 | check_loop_unsafe!(i, n, a, y, *y.get_unchecked_mut(i) = a.get_unchecked(i).abs()); 95 | } 96 | 97 | // A simple way to define the vsl unary functions with singular parameter b. 98 | // The operation should be in the form e.g. y[i] = pow(a[i], b) 99 | 100 | pub fn vs_powx(n: usize, a: &[f32], b: f32, y: &mut [f32]) { 101 | check_loop_unsafe!(i, n, a, y, *y.get_unchecked_mut(i) = a.get_unchecked(i).powf(b)); 102 | } 103 | 104 | pub fn vd_powx(n: usize, a: &[f64], b: f64, y: &mut [f64]) { 105 | check_loop_unsafe!(i, n, a, y, *y.get_unchecked_mut(i) = a.get_unchecked(i).powf(b)); 106 | } 107 | 108 | // A simple way to define the vsl binary functions. The operation should be in the 109 | // form e.g. y[i] = a[i] + b[i] 110 | 111 | pub fn vs_add(n: usize, a: &[f32], b: &[f32], y: &mut [f32]) { 112 | check_loop_unsafe!(i, n, a, b, y, *y.get_unchecked_mut(i) = *a.get_unchecked(i) + *b.get_unchecked(i)); 113 | } 114 | 115 | pub fn vd_add(n: usize, a: &[f64], b: &[f64], y: &mut [f64]) { 116 | check_loop_unsafe!(i, n, a, b, y, *y.get_unchecked_mut(i) = *a.get_unchecked(i) + *b.get_unchecked(i)); 117 | } 118 | 119 | pub fn vs_sub(n: usize, a: &[f32], b: &[f32], y: &mut [f32]) { 120 | check_loop_unsafe!(i, n, a, b, y, *y.get_unchecked_mut(i) = *a.get_unchecked(i) - *b.get_unchecked(i)); 121 | } 122 | 123 | pub fn vd_sub(n: usize, a: &[f64], b: &[f64], y: &mut [f64]) { 124 | check_loop_unsafe!(i, n, a, b, y, *y.get_unchecked_mut(i) = *a.get_unchecked(i) - *b.get_unchecked(i)); 125 | } 126 | 127 | pub fn vs_mul(n: usize, a: &[f32], b: &[f32], y: &mut [f32]) { 128 | check_loop_unsafe!(i, n, a, b, y, *y.get_unchecked_mut(i) = *a.get_unchecked(i) * *b.get_unchecked(i)); 129 | } 130 | 131 | pub fn vd_mul(n: usize, a: &[f64], b: &[f64], y: &mut [f64]) { 132 | check_loop_unsafe!(i, n, a, b, y, *y.get_unchecked_mut(i) = *a.get_unchecked(i) * *b.get_unchecked(i)); 133 | } 134 | 135 | pub fn vs_div(n: usize, a: &[f32], b: &[f32], y: &mut [f32]) { 136 | check_loop_unsafe!(i, n, a, b, y, *y.get_unchecked_mut(i) = *a.get_unchecked(i) / *b.get_unchecked(i)); 137 | } 138 | 139 | pub fn vd_div(n: usize, a: &[f64], b: &[f64], y: &mut [f64]) { 140 | check_loop_unsafe!(i, n, a, b, y, *y.get_unchecked_mut(i) = *a.get_unchecked(i) / *b.get_unchecked(i)); 141 | } 142 | 143 | // AssignOps impls. The operation is in the form e.g. y[i] *= a[i] 144 | 145 | pub fn vs_sub_assign(n: usize, y: &mut [f32], a: &[f32]) { 146 | check_loop_unsafe!(i, n, a, y, *y.get_unchecked_mut(i) -= *a.get_unchecked(i)); 147 | } 148 | 149 | pub fn vd_sub_assign(n: usize, y: &mut [f64], a: &[f64]) { 150 | check_loop_unsafe!(i, n, a, y, *y.get_unchecked_mut(i) -= *a.get_unchecked(i)); 151 | } 152 | 153 | pub fn vs_mul_assign(n: usize, y: &mut [f32], a: &[f32]) { 154 | check_loop_unsafe!(i, n, a, y, *y.get_unchecked_mut(i) *= *a.get_unchecked(i)); 155 | } 156 | 157 | pub fn vd_mul_assign(n: usize, y: &mut [f64], a: &[f64]) { 158 | check_loop_unsafe!(i, n, a, y, *y.get_unchecked_mut(i) *= *a.get_unchecked(i)); 159 | } 160 | 161 | // In addition, MKL comes with an additional function axpby that is not present in standard 162 | // blas. We will simply use a two-step (inefficient, of course) way to mimic that. 163 | pub fn cblas_saxpby(n: i32, alpha: f32, x: &[f32], inc_x: i32, beta: f32, y: &mut [f32], inc_y: i32) { 164 | unsafe { 165 | sscal(n, beta, y, inc_y); 166 | saxpy(n, alpha, x, inc_x, y, inc_y); 167 | } 168 | } 169 | 170 | pub fn cblas_daxpby(n: i32, alpha: f64, x: &[f64], inc_x: i32, beta: f64, y: &mut [f64], inc_y: i32) { 171 | unsafe { 172 | dscal(n, beta, y, inc_y); 173 | daxpy(n, alpha, x, inc_x, y, inc_y); 174 | } 175 | } 176 | 177 | // Other dependent functions. 178 | // mark: maybe use a generic impl. 179 | /// Output is 1 for the positives, 0 for zero, and -1 for the negatives. 180 | pub fn vs_get_sign(val: f32) -> i8 { 181 | ((0f32 - val) as i8) - ((val < 0f32) as i8) 182 | } 183 | 184 | /// Output is 1 for the positives, 0 for zero, and -1 for the negatives. 185 | pub fn vd_get_sign(val: f64) -> i8 { 186 | ((0f64 < val) as i8) - ((val < 0f64) as i8) 187 | } 188 | -------------------------------------------------------------------------------- /src/layers/bias_layer.rs: -------------------------------------------------------------------------------- 1 | use std::rc::Rc; 2 | 3 | use cblas::Transpose; 4 | 5 | use crate::blob::{BlobType, Blob}; 6 | use crate::filler::get_filler; 7 | use crate::layer::{LayerImpl, CaffeLayer, BlobVec, def_layer_setup, make_shared_blob}; 8 | use crate::proto::caffe::LayerParameter; 9 | use crate::util::math_functions::{caffe_set, caffe_copy}; 10 | 11 | 12 | pub struct BiasLayer { 13 | layer: LayerImpl, 14 | bias_multiplier: Blob, 15 | outer_dim: i32, 16 | bias_dim: i32, 17 | inner_dim: i32, 18 | dim: i32, 19 | } 20 | 21 | impl BiasLayer { 22 | pub fn new(param: &LayerParameter) -> Self { 23 | Self { 24 | layer: LayerImpl::new(param), 25 | bias_multiplier: Blob::new(), 26 | outer_dim: 0, 27 | bias_dim: 0, 28 | inner_dim: 0, 29 | dim: 0, 30 | } 31 | } 32 | } 33 | 34 | impl CaffeLayer for BiasLayer { 35 | type DataType = T; 36 | 37 | fn get_impl(&self) -> &LayerImpl { 38 | &self.layer 39 | } 40 | 41 | fn get_impl_mut(&mut self) -> &mut LayerImpl { 42 | & mut self.layer 43 | } 44 | 45 | fn layer_setup(&mut self, bottom: &BlobVec, top: &BlobVec) { 46 | if bottom.len() == 1 && !self.layer.blobs.is_empty() { 47 | info!("Skipping parameter initialization"); 48 | } else if bottom.len() == 1 { 49 | // bias is a learned parameter; initialize it 50 | let param = self.layer.layer_param.get_bias_param(); 51 | let b0 = bottom[0].as_ref().borrow(); 52 | let axis = b0.canonical_axis_index(param.get_axis()); 53 | let num_axes = param.get_num_axes(); 54 | assert!(num_axes >= -1, "num_axes must be non-negative, or -1 to extend to the end of bottom[0]"); 55 | if num_axes >= 0 { 56 | assert!(b0.num_axes() >= (axis as i32 + num_axes), 57 | "bias blob's shape extends past bottom[0]'s shape when applied \ 58 | starting with bottom[0] axis = {}", 59 | axis); 60 | } 61 | 62 | let bias_shape = if num_axes == -1 { 63 | &b0.shape()[axis..] 64 | } else { 65 | &b0.shape()[axis..(axis + num_axes as usize)] 66 | }; 67 | let mut blob = Blob::with_shape(bias_shape); 68 | let filler = get_filler(param.get_filler()); 69 | filler.fill(&mut blob); 70 | self.layer.blobs.push(make_shared_blob(blob)); 71 | } 72 | 73 | self.layer.param_propagate_down.resize(self.layer.blobs.len(), true); 74 | } 75 | 76 | fn reshape(&mut self, bottom: &BlobVec, top: &BlobVec) { 77 | let param = self.layer.layer_param.get_bias_param(); 78 | let b1 = bottom.get(1).map(|b| b.as_ref().borrow()); 79 | let blobs = &self.layer.blobs; 80 | let bias = b1.unwrap_or_else(|| blobs[0].as_ref().borrow()); 81 | // Always set axis == 0 in special case where bias is a scalar 82 | // (num_axes == 0). Mathematically equivalent for any choice of axis, so the 83 | // actual setting can be safely ignored; and computation is most efficient 84 | // with axis == 0 and (therefore) outer_dim_ == 1. 85 | let b0 = bottom[0].as_ref().borrow(); 86 | let axis = if bias.num_axes() == 0 { 87 | 0 88 | } else { 89 | b0.canonical_axis_index(param.get_axis()) as i32 90 | }; 91 | assert!(b0.num_axes() >= axis + bias.num_axes(), 92 | "bias blob's shape extends past bottom[0]'s shape when applied starting with bottom[0] axis = {}", 93 | axis); 94 | for i in 0..bias.num_axes() { 95 | assert_eq!(b0.shape_idx(axis + i), bias.shape_idx(i), 96 | "dimension mismatch between bottom[0]->shape({}) and bias->shape({})", 97 | axis + i, i); 98 | } 99 | 100 | self.outer_dim = b0.count_range(0, axis as usize); 101 | self.bias_dim = bias.count() as i32; 102 | self.inner_dim = b0.count_range_to_end((axis + bias.num_axes()) as usize); 103 | self.dim = self.bias_dim * self.inner_dim; 104 | if !Rc::ptr_eq(&bottom[0], &top[0]) { 105 | top[0].borrow_mut().reshape_like(&*bottom[0].as_ref().borrow()); 106 | } 107 | self.bias_multiplier.reshape(&[self.inner_dim]); 108 | if self.bias_multiplier.cpu_data()[self.inner_dim as usize - 1usize] != T::from_i32(1) { 109 | caffe_set(self.inner_dim as usize, T::from_i32(1), self.bias_multiplier.mutable_cpu_data()); 110 | } 111 | } 112 | 113 | fn layer_type(&self) -> &'static str { 114 | "Bias" 115 | } 116 | 117 | fn min_bottom_blobs(&self) -> i32 { 118 | 1 119 | } 120 | 121 | fn max_bottom_blobs(&self) -> i32 { 122 | 2 123 | } 124 | 125 | fn exact_num_top_blobs(&self) -> i32 { 126 | 1 127 | } 128 | 129 | fn forward_cpu(&mut self, bottom: &BlobVec, top: &BlobVec) { 130 | let b1 = bottom.get(1).map(|b| b.as_ref().borrow()); 131 | let blobs = &self.layer.blobs; 132 | let bias = b1.unwrap_or_else(|| blobs[0].as_ref().borrow()); 133 | let bias_data = bias.cpu_data(); 134 | let mut t0 = top[0].borrow_mut(); 135 | let top_data = t0.mutable_cpu_data(); 136 | if !Rc::ptr_eq(&bottom[0], &top[0]) { 137 | let b0 = bottom[0].as_ref().borrow(); 138 | let bottom_data = b0.cpu_data(); 139 | caffe_copy(b0.count(), bottom_data, top_data); 140 | } 141 | 142 | let mut top_offset = 0usize; 143 | for n in 0..self.outer_dim { 144 | T::caffe_cpu_gemm(Transpose::None, Transpose::None, self.bias_dim, self.inner_dim, 1, 145 | T::from_i32(1), bias_data, self.bias_multiplier.cpu_data(), 146 | T::from_i32(1), &mut top_data[top_offset..]); 147 | top_offset += self.dim as usize; 148 | } 149 | } 150 | 151 | fn forward_gpu(&mut self, _bottom: &BlobVec, _top: &BlobVec) { 152 | no_gpu!(); 153 | } 154 | 155 | fn backward_cpu(&mut self, top: &BlobVec, propagate_down: &Vec, bottom: &BlobVec) { 156 | if propagate_down[0] && !Rc::ptr_eq(&bottom[0], &top[0]) { 157 | let t0 = top[0].as_ref().borrow(); 158 | let mut b0 = bottom[0].borrow_mut(); 159 | let count = b0.count(); 160 | caffe_copy(count, t0.cpu_diff(), b0.mutable_cpu_diff()); 161 | } 162 | 163 | // in-place, we don't need to do anything with the data diff 164 | let bias_param = bottom.len() == 1; 165 | if (!bias_param && propagate_down[1]) || (bias_param && self.layer.param_propagate_down[0]) { 166 | let t0 = top[0].as_ref().borrow(); 167 | let b1 = bottom.get(1).map(|b| b.borrow_mut()); 168 | let blobs = &self.layer.blobs; 169 | let mut bias = b1.unwrap_or_else(|| blobs[0].borrow_mut()); 170 | 171 | let top_diff = t0.cpu_diff(); 172 | let bias_diff = bias.mutable_cpu_diff(); 173 | let mut accum = bias_param; 174 | let mut top_offset = 0usize; 175 | for n in 0..self.outer_dim { 176 | T::caffe_cpu_gemv(Transpose::None, self.bias_dim, self.inner_dim, T::from_i32(1), 177 | &top_diff[top_offset..], self.bias_multiplier.cpu_data(), 178 | T::from_i32(accum as i32), bias_diff); 179 | top_offset += self.dim as usize; 180 | accum = true; 181 | } 182 | } 183 | } 184 | 185 | fn backward_gpu(&mut self, _top: &BlobVec, _propagate_down: &Vec, _bottom: &BlobVec) { 186 | no_gpu!(); 187 | } 188 | } 189 | 190 | register_layer_class!(Bias); 191 | -------------------------------------------------------------------------------- /src/layers/base_data_layer.rs: -------------------------------------------------------------------------------- 1 | use crate::blob::{BlobType, Blob, ArcBlob}; 2 | use crate::data_transformer::DataTransformer; 3 | use crate::internal_thread::{InternalThread, CancelToken, InternalThreadImpl}; 4 | use crate::layer::{LayerImpl, BlobVec}; 5 | use crate::proto::caffe::{TransformationParameter, LayerParameter}; 6 | use crate::util::blocking_queue::BlockingQueue; 7 | 8 | 9 | /// Provides base for data layers that feed blobs to the Net. 10 | pub struct BaseDataLayerImpl { 11 | pub layer: LayerImpl, 12 | pub transform_param: TransformationParameter, 13 | pub data_transformer: Option>, 14 | pub output_labels: bool, 15 | } 16 | 17 | impl BaseDataLayerImpl { 18 | pub fn new(param: &LayerParameter) -> Self { 19 | BaseDataLayerImpl { 20 | layer: LayerImpl::new(param), 21 | transform_param: param.get_transform_param().clone(), 22 | data_transformer: None, 23 | output_labels: false, 24 | } 25 | } 26 | } 27 | 28 | pub trait BaseDataLayer { 29 | type BaseDataType: BlobType; 30 | 31 | fn get_data_impl(&self) -> &BaseDataLayerImpl; 32 | 33 | fn get_data_impl_mut(&mut self) -> &mut BaseDataLayerImpl; 34 | 35 | fn data_layer_setup(&mut self, bottom: &BlobVec, top: &BlobVec); 36 | 37 | /// LayerSetUp: implements common data layer setup functionality, and calls 38 | /// `data_layer_setUp` to do special data layer setup for individual layer types. 39 | /// This method may not be overridden except by the `BasePrefetchingDataLayer`. 40 | fn layer_setup(&mut self, bottom: &BlobVec, top: &BlobVec) { 41 | let data = self.get_data_impl_mut(); 42 | data.output_labels = top.len() != 1; 43 | data.data_transformer = Some(DataTransformer::new(&data.transform_param, data.layer.phase)); 44 | data.data_transformer.as_mut().unwrap().init_rand(); 45 | 46 | // The subclasses should setup the size of bottom and top 47 | self.data_layer_setup(bottom, top); 48 | } 49 | } 50 | 51 | 52 | #[derive(Default)] 53 | pub struct Batch { 54 | pub data: ArcBlob, 55 | pub label: ArcBlob, 56 | } 57 | 58 | pub struct BasePrefetchingDataLayerImpl { 59 | pub base: BaseDataLayerImpl, 60 | 61 | pub prefetch: Vec>, 62 | pub prefetch_free: BlockingQueue>, 63 | pub prefetch_full: BlockingQueue>, 64 | pub prefetch_current: Option>, 65 | 66 | pub transformed_data: Blob, 67 | 68 | pub thread: InternalThreadImpl, 69 | } 70 | 71 | impl BasePrefetchingDataLayerImpl { 72 | pub fn new(param: &LayerParameter) -> Self { 73 | let prefetch_count = param.get_data_param().get_prefetch() as usize; 74 | let mut this = Self { 75 | base: BaseDataLayerImpl::new(param), 76 | prefetch: Vec::with_capacity(prefetch_count), 77 | prefetch_free: BlockingQueue::new(), 78 | prefetch_full: BlockingQueue::new(), 79 | prefetch_current: None, 80 | transformed_data: Blob::new(), 81 | thread: InternalThreadImpl::default(), 82 | }; 83 | this.prefetch.resize_with(prefetch_count, Default::default); 84 | 85 | this 86 | } 87 | } 88 | 89 | pub trait BasePrefetchingDataLayer: BaseDataLayer { 90 | type PrefetchDataType: Send + 'static; 91 | 92 | fn get_prefetch(&self) -> &BasePrefetchingDataLayerImpl; 93 | 94 | fn get_prefetch_mut(&mut self) -> &mut BasePrefetchingDataLayerImpl; 95 | 96 | fn forward_cpu(&mut self, _bottom: &BlobVec, top: &BlobVec) { 97 | let base = self.get_prefetch_mut(); 98 | base.prefetch_current.take().map(|b| base.prefetch_free.push(b)); 99 | let mut batch = base.prefetch_full.pop(); 100 | // Reshape to loaded data. 101 | let t0 = &top[0]; 102 | let mut data = std::mem::take(&mut batch.data).into_blob().ok().unwrap(); 103 | t0.borrow_mut().reshape_like(&data); 104 | t0.borrow_mut().set_cpu_data(&data.cpu_data_shared()); 105 | if base.base.output_labels { 106 | let mut label = std::mem::take(&mut batch.label).into_blob().ok().unwrap(); 107 | let t1 = &top[1]; 108 | t1.borrow_mut().reshape_like(&label); 109 | t1.borrow_mut().set_cpu_data(&label.cpu_data_shared()); 110 | } 111 | 112 | // `batch` value are taken and only leaving default value. 113 | base.prefetch_current.replace(batch); 114 | } 115 | 116 | fn forward_gpu(&mut self, bottom: &BlobVec, top: &BlobVec) { 117 | no_gpu!(); 118 | } 119 | 120 | fn get_sync_data(&self) -> Self::PrefetchDataType; 121 | 122 | fn load_batch(data: &mut Self::PrefetchDataType, batch: &mut Batch) where Self: Sized; 123 | } 124 | 125 | 126 | pub struct ThreadEntryData { 127 | pub prefetch_free: BlockingQueue>, 128 | pub prefetch_full: BlockingQueue>, 129 | pub prefetch_data: S, 130 | } 131 | 132 | impl InternalThread for U 133 | where 134 | U: BasePrefetchingDataLayer { 135 | type EntryData = ThreadEntryData; 136 | 137 | fn get_thread(&self) -> &InternalThreadImpl { 138 | &self.get_prefetch().thread 139 | } 140 | 141 | fn get_thread_mut(&mut self) -> &mut InternalThreadImpl { 142 | &mut self.get_prefetch_mut().thread 143 | } 144 | 145 | fn get_entry_data(&mut self) -> Box { 146 | let base = self.get_prefetch(); 147 | let data = self.get_sync_data(); 148 | Box::new(ThreadEntryData { 149 | prefetch_free: base.prefetch_free.clone(), 150 | prefetch_full: base.prefetch_full.clone(), 151 | prefetch_data: data 152 | }) 153 | } 154 | 155 | fn internal_thread_entry(token: CancelToken, data: Box) { 156 | let mut prefetch_data = data.prefetch_data; 157 | let prefetch_free = data.prefetch_free; 158 | let prefetch_full = data.prefetch_full; 159 | while !token.is_cancelled() { 160 | let mut batch = prefetch_free.pop(); 161 | Self::load_batch(&mut prefetch_data, &mut batch); 162 | prefetch_full.push(batch); 163 | } 164 | } 165 | } 166 | 167 | 168 | pub trait BasePrefetchingDataLayerSetup { 169 | type BlobDataType: BlobType; 170 | 171 | fn layer_setup(&mut self, bottom: &BlobVec, top: &BlobVec); 172 | } 173 | 174 | impl BasePrefetchingDataLayerSetup for T 175 | where 176 | T: BasePrefetchingDataLayer { 177 | type BlobDataType = T::BaseDataType; 178 | 179 | fn layer_setup(&mut self, bottom: &BlobVec, top: &BlobVec) { 180 | BaseDataLayer::layer_setup(self, bottom, top); 181 | 182 | let base = self.get_prefetch_mut(); 183 | 184 | // Before starting the prefetch thread, we make cpu_data and gpu_data 185 | // calls so that the prefetch thread does not accidentally make simultaneous 186 | // memory alloc calls when the main thread is running. 187 | let output_labels = base.base.output_labels; 188 | for prefetch in &mut base.prefetch { 189 | let mut batch = std::mem::take(prefetch); 190 | 191 | let mut data = batch.data.into_blob().ok().unwrap(); 192 | data.mutable_cpu_data(); 193 | batch.data = ArcBlob::from(data).ok().unwrap(); 194 | if output_labels { 195 | let mut labels = batch.label.into_blob().ok().unwrap(); 196 | labels.mutable_cpu_data(); 197 | batch.label = ArcBlob::from(labels).ok().unwrap(); 198 | } 199 | 200 | *prefetch = batch; 201 | } 202 | 203 | info!("Initializing prefetch."); 204 | // base.base.data_transformer.unwrap().init_rand(); 205 | self.start_internal_thread(); 206 | info!("Prefetch initialized.") 207 | } 208 | } 209 | -------------------------------------------------------------------------------- /src/util/im2col.rs: -------------------------------------------------------------------------------- 1 | use crate::blob::BlobType; 2 | use crate::util::math_functions::caffe_set; 3 | 4 | 5 | /// Function uses casting from int to unsigned to compare if value of 6 | /// parameter a is greater or equal to zero and lower than value of 7 | /// parameter b. The b parameter is of type signed and is always positive, 8 | /// therefore its value is always lower than `0x800...` where casting 9 | /// negative value of a parameter converts it to value higher than `0x800...` 10 | /// The casting allows to use one condition instead of two. 11 | fn is_a_ge_zero_and_a_lt_b(a: i32, b: i32) -> bool { 12 | (a as u32) < (b as u32) 13 | } 14 | 15 | fn im2col_nd_core_cpu(data_input: &[T], im2col: bool, num_spatial_axes: usize, im_shape: &[i32], 16 | col_shape: &[i32], kernel_shape: &[i32], pad: &[i32], stride: &[i32], 17 | dilation: &[i32], data_output: &mut [T]) { 18 | if !im2col { 19 | let mut im_size = im_shape[0]; 20 | for i in 0..num_spatial_axes { 21 | im_size *= im_shape[i + 1]; 22 | } 23 | caffe_set(im_size as usize, T::default(), data_output); 24 | } 25 | 26 | let mut kernel_size = 1; 27 | for i in 0..num_spatial_axes { 28 | kernel_size *= kernel_shape[i]; 29 | } 30 | 31 | let channels_col = col_shape[0]; 32 | let mut d_offset = Vec::with_capacity(num_spatial_axes); 33 | d_offset.resize(num_spatial_axes, 0); 34 | let mut d_iter = Vec::with_capacity(num_spatial_axes); 35 | d_iter.resize(num_spatial_axes, 0); 36 | for c_col in 0..channels_col { 37 | // Loop over spatial axes in reverse order to compute a per-axis offset. 38 | let mut offset = c_col; 39 | for d_i in (0..num_spatial_axes).rev() { 40 | if d_i < num_spatial_axes - 1 { 41 | offset /= kernel_shape[d_i + 1]; 42 | } 43 | d_offset[d_i] = offset % kernel_shape[d_i]; 44 | } 45 | 46 | let mut incremented = true; 47 | while incremented { 48 | // Loop over spatial axes in forward order to compute the indices in the 49 | // image and column, and whether the index lies in the padding. 50 | let mut index_col = c_col; 51 | let mut index_im = c_col / kernel_size; 52 | let mut is_padding = false; 53 | for d_i in 0..num_spatial_axes { 54 | let d = d_iter[d_i]; 55 | let d_im = d * stride[d_i] - pad[d_i] + d_offset[d_i] * dilation[d_i]; 56 | is_padding |= (d_im < 0) || (d_im >= im_shape[d_i + 1]); 57 | index_col *= col_shape[d_i + 1]; 58 | index_col += d; 59 | index_im *= im_shape[d_i + 1]; 60 | index_im += d_im; 61 | } 62 | 63 | if im2col { 64 | if is_padding { 65 | data_output[index_col as usize] = T::default(); 66 | } else { 67 | data_output[index_col as usize] = data_input[index_im as usize]; 68 | } 69 | } else if !is_padding { 70 | // col2im 71 | data_output[index_im as usize] = data_input[index_col as usize]; 72 | } 73 | 74 | // Loop over spatial axes in reverse order to choose an index, like counting. 75 | incremented = false; 76 | for d_i in (0..num_spatial_axes).rev() { 77 | let d_max = col_shape[d_i + 1]; 78 | check_lt!(d_iter[d_i], d_max); 79 | if d_iter[d_i] == d_max - 1 { 80 | d_iter[d_i] = 0; 81 | } else { 82 | // d_iter[d_i] < d_max - 1 83 | d_iter[d_i] += 1; 84 | incremented = true; 85 | break; 86 | } 87 | } 88 | } 89 | } 90 | } 91 | 92 | pub fn im2col_nd_cpu(data_im: &[T], num_spatial_axes: usize, im_shape: &[i32], 93 | col_shape: &[i32], kernel_shape: &[i32], pad: &[i32], stride: &[i32], 94 | dilation: &[i32], data_col: &mut [T]) { 95 | const K_IM2COL: bool = true; 96 | im2col_nd_core_cpu(data_im, K_IM2COL, num_spatial_axes, im_shape, col_shape, 97 | kernel_shape, pad, stride, dilation, data_col); 98 | } 99 | 100 | pub fn im2col_cpu(data_im: &[T], channels: i32, height: i32, width: i32, kernel_h: i32, kernel_w: i32, 101 | pad_h: i32, pad_w: i32, stride_h: i32, stride_w: i32, 102 | dilation_h: i32, dilation_w: i32, data_col: &mut [T]) { 103 | let output_h = (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1; 104 | let output_w = (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1; 105 | let channel_size = (height * width) as usize; 106 | 107 | let mut channel_offset = 0usize; 108 | let mut col_offset = 0usize; 109 | for _channel in 0..channels { 110 | let data_im = &data_im[channel_offset..]; 111 | for kernel_row in 0..kernel_h { 112 | for kernel_col in 0..kernel_w { 113 | let mut input_row = -pad_h + kernel_row * dilation_h; 114 | for _output_rows in 0..output_h { 115 | if !is_a_ge_zero_and_a_lt_b(input_row, height) { 116 | for _output_cols in 0..output_w { 117 | data_col[col_offset] = T::default(); 118 | col_offset += 1; 119 | } 120 | } else { 121 | let mut input_col = -pad_w + kernel_col * dilation_w; 122 | for _output_col in 0..output_w { 123 | if is_a_ge_zero_and_a_lt_b(input_col, width) { 124 | data_col[col_offset] = data_im[(input_row * width + input_col) as usize]; 125 | col_offset += 1; 126 | } else { 127 | data_col[col_offset] = T::default(); 128 | col_offset += 1; 129 | } 130 | input_col += stride_w; 131 | } 132 | } 133 | 134 | input_row += stride_h; 135 | } 136 | } 137 | } 138 | 139 | channel_offset += channel_size; 140 | } 141 | } 142 | 143 | pub fn col2im_nd_cpu(data_col: &[T], num_spatial_axes: usize, im_shape: &[i32], 144 | col_shape: &[i32], kernel_shape: &[i32], pad: &[i32], stride: &[i32], 145 | dilation: &[i32], data_im: &mut [T]) { 146 | const K_IM2COL: bool = false; 147 | im2col_nd_core_cpu(data_col, K_IM2COL, num_spatial_axes, im_shape, col_shape, 148 | kernel_shape, pad, stride, dilation, data_im); 149 | } 150 | 151 | pub fn col2im_cpu(data_col: &[T], channels: i32, height: i32, width: i32, kernel_h: i32, kernel_w: i32, 152 | pad_h: i32, pad_w: i32, stride_h: i32, stride_w: i32, 153 | dilation_h: i32, dilation_w: i32, data_im: &mut [T]) { 154 | caffe_set((height * width * channels) as usize, T::default(), data_im); 155 | let output_h = (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1; 156 | let output_w = (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1; 157 | let output_w = output_w as usize; 158 | let channel_size = (height * width) as usize; 159 | let mut channel_offset = 0usize; 160 | let mut col_offset = 0usize; 161 | for _channel in 0..channels { 162 | let data_im = &mut data_im[channel_offset..]; 163 | for kernel_row in 0..kernel_h { 164 | for kernel_col in 0..kernel_w { 165 | let mut input_row = -pad_h + kernel_row * dilation_h; 166 | for _output_rows in 0..output_h { 167 | if !is_a_ge_zero_and_a_lt_b(input_row, height) { 168 | col_offset += output_w; 169 | } else { 170 | let mut input_col = -pad_w + kernel_col * dilation_w; 171 | for _output_col in 0..output_w { 172 | if is_a_ge_zero_and_a_lt_b(input_col, width) { 173 | data_im[(input_row * width + input_col) as usize] += data_col[col_offset]; 174 | } 175 | col_offset += 1; 176 | input_col += stride_w; 177 | } 178 | } 179 | input_row += stride_h; 180 | } 181 | } 182 | } 183 | 184 | channel_offset += channel_size; 185 | } 186 | } 187 | -------------------------------------------------------------------------------- /src/layers/accuracy_layer.rs: -------------------------------------------------------------------------------- 1 | use crate::blob::{BlobType, Blob}; 2 | use crate::layer::{CaffeLayer, LayerImpl, BlobVec}; 3 | use crate::proto::caffe::LayerParameter; 4 | use crate::util::math_functions::caffe_set; 5 | 6 | 7 | /// Computes the classification accuracy for a one-of-many classification task. 8 | pub struct AccuracyLayer { 9 | layer: LayerImpl, 10 | label_axis: i32, 11 | outer_num: i32, 12 | inner_num: i32, 13 | top_k: u32, 14 | /// Whether to ignore instances with a certain label. 15 | has_ignore_label: bool, 16 | /// The label indicating that an instance should be ignored. 17 | ignore_label: i32, 18 | /// Keeps counts of the number of samples per class. 19 | nums_buffer: Blob, 20 | } 21 | 22 | impl AccuracyLayer { 23 | /// `param` provides **AccuracyParameter** accuracy_param, with **AccuracyLayer options**: 24 | /// - top_k (**optional, default `1`**). Sets the maximum rank $ k $ at which a prediction 25 | /// is considered correct. For example, if $ k = 5 $, a prediction is counted correct if 26 | /// the correct label is among the top 5 predicted labels. 27 | pub fn new(param: &LayerParameter) -> Self { 28 | AccuracyLayer { 29 | layer: LayerImpl::new(param), 30 | label_axis: 0, 31 | outer_num: 0, 32 | inner_num: 0, 33 | top_k: 0, 34 | has_ignore_label: false, 35 | ignore_label: 0, 36 | nums_buffer: Blob::new() 37 | } 38 | } 39 | } 40 | 41 | impl CaffeLayer for AccuracyLayer { 42 | type DataType = T; 43 | 44 | fn get_impl(&self) -> &LayerImpl { 45 | &self.layer 46 | } 47 | 48 | fn get_impl_mut(&mut self) -> &mut LayerImpl { 49 | &mut self.layer 50 | } 51 | 52 | fn layer_setup(&mut self, _bottom: &BlobVec, _top: &BlobVec) { 53 | self.top_k = self.layer.layer_param.get_accuracy_param().get_top_k(); 54 | self.has_ignore_label = self.layer.layer_param.get_accuracy_param().has_ignore_label(); 55 | if self.has_ignore_label { 56 | self.ignore_label = self.layer.layer_param.get_accuracy_param().get_ignore_label(); 57 | } 58 | } 59 | 60 | fn reshape(&mut self, bottom: &BlobVec, top: &BlobVec) { 61 | let b0 = bottom[0].as_ref().borrow(); 62 | let b1 = bottom[1].as_ref().borrow(); 63 | check_le!(self.top_k as usize, b0.count() / b1.count(), 64 | "top_k must be less than or equal to the number of classes."); 65 | 66 | let axis = b0.canonical_axis_index(self.layer.layer_param.get_accuracy_param().get_axis()); 67 | self.label_axis = axis as i32; 68 | self.outer_num = b0.count_range(0, axis); 69 | self.inner_num = b0.count_range_to_end(axis + 1); 70 | assert_eq!(self.outer_num * self.inner_num, b1.count() as i32, 71 | "Number of labels must match number of predictions; e.g., if label axis == 1 \ 72 | and prediction shape is (N, C, H, W), label count (number of labels) must be \ 73 | N*H*W, with integer values in {{0, 1, ..., C-1}}."); 74 | 75 | // Accuracy is a scalar; 0 axes. 76 | top[0].borrow_mut().reshape(&Vec::new()); 77 | if top.len() > 1 { 78 | // Per-class accuracy is a vector; 1 axes. 79 | let top_shape_per_class = vec![b0.shape_idx(self.label_axis)]; 80 | top[1].borrow_mut().reshape(&top_shape_per_class); 81 | self.nums_buffer.reshape(&top_shape_per_class); 82 | } 83 | } 84 | 85 | fn layer_type(&self) -> &'static str { 86 | "Accuracy" 87 | } 88 | 89 | fn exact_num_bottom_blobs(&self) -> i32 { 90 | 2 91 | } 92 | 93 | fn min_top_blobs(&self) -> i32 { 94 | 1 95 | } 96 | 97 | /// If there are two top blobs, then the second blob will contain accuracies per class. 98 | fn max_top_blobs(&self) -> i32 { 99 | 2 100 | } 101 | 102 | /// *Params:* 103 | /// 104 | /// `bottom` input Blob vector (length 2): 105 | /// - $ (N \times C \times H \times W) $, the predictions $ x $, a Blob with values in 106 | /// $ [-\infty, +\infty] $ indicating the predicted score for each of the $ K = CHW $ 107 | /// classes, Each $ x_n $ is mapped to a predicted label $ \hat{l}_n $ given by its 108 | /// maximal index: $ \hat{l}\_n = \arg\max\limits\_k x\_{nk} $. 109 | /// - $ (N \times 1 \times 1 \times 1) $, the labels $ l $, an integer-valued Blob with values 110 | /// $ l_n \in [0, 1, 2, ..., K - 1] $ indicating the correct class label among the $ K $ classes. 111 | /// 112 | /// `top` output Blob vector (length 1): 113 | /// - $ (1 \times 1 \times 1 \times 1) $, the computed accuracy: $$ 114 | /// \frac{1}{N} \sum\limits_{n=1}^N \delta\{ \hat{l}_n = l_n \} 115 | /// $$, where $$ 116 | /// \delta\\{\mathrm{condition}\\} = \left\\{ 117 | /// \begin{array}{lr} 118 | /// 1 & \mathrm{if condition} \\\\ 119 | /// 0 & \mathrm{otherwise} 120 | /// \end{array} \right. 121 | /// $$. 122 | fn forward_cpu(&mut self, bottom: &BlobVec, top: &BlobVec) { 123 | let mut accuracy = T::default(); 124 | let b0 = bottom[0].as_ref().borrow(); 125 | let b1 = bottom[1].as_ref().borrow(); 126 | let bottom_data = b0.cpu_data(); 127 | let bottom_label = b1.cpu_data(); 128 | let dim = b0.count() as i32 / self.outer_num; 129 | let num_labels = b0.shape_idx(self.label_axis); 130 | if top.len() > 1 { 131 | caffe_set(self.nums_buffer.count(), T::default(), self.nums_buffer.mutable_cpu_data()); 132 | let mut t1 = top[1].borrow_mut(); 133 | let count = t1.count(); 134 | caffe_set(count, T::default(), t1.mutable_cpu_data()); 135 | } 136 | 137 | let mut count = 0; 138 | for i in 0..self.outer_num { 139 | for j in 0..self.inner_num { 140 | let label_value = bottom_label[(i * self.inner_num + j) as usize].to_i32(); 141 | if self.has_ignore_label && self.ignore_label == label_value { 142 | continue; 143 | } 144 | 145 | check_ge!(label_value, 0); 146 | check_lt!(label_value, num_labels); 147 | if top.len() > 1 { 148 | self.nums_buffer.mutable_cpu_data()[label_value as usize] += T::from_i32(1); 149 | } 150 | 151 | let prob_of_true_class = bottom_data[(i * dim + label_value * self.inner_num + j) as usize]; 152 | let mut num_better_predictions = -1; // true_class also counts as "better" 153 | let top_k = self.top_k as i32; 154 | // Top-k accuracy 155 | let mut k = 0; 156 | while k < num_labels && num_better_predictions < top_k { 157 | let v = bottom_data[(i * dim + k * self.inner_num + j) as usize]; 158 | num_better_predictions += (v >= prob_of_true_class) as i32; 159 | k += 1; 160 | } 161 | // Check if there are less than top_k predictions 162 | if num_better_predictions < top_k { 163 | accuracy += T::from_i32(1); 164 | if top.len() > 1 { 165 | let mut t1 = top[1].borrow_mut(); 166 | t1.mutable_cpu_data()[label_value as usize] += T::from_i32(1); 167 | } 168 | } 169 | 170 | count += 1; 171 | } 172 | } 173 | 174 | top[0].borrow_mut().mutable_cpu_data()[0] = if count == 0 { 175 | T::default() 176 | } else { 177 | accuracy / T::from_i32(count) 178 | }; 179 | if top.len() > 1 { 180 | let mut t1 = top[1].borrow_mut(); 181 | for i in 0..t1.count() { 182 | let num = self.nums_buffer.cpu_data()[i]; 183 | let v = if num.is_zero() { 184 | T::default() 185 | } else { 186 | t1.cpu_data()[i] / num 187 | }; 188 | t1.mutable_cpu_data()[i] = v; 189 | } 190 | } 191 | // Accuracy layer should not be used as a loss function. 192 | } 193 | 194 | fn forward_gpu(&mut self, bottom: &BlobVec, top: &BlobVec) { 195 | no_gpu!(); 196 | } 197 | 198 | /// Not implemented -- AccuracyLayer cannot be used as a loss. 199 | fn backward_cpu(&mut self, _top: &BlobVec, propagate_down: &Vec, _bottom: &BlobVec) { 200 | for &prop_down in propagate_down { 201 | if prop_down { 202 | unimplemented!(); 203 | } 204 | } 205 | } 206 | 207 | fn backward_gpu(&mut self, top: &BlobVec, propagate_down: &Vec, bottom: &BlobVec) { 208 | no_gpu!(); 209 | } 210 | } 211 | 212 | register_layer_class!(Accuracy); 213 | -------------------------------------------------------------------------------- /src/synced_mem.rs: -------------------------------------------------------------------------------- 1 | use std::alloc::{alloc_zeroed, dealloc, Layout}; 2 | use std::cell::{RefCell, Ref, RefMut}; 3 | use std::rc::Rc; 4 | use std::sync::{Arc, Mutex}; 5 | 6 | 7 | struct MemPtr { 8 | ptr: *mut T, 9 | count: usize, 10 | } 11 | 12 | /// Makes the raw memory handle be sent between threads. 13 | unsafe impl Send for MemPtr {} 14 | 15 | impl MemPtr { 16 | pub fn new(count: usize) -> Self { 17 | let layout = Layout::array::(count).unwrap(); 18 | trace!("alloc owned memory, layout: {:?}", layout); 19 | MemPtr { 20 | ptr: unsafe { alloc_zeroed(layout) as *mut T }, 21 | count, 22 | } 23 | } 24 | 25 | pub fn raw_parts(&self) -> (*const T, usize) { 26 | (self.ptr as *const T, self.count) 27 | } 28 | 29 | pub fn raw_parts_mut(&mut self) -> (*mut T, usize) { 30 | (self.ptr, self.count) 31 | } 32 | 33 | pub fn as_slice(&self) -> &[T] { 34 | unsafe { std::slice::from_raw_parts(self.ptr, self.count) } 35 | } 36 | 37 | pub fn as_mut_slice(&mut self) -> &mut [T] { 38 | unsafe { std::slice::from_raw_parts_mut(self.ptr, self.count) } 39 | } 40 | } 41 | 42 | impl Drop for MemPtr { 43 | fn drop(&mut self) { 44 | let layout = Layout::array::(self.count).unwrap(); 45 | trace!("dealloc owned memory, ptr: {:?}, size: {} * {}", self.ptr, self.count, std::mem::size_of::()); 46 | unsafe { dealloc(self.ptr as *mut u8, layout); } 47 | } 48 | } 49 | 50 | 51 | #[derive(Clone)] 52 | pub struct MemShared { 53 | mem: Rc>>, 54 | offset: isize, 55 | } 56 | 57 | impl MemShared { 58 | /// Make a new instance which pointer is offset by a length `offset * std::mem::size_of::()`. 59 | pub fn offset(&self, offset: i32) -> Self { 60 | MemShared { 61 | mem: self.mem.clone(), 62 | offset: self.offset + offset as isize, 63 | } 64 | } 65 | } 66 | 67 | 68 | /// Manages memory allocation and synchronization between the host (CPU) and device (GPU) 69 | pub struct SyncedMemory { 70 | cpu_mem: Option>>>, 71 | count: usize, 72 | cpu_offset: isize, 73 | } 74 | 75 | impl Default for SyncedMemory { 76 | fn default() -> Self { 77 | SyncedMemory { 78 | cpu_mem: Default::default(), 79 | count: 0, 80 | cpu_offset: 0, 81 | } 82 | } 83 | } 84 | 85 | impl SyncedMemory { 86 | pub fn new_uninit() -> Self { 87 | Default::default() 88 | } 89 | 90 | /// Construct an instance without allocating memory. Note that `count` is the num of item of type `T`, 91 | /// so the actual memory size in bytes is `std::mem::size_of::() * count`. 92 | pub fn new(count: usize) -> Self { 93 | SyncedMemory { 94 | cpu_mem: Default::default(), 95 | count, 96 | cpu_offset: 0 97 | } 98 | } 99 | 100 | #[inline] 101 | pub fn count(&self) -> usize { 102 | self.count 103 | } 104 | 105 | #[inline] 106 | pub fn bytes_size(&self) -> usize { 107 | std::mem::size_of::() * self.count 108 | } 109 | 110 | fn sync_to_cpu(&mut self) { 111 | if let Option::None = self.cpu_mem { 112 | trace!("Synced CPU memory type from uninitialized to alloc."); 113 | self.cpu_mem = Some(Rc::new(RefCell::new(MemPtr::new(self.count)))); 114 | } 115 | } 116 | 117 | pub fn cpu_data(&mut self) -> &[T] { 118 | self.sync_to_cpu(); 119 | let (ptr, _) = RefCell::borrow(self.cpu_mem.as_ref().unwrap()).raw_parts(); 120 | unsafe { std::slice::from_raw_parts(ptr.offset(self.cpu_offset), self.count) } 121 | } 122 | 123 | pub fn cpu_data_raw(&mut self) -> (*const T, usize) { 124 | self.sync_to_cpu(); 125 | let (ptr, _) = self.cpu_mem.as_ref().unwrap().as_ref().borrow().raw_parts(); 126 | (unsafe { ptr.offset(self.cpu_offset) }, self.count) 127 | } 128 | 129 | pub fn cpu_data_shared(&mut self) -> MemShared { 130 | self.sync_to_cpu(); 131 | MemShared { 132 | mem: Rc::clone(self.cpu_mem.as_ref().unwrap()), 133 | offset: self.cpu_offset, 134 | } 135 | } 136 | 137 | pub fn try_map_cpu_data(&self, f: F) -> Option where F: FnOnce(&[T]) -> U { 138 | self.cpu_mem.as_ref().map(|ptr| { 139 | let (ptr, _) = RefCell::borrow((*ptr).as_ref()).raw_parts(); 140 | f(unsafe { std::slice::from_raw_parts(ptr.offset(self.cpu_offset), self.count) }) 141 | }) 142 | } 143 | 144 | pub fn mutable_cpu_data(&mut self) -> &mut [T] { 145 | self.sync_to_cpu(); 146 | let (ptr, _) = RefCell::borrow_mut(self.cpu_mem.as_ref().unwrap()).raw_parts_mut(); 147 | unsafe { std::slice::from_raw_parts_mut(ptr.offset(self.cpu_offset), self.count) } 148 | } 149 | 150 | pub fn mutable_cpu_data_raw(&mut self) -> (*mut T, usize) { 151 | self.sync_to_cpu(); 152 | let (ptr, _) = self.cpu_mem.as_ref().unwrap().borrow_mut().raw_parts_mut(); 153 | (unsafe { ptr.offset(self.cpu_offset) }, self.count) 154 | } 155 | 156 | pub fn try_map_cpu_mut_data(&mut self, f: F) -> Option where F: FnOnce(&mut [T]) -> U { 157 | self.cpu_mem.as_ref().map(|ptr| { 158 | let (ptr, _) = RefCell::borrow_mut((*ptr).as_ref()).raw_parts_mut(); 159 | f(unsafe { std::slice::from_raw_parts_mut(ptr.offset(self.cpu_offset), self.count) }) 160 | }) 161 | } 162 | 163 | pub fn set_cpu_data(&mut self, data: &MemShared) { 164 | let mem_ptr = data.mem.as_ptr(); 165 | let &MemPtr { ptr, count } = unsafe { &*mem_ptr }; 166 | 167 | trace!("Set a borrowed slice of CPU memory. ptr: {:?}, len: {} * {}; offset: {}", 168 | ptr, count, std::mem::size_of::(), data.offset); 169 | if data.offset < 0 { 170 | panic!("Set a borrowed memory but which offset({}) < 0.", data.offset); 171 | } 172 | if self.count + data.offset as usize > count { 173 | panic!("Set a slice which length ({} - offset({}) = {}) less than the memory need ({}).", 174 | count, data.offset, count as isize - data.offset, self.count); 175 | } 176 | 177 | self.cpu_mem = Some(Rc::clone(&data.mem)); 178 | self.cpu_offset = data.offset; 179 | } 180 | } 181 | 182 | 183 | pub struct ArcSyncedMemory { 184 | cpu_mem: Option>>>, 185 | count: usize, 186 | cpu_offset: isize, 187 | } 188 | 189 | impl ArcSyncedMemory { 190 | pub fn new() -> Self { 191 | Self { 192 | cpu_mem: Default::default(), 193 | count: 0, 194 | cpu_offset: 0, 195 | } 196 | } 197 | 198 | pub fn from(mut mem: SyncedMemory) -> Result> { 199 | let cpu_mem = mem.cpu_mem.map(|r| Rc::try_unwrap(r)); 200 | if cpu_mem.as_ref().map_or(false, |r| r.is_err()) { 201 | mem.cpu_mem = cpu_mem.map(|r| r.err().unwrap()); 202 | return Result::Err(mem); 203 | } 204 | 205 | let cpu_mem = cpu_mem.map( 206 | |r| Arc::new(Mutex::new(r.ok().unwrap().into_inner())) 207 | ); 208 | let arc_mem = ArcSyncedMemory { 209 | cpu_mem, 210 | count: mem.count, 211 | cpu_offset: mem.cpu_offset, 212 | }; 213 | Result::Ok(arc_mem) 214 | } 215 | 216 | pub fn into_mem(mut self) -> Result, Self> { 217 | let cpu_mem = self.cpu_mem.map(|a| Arc::try_unwrap(a)); 218 | if cpu_mem.as_ref().map_or(false, |r| r.is_err()) { 219 | self.cpu_mem = cpu_mem.map(|r| r.err().unwrap()); 220 | return Result::Err(self); 221 | } 222 | 223 | let cpu_mem = cpu_mem.map( 224 | |r| Rc::new(RefCell::new(r.ok().unwrap().into_inner().unwrap())) 225 | ); 226 | let mem = SyncedMemory { 227 | cpu_mem, 228 | count: self.count, 229 | cpu_offset: self.cpu_offset, 230 | }; 231 | Result::Ok(mem) 232 | } 233 | } 234 | 235 | 236 | #[cfg(test)] 237 | use test_env_log::test; 238 | 239 | #[test] 240 | fn mem_ptr_test_new() { 241 | let _ = MemPtr::::new(54); 242 | } 243 | 244 | #[test] 245 | fn synced_mem_test_uninit() { 246 | let mut s = SyncedMemory::new_uninit(); 247 | let slice: &[i32] = s.cpu_data(); 248 | info!("New uninitialized memory, ptr: {:?}, len: {}", slice.as_ptr(), slice.len()); 249 | } 250 | 251 | #[test] 252 | fn synced_mem_test_new() { 253 | let mut s = SyncedMemory::new(78); 254 | { 255 | let mut slice = s.mutable_cpu_data(); 256 | info!("Get mutable slice from SyncedMemory: {:#?}", slice); 257 | let mut count = 0u8; 258 | for x in slice { 259 | count += 1; 260 | *x = count; 261 | } 262 | } 263 | 264 | let slice = s.cpu_data(); 265 | info!("Get const slice from SyncedMemory: {:#?}", slice); 266 | } 267 | 268 | #[test] 269 | fn synced_mem_test_slice() { 270 | let mem = MemShared { 271 | mem: Rc::new(RefCell::new(MemPtr::new(12))), 272 | offset: 2, 273 | }; 274 | { 275 | let mut s = SyncedMemory::new(9); 276 | info!("Set slice data"); 277 | s.set_cpu_data(&mem); 278 | 279 | let mut slice = s.mutable_cpu_data(); 280 | info!("Get mutable slice from SyncedMemory: {:#?}", slice); 281 | let mut count = 2u8; 282 | for x in slice { 283 | *x = count * 2u8; 284 | count += 1; 285 | } 286 | } 287 | 288 | info!("Print original slice: {:#?}", RefCell::borrow(&mem.mem).as_slice()); 289 | } 290 | -------------------------------------------------------------------------------- /src/filler.rs: -------------------------------------------------------------------------------- 1 | //! Fillers are random number generators that fills a blob using the specified 2 | //! algorithm. The expectation is that they are only going to be used during 3 | //! initialization time and will not involve any GPUs. 4 | 5 | use crate::blob::{BlobType, Blob}; 6 | use crate::proto::caffe::{FillerParameter, FillerParameter_VarianceNorm}; 7 | use crate::util::math_functions::{caffe_rng_uniform, caffe_rng_gaussian, caffe_rng_bernoulli_i32}; 8 | 9 | 10 | /// Fills a Blob with constant or randomly-generated data. 11 | pub trait Filler { 12 | fn fill(&self, blob: &mut Blob); 13 | } 14 | 15 | 16 | /// Fills a Blob with constant values $ x = 0 $. 17 | pub struct ConstantFiller { 18 | filler_param: FillerParameter, 19 | } 20 | 21 | impl ConstantFiller { 22 | pub fn new(param: &FillerParameter) -> Self { 23 | ConstantFiller { 24 | filler_param: param.clone() 25 | } 26 | } 27 | } 28 | 29 | impl Filler for ConstantFiller { 30 | fn fill(&self, blob: &mut Blob) { 31 | let count = blob.count(); 32 | let data = blob.mutable_cpu_data(); 33 | let value = self.filler_param.get_value(); 34 | debug_assert_eq!(count, data.len()); 35 | assert_ne!(count, 0); 36 | 37 | data.fill(T::from_f32(value)); 38 | assert_eq!(self.filler_param.get_sparse(), -1, "Sparsity not supported by this Filler."); 39 | } 40 | } 41 | 42 | 43 | /// Fills a Blob with uniformly distributed values $ x\sim U(a, b) $. 44 | pub struct UniformFiller { 45 | filler_param: FillerParameter, 46 | } 47 | 48 | impl UniformFiller { 49 | pub fn new(param: &FillerParameter) -> Self { 50 | UniformFiller { 51 | filler_param: param.clone() 52 | } 53 | } 54 | } 55 | 56 | impl Filler for UniformFiller { 57 | fn fill(&self, blob: &mut Blob) { 58 | let count = blob.count(); 59 | assert_ne!(count, 0); 60 | caffe_rng_uniform(count, T::from_f32(self.filler_param.get_min()), 61 | T::from_f32(self.filler_param.get_max()), blob.mutable_cpu_data()); 62 | assert_eq!(self.filler_param.get_sparse(), -1, "Sparsity not supported by this Filler."); 63 | } 64 | } 65 | 66 | 67 | /// Fills a Blob with Gaussian-distributed values $ x = a $. 68 | pub struct GaussianFiller { 69 | filler_param: FillerParameter, 70 | } 71 | 72 | impl GaussianFiller { 73 | pub fn new(param: &FillerParameter) -> Self { 74 | GaussianFiller { 75 | filler_param: param.clone(), 76 | } 77 | } 78 | } 79 | 80 | impl Filler for GaussianFiller { 81 | fn fill(&self, blob: &mut Blob) { 82 | let count = blob.count(); 83 | assert_ne!(count, 0); 84 | caffe_rng_gaussian(count, T::from_f32(self.filler_param.get_mean()), 85 | T::from_f32(self.filler_param.get_std()), blob.mutable_cpu_data()); 86 | let sparse = self.filler_param.get_sparse(); 87 | assert!(sparse >= -1); 88 | if sparse >= 0 { 89 | // Sparse initialization is implemented for "weight" blobs; i.e. matrices. 90 | // These have num == channels == 1; width is number of inputs; height is 91 | // number of outputs. The 'sparse' variable specifies the mean number 92 | // of non-zero input weights for a given output. 93 | assert!(blob.num_axes() >= 1); 94 | let num_outputs = blob.shape_idx(0); 95 | let non_zero_probability = T::from_i32(sparse) / T::from_i32(num_outputs); 96 | let mut rand_vec = Vec::with_capacity(count); 97 | rand_vec.resize(count, 0); 98 | caffe_rng_bernoulli_i32(count, non_zero_probability, &mut rand_vec); 99 | let data = blob.mutable_cpu_data(); 100 | for i in 0..count { 101 | // SAFETY: Blob data size and `rand_vec` data size both equal to `count`. 102 | unsafe { *data.get_unchecked_mut(i) *= T::from_i32(*rand_vec.get_unchecked(i)); } 103 | } 104 | } 105 | } 106 | } 107 | 108 | 109 | /// Fills a Blob with values $ x \in [0, 1] $ such that $ \forall i \sum_j x_{ij} = 1 $. 110 | pub struct PositiveUnitballFiller { 111 | filler_param: FillerParameter, 112 | } 113 | 114 | impl PositiveUnitballFiller { 115 | pub fn new(param: &FillerParameter) -> Self { 116 | PositiveUnitballFiller { 117 | filler_param: param.clone(), 118 | } 119 | } 120 | } 121 | 122 | impl Filler for PositiveUnitballFiller { 123 | fn fill(&self, blob: &mut Blob) { 124 | let count = blob.count(); 125 | assert_ne!(count, 0); 126 | caffe_rng_uniform(count, T::from_i32(0), T::from_i32(1), blob.mutable_cpu_data()); 127 | // We expect the filler to not be called very frequently, so we will 128 | // just use a simple implementation 129 | let num = blob.shape_idx(0) as usize; 130 | let dim = count / num; 131 | assert_ne!(dim, 0); 132 | let data = blob.mutable_cpu_data(); 133 | for i in 0..num { 134 | let mut sum = T::default(); 135 | for j in 0..dim { 136 | // SAFETY: max value of `i*dim+j` is `count` which is the data size. 137 | sum += unsafe { *data.get_unchecked(i * dim + j) }; 138 | } 139 | for j in 0..dim { 140 | // SAFETY: max value of `i*dim+j` is `count` which is the data size. 141 | unsafe { *data.get_unchecked_mut(i * dim + j) /= sum; } 142 | } 143 | } 144 | 145 | assert_eq!(self.filler_param.get_sparse(), -1, "Sparsity not supported by this Filler."); 146 | } 147 | } 148 | 149 | 150 | /// Fills a Blob with values $ x \sim U(-a, +a) $ where $ a $ is set inversely proportional 151 | /// to number of incoming nodes, outgoing nodes, or their average. 152 | /// 153 | /// A Filler based on the paper \[Bengio and Glorot 2010\]: Understanding the difficulty 154 | /// of training deep feedforward neuralnetworks. 155 | /// 156 | /// It fills the incoming matrix by randomly sampling uniform data from [-scale, scale] where 157 | /// scale = sqrt(3 / n) where n is the fan_in, fan_out, or their average, depending on the 158 | /// variance_norm option. You should make sure the input blob has shape (num, a, b, c) where 159 | /// a * b * c = fan_in and num * b * c = fan_out. Note that this is currently not the case 160 | /// for inner product layers. 161 | pub struct XavierFiller { 162 | filler_param: FillerParameter, 163 | } 164 | 165 | impl XavierFiller { 166 | pub fn new(param: &FillerParameter) -> Self { 167 | XavierFiller { 168 | filler_param: param.clone(), 169 | } 170 | } 171 | } 172 | 173 | impl Filler for XavierFiller { 174 | fn fill(&self, blob: &mut Blob) { 175 | let count = blob.count(); 176 | assert_ne!(count, 0); 177 | let n = get_fan(&self.filler_param, blob); 178 | let scale = T::sqrt(T::from_i32(3) / n); 179 | let mut neg_scale = T::default(); 180 | neg_scale -= scale; 181 | caffe_rng_uniform(count, neg_scale, scale, blob.mutable_cpu_data()); 182 | assert_eq!(self.filler_param.get_sparse(), -1, "Sparsity not supported by this Filler."); 183 | } 184 | } 185 | 186 | // Used in `XavierFiller`, `MSRAFiller`. 187 | fn get_fan(param: &FillerParameter, blob: &Blob) -> T { 188 | let count = blob.count(); 189 | let fan_in = count / blob.shape_idx(0) as usize; 190 | // Compatibility with ND blobs 191 | let fan_out = if blob.num_axes() > 1 { count / blob.shape_idx(1) as usize } else { count }; 192 | let variance = param.get_variance_norm(); 193 | if variance == FillerParameter_VarianceNorm::AVERAGE { 194 | T::from_f64((fan_in + fan_out) as f64 / 2f64) 195 | } else if variance == FillerParameter_VarianceNorm::FAN_OUT { 196 | T::from_usize(fan_out) 197 | } else { 198 | T::from_usize(fan_in) 199 | } 200 | } 201 | 202 | /// Fills a Blob with values $ x \sim N(0, \sigma^2) $ where $ \sigma^2 $ is set inversely 203 | /// proportional to number of incoming nodes, outgoing nodes, or their average. 204 | /// 205 | /// A Filler based on the paper [He, Zhang, Ren and Sun 2015]: Specifically accounts for 206 | /// ReLU nonlinearities. 207 | /// 208 | /// Aside: for another perspective on the scaling factor, see the derivation of [Saxe, 209 | /// McClelland, and Ganguli 2013 (v3)]. 210 | /// 211 | /// It fills the incoming matrix by randomly sampling Gaussian data with std = sqrt(2 / n) 212 | /// where n is the fan_in, fan_out, or their average, depending on the variance_norm option. 213 | /// You should make sure the input blob has shape (num, a, b, c) where a * b * c = fan_in 214 | /// and num * b * c = fan_out. Note that this is currently not the case for inner product layers. 215 | pub struct MSRAFiller { 216 | filler_param: FillerParameter, 217 | } 218 | 219 | impl MSRAFiller { 220 | pub fn new(param: &FillerParameter) -> Self { 221 | MSRAFiller { 222 | filler_param: param.clone() 223 | } 224 | } 225 | } 226 | 227 | impl Filler for MSRAFiller { 228 | fn fill(&self, blob: &mut Blob) { 229 | let count = blob.count(); 230 | assert_ne!(count, 0); 231 | let n = get_fan(&self.filler_param, blob); 232 | let mut std = T::from_i32(2); 233 | std /= n; 234 | std = T::sqrt(std); 235 | caffe_rng_gaussian(count, T::default(), std, blob.mutable_cpu_data()); 236 | assert_eq!(self.filler_param.get_sparse(), -1, "Sparsity not supported by this Filler."); 237 | } 238 | } 239 | 240 | 241 | /// Fills a Blob with coefficients for bilinear interpolation. 242 | /// 243 | /// A common use case is with the DeconvolutionLayer acting as upsampling. You can upsample 244 | /// a feature map with shape of (B, C, H, W) by any integer factor using the following proto. 245 | /// 246 | /// ``` proto 247 | /// layer { 248 | /// name: "upsample" 249 | /// type: "Deconvolution" 250 | /// bottom: "{{bottom_name}}" 251 | /// top: "{{top_name}}" 252 | /// convolution_param { 253 | /// kernel_size: {{2 * factor - factor % 2}} 254 | /// stride: {{factor}} 255 | /// num_output: {{C}} 256 | /// group: {{C}} 257 | /// pad: {{ceil((factor - 1) / 2.)}} 258 | /// weight_filler: { type: "bilinear" } 259 | /// bias_term: false 260 | /// } 261 | /// param { lr_mult: 0 decay_mult: 0 } 262 | /// } 263 | /// ``` 264 | /// 265 | /// Please use this by replacing `{{}}` with your values. By specifying 266 | /// `num_output: {{C}} group: {{C}}`, it behaves as channel-wise convolution. The filter 267 | /// shape of this deconvolution layer will be (C, 1, K, K) where K is `kernel_size`, and 268 | /// this filler will set a (K, K) interpolation kernel for every channel of the filter 269 | /// identically. The resulting shape of the top feature map will be (B, C, factor * H, factor * W). 270 | /// Note that the learning rate and the weight decay are set to 0 in order to keep coefficient 271 | /// values of bilinear interpolation unchanged during training. If you apply this to an image, 272 | /// this operation is equivalent to the following call in Python with `Scikit.Image`. 273 | /// 274 | /// ``` python 275 | /// out = skimage.transform.rescale(img, factor, mode='constant', cval=0) 276 | /// ``` 277 | pub struct BilinearFiller { 278 | filler_layer: FillerParameter, 279 | } 280 | 281 | impl BilinearFiller { 282 | pub fn new(param: &FillerParameter) -> Self { 283 | BilinearFiller { 284 | filler_layer: param.clone() 285 | } 286 | } 287 | } 288 | 289 | impl Filler for BilinearFiller { 290 | fn fill(&self, blob: &mut Blob) { 291 | assert_eq!(blob.num_axes(), 4, "Blob must be 4 dim."); 292 | let width = blob.width() as usize; 293 | let height = blob.height() as usize; 294 | assert_eq!(width, height, "Filter must be square"); 295 | let f = (width as f64 / 2f64).ceil(); 296 | let c = T::from_f64((width - 1) as f64 / (2f64 * f)); 297 | let f = T::from_f64(f); 298 | let count = blob.count(); 299 | let data = blob.mutable_cpu_data(); 300 | for i in 0..count { 301 | let mut x = T::from_usize(i % width); 302 | let mut y = T::from_usize((i / width) % height); 303 | x /= f; 304 | x -= c; 305 | let mut xx = T::from_i32(1); 306 | xx -= T::fabs(x); 307 | y /= f; 308 | y -= c; 309 | let mut yy = T::from_i32(1); 310 | yy -= T::fabs(y); 311 | 312 | xx *= yy; 313 | unsafe { *data.get_unchecked_mut(i) = xx; } 314 | } 315 | 316 | assert_eq!(self.filler_layer.get_sparse(), -1, "Sparsity not supported by this Filler."); 317 | } 318 | } 319 | 320 | 321 | /// Get a specific filler from the specification given in FillerParameter. 322 | pub fn get_filler(param: &FillerParameter) -> Box> { 323 | let ty = param.get_field_type(); 324 | match ty { 325 | "constant" => Box::new(ConstantFiller::new(param)), 326 | "gaussian" => Box::new(GaussianFiller::new(param)), 327 | "positive_unitball" => Box::new(PositiveUnitballFiller::new(param)), 328 | "uniform" => Box::new(UniformFiller::new(param)), 329 | "xavier" => Box::new(XavierFiller::new(param)), 330 | "msra" => Box::new(MSRAFiller::new(param)), 331 | "bilinear" => Box::new(BilinearFiller::new(param)), 332 | _ => panic!("Unknown filler name: {:?}", ty), 333 | } 334 | } 335 | -------------------------------------------------------------------------------- /src/layer.rs: -------------------------------------------------------------------------------- 1 | use std::boxed::Box; 2 | use std::cell::{RefCell, Ref, RefMut}; 3 | use std::rc::Rc; 4 | 5 | use protobuf::Clear; 6 | 7 | use crate::common::{Caffe, CaffeBrew}; 8 | use crate::blob::{Blob, BlobType, BlobMemRef}; 9 | use crate::proto::caffe; 10 | use crate::util::math_functions::{CaffeNum, caffe_set}; 11 | 12 | 13 | /// A typedef for **shared_ptr** of `Blob`. 14 | pub type SharedBlob = Rc>>; 15 | 16 | /// A typedef of **vector of blob**. 17 | pub type BlobVec = Vec>; 18 | 19 | /// A typedef for **shared_ptr** of `Layer`. 20 | pub type SharedLayer = Rc>>; 21 | 22 | /// A typedef of **vector of layer**. 23 | pub type LayerVec = Vec>; 24 | 25 | /// Helper function to create a `SharedBlob` object. 26 | #[inline] 27 | pub fn make_shared_blob(blob: Blob) -> SharedBlob { 28 | Rc::new(RefCell::new(blob)) 29 | } 30 | 31 | /// Helper function to create a `SharedLayer` object. 32 | #[inline] 33 | pub fn make_shared_layer(layer: Layer) -> SharedLayer { 34 | Rc::new(RefCell::new(layer)) 35 | } 36 | 37 | 38 | #[derive(Clone, Default)] 39 | pub struct LayerImpl { 40 | /// The protobuf that stores the layer parameters 41 | pub layer_param: caffe::LayerParameter, 42 | /// The phase: TRAIN or TEST 43 | pub phase: caffe::Phase, 44 | /// The vector that stores the learnable parameters as a set of blobs. 45 | pub blobs: BlobVec, 46 | /// Vector indicating whether to compute the diff of each param blob. 47 | pub param_propagate_down: Vec, 48 | /// The vector that indicates whether each top blob has a non-zero weight in 49 | /// the objective function. 50 | pub loss: Vec, 51 | } 52 | 53 | impl LayerImpl { 54 | pub fn new(param: &caffe::LayerParameter) -> Self { 55 | let mut layer = LayerImpl { 56 | layer_param: param.clone(), 57 | ..Default::default() 58 | }; 59 | 60 | // Set phase and copy blobs (if there are any). 61 | layer.phase = param.get_phase(); 62 | if !layer.layer_param.get_blobs().is_empty() { 63 | layer.blobs.reserve(layer.layer_param.get_blobs().len()); 64 | for x in layer.layer_param.get_blobs() { 65 | let mut blob = Blob::new(); 66 | blob.set_from_proto(x, true); 67 | let blob = Rc::new(RefCell::new(blob)); 68 | layer.blobs.push(blob); 69 | } 70 | } 71 | 72 | layer 73 | } 74 | } 75 | 76 | 77 | pub trait CaffeLayer { 78 | type DataType: BlobType; 79 | 80 | fn get_impl(&self) -> &LayerImpl; 81 | 82 | fn get_impl_mut(&mut self) -> &mut LayerImpl; 83 | 84 | fn layer_setup(&mut self, bottom: &BlobVec, top: &BlobVec) { 85 | def_layer_setup(self, bottom, top); 86 | } 87 | 88 | fn reshape(&mut self, bottom: &BlobVec, top: &BlobVec); 89 | 90 | fn to_proto(&self, param: &mut caffe::LayerParameter, write_diff: bool) { 91 | def_to_proto(self, param, write_diff); 92 | } 93 | 94 | fn layer_type(&self) -> &'static str; 95 | 96 | fn exact_num_bottom_blobs(&self) -> i32 { 97 | -1 98 | } 99 | 100 | fn min_bottom_blobs(&self) -> i32 { 101 | -1 102 | } 103 | 104 | fn max_bottom_blobs(&self) -> i32 { 105 | -1 106 | } 107 | 108 | fn exact_num_top_blobs(&self) -> i32 { 109 | -1 110 | } 111 | 112 | fn min_top_blobs(&self) -> i32 { 113 | -1 114 | } 115 | 116 | fn max_top_blobs(&self) -> i32 { 117 | -1 118 | } 119 | 120 | fn equal_num_bottom_top_blobs(&self) -> bool { 121 | false 122 | } 123 | 124 | fn auto_top_blobs(&self) -> bool { 125 | false 126 | } 127 | 128 | fn allow_force_backward(&self, _bottom_index: usize) -> bool { 129 | true 130 | } 131 | 132 | fn forward_cpu(&mut self, bottom: &BlobVec, top: &BlobVec); 133 | 134 | fn forward_gpu(&mut self, bottom: &BlobVec, top: &BlobVec) { 135 | self.forward_cpu(bottom, top); 136 | } 137 | 138 | fn backward_cpu(&mut self, top: &BlobVec, propagate_down: &Vec, 139 | bottom: &BlobVec); 140 | 141 | fn backward_gpu(&mut self, top: &BlobVec, propagate_down: &Vec, 142 | bottom: &BlobVec) { 143 | self.backward_cpu(top, propagate_down, bottom); 144 | } 145 | 146 | fn check_blob_counts(&self, bottom: &BlobVec, top: &BlobVec) { 147 | def_check_blob_counts(self, bottom, top); 148 | } 149 | } 150 | 151 | #[inline] 152 | pub fn def_layer_setup(_this: &mut Caffe, _bottom: &BlobVec, _top: &BlobVec) 153 | where 154 | T: BlobType, 155 | Caffe: CaffeLayer + ?Sized {} 156 | 157 | pub fn def_to_proto(this: &Caffe, param: &mut caffe::LayerParameter, write_diff: bool) 158 | where 159 | T: BlobType, 160 | Caffe: CaffeLayer + ?Sized { 161 | param.clear(); 162 | param.clone_from(&this.get_impl().layer_param); 163 | param.clear_blobs(); 164 | for blob in &this.get_impl().blobs { 165 | RefCell::borrow(blob.as_ref()).to_proto(param.mut_blobs().push_default(), write_diff); 166 | } 167 | } 168 | 169 | pub fn def_check_blob_counts(this: &Caffe, bottom: &BlobVec, top: &BlobVec) 170 | where 171 | T: BlobType, 172 | Caffe: CaffeLayer + ?Sized { 173 | if this.exact_num_bottom_blobs() >= 0 { 174 | let num = this.exact_num_bottom_blobs(); 175 | check_eq!(num, bottom.len() as i32, "{} Layer takes {} bottom blob(s) as input.", 176 | this.layer_type(), num); 177 | } 178 | if this.min_bottom_blobs() >= 0 { 179 | let num = this.min_bottom_blobs(); 180 | check_le!(num, bottom.len() as i32, "{} Layer takes at least {} bottom blob(s) as input.", 181 | this.layer_type(), num); 182 | } 183 | if this.max_bottom_blobs() >= 0 { 184 | let num = this.max_bottom_blobs(); 185 | check_ge!(num, bottom.len() as i32, "{} Layer takes at most {} bottom blob(s) as input.", 186 | this.layer_type(), num); 187 | } 188 | if this.exact_num_top_blobs() >= 0 { 189 | let num = this.exact_num_top_blobs(); 190 | check_eq!(num, top.len() as i32, "{} Layer produces {} top blob(s) as output.", 191 | this.layer_type(), num); 192 | } 193 | if this.min_top_blobs() >= 0 { 194 | let num = this.min_top_blobs(); 195 | check_le!(num, top.len() as i32, "{} Layer produces at least {} top blob(s) as output.", 196 | this.layer_type(), num); 197 | } 198 | if this.max_top_blobs() >= 0 { 199 | let num = this.max_top_blobs(); 200 | check_ge!(num, top.len() as i32, "{} Layer produces at most {} top blob(s) as output.", 201 | this.layer_type(), num); 202 | } 203 | if this.equal_num_bottom_top_blobs() { 204 | check_eq!(bottom.len(), top.len(), 205 | "{} Layer produces one top blob as output for each bottom blob input.", 206 | this.layer_type()); 207 | } 208 | } 209 | 210 | 211 | pub struct Layer { 212 | layer: Box>, 213 | } 214 | 215 | impl Layer { 216 | pub fn new(layer: Box>) -> Self { 217 | Layer { 218 | layer 219 | } 220 | } 221 | 222 | /// Implements common layer setup functionality. 223 | /// * `bottom`: the pre-shaped input blobs 224 | /// * `top`: the allocated but unshaped output blobs, to be shaped by Reshape 225 | /// 226 | /// Checks that the number of bottom and top blobs is correct. 227 | /// Calls [`layer_setup`][layer_setup] to do special layer setup for individual layer types, 228 | /// followed by [`reshape`][reshape] to set up sizes of top blobs and internal buffers. 229 | /// Sets up the loss weight multiplier blobs for any non-zero loss weights. 230 | /// This method may not be overridden. 231 | /// 232 | /// [layer_setup]: caffe_rs::Layer::layer_setup 233 | /// [reshape]: caffe_rs::Layer::reshape 234 | pub fn setup(&mut self, bottom: &BlobVec, top: &BlobVec) { 235 | self.layer.check_blob_counts(bottom, top); 236 | self.layer_setup(bottom, top); 237 | self.reshape(bottom, top); 238 | self.set_loss_weights(top); 239 | } 240 | 241 | pub fn layer_setup(&mut self, bottom: &BlobVec, top: &BlobVec) { 242 | self.layer.layer_setup(bottom, top); 243 | } 244 | 245 | pub fn reshape(&mut self, bottom: &BlobVec, top: &BlobVec) { 246 | self.layer.reshape(bottom, top); 247 | } 248 | 249 | pub fn forward(&mut self, bottom: &BlobVec, top: &BlobVec) -> T { 250 | let mut loss = T::from_f32(0f32); 251 | self.reshape(bottom, top); 252 | 253 | match Caffe::mode() { 254 | CaffeBrew::CPU => { 255 | // CPU mode 256 | self.layer.forward_cpu(bottom, top); 257 | for top_id in 0..top.len() { 258 | if self.loss(top_id).is_zero() { 259 | continue; 260 | } 261 | 262 | let blob = RefCell::borrow(top[top_id].as_ref()); 263 | let count = blob.count(); 264 | let BlobMemRef { data, diff } = blob.cpu_mem_ref(); 265 | loss += T::caffe_cpu_dot(count as i32, data, diff); 266 | } 267 | } 268 | CaffeBrew::GPU => { 269 | self.layer.forward_gpu(bottom, top); 270 | unimplemented!(); 271 | } 272 | } 273 | 274 | loss 275 | } 276 | 277 | pub fn backward(&mut self, top: &BlobVec, propagate_down: &Vec, bottom: &BlobVec) { 278 | match Caffe::mode() { 279 | CaffeBrew::CPU => { 280 | // CPU mode 281 | self.layer.backward_cpu(top, propagate_down, bottom); 282 | } 283 | CaffeBrew::GPU => { 284 | // GPU mode 285 | self.layer.backward_gpu(top, propagate_down, bottom); 286 | } 287 | } 288 | } 289 | 290 | pub fn blobs(&self) -> &BlobVec { 291 | &self.layer.get_impl().blobs 292 | } 293 | 294 | pub fn blobs_mut(&mut self) -> &mut BlobVec { 295 | &mut self.layer.get_impl_mut().blobs 296 | } 297 | 298 | pub fn layer_param(&self) -> &caffe::LayerParameter { 299 | &self.layer.get_impl().layer_param 300 | } 301 | 302 | pub fn to_proto(&self, param: &mut caffe::LayerParameter, write_diff: bool) { 303 | self.layer.to_proto(param, write_diff); 304 | } 305 | 306 | pub fn loss(&self, top_index: usize) -> T { 307 | let loss = &self.layer.get_impl().loss; 308 | if loss.len() > top_index { 309 | loss[top_index] 310 | } else { 311 | Default::default() 312 | } 313 | } 314 | 315 | pub fn set_loss(&mut self, top_index: usize, value: T) { 316 | let loss = &mut self.layer.get_impl_mut().loss; 317 | if loss.len() <= top_index { 318 | loss.resize(top_index + 1, Default::default()); 319 | } 320 | 321 | loss[top_index] = value; 322 | } 323 | 324 | pub fn layer_type(&self) -> &'static str { 325 | self.layer.layer_type() 326 | } 327 | 328 | pub fn exact_num_bottom_blobs(&self) -> i32 { 329 | self.layer.exact_num_bottom_blobs() 330 | } 331 | 332 | pub fn min_bottom_blobs(&self) -> i32 { 333 | self.layer.min_bottom_blobs() 334 | } 335 | 336 | pub fn max_bottom_blobs(&self) -> i32 { 337 | self.layer.max_bottom_blobs() 338 | } 339 | 340 | pub fn exact_num_top_blobs(&self) -> i32 { 341 | self.layer.exact_num_top_blobs() 342 | } 343 | 344 | pub fn min_top_blobs(&self) -> i32 { 345 | self.layer.min_top_blobs() 346 | } 347 | 348 | pub fn max_top_blobs(&self) -> i32 { 349 | self.layer.max_top_blobs() 350 | } 351 | 352 | pub fn equal_num_bottom_top_blobs(&self) -> bool { 353 | self.layer.equal_num_bottom_top_blobs() 354 | } 355 | 356 | pub fn auto_top_blobs(&self) -> bool { 357 | self.layer.auto_top_blobs() 358 | } 359 | 360 | pub fn allow_force_backward(&self, bottom_index: usize) -> bool { 361 | self.layer.allow_force_backward(bottom_index) 362 | } 363 | 364 | pub fn param_propagate_down(&self, param_id: usize) -> bool { 365 | let prop_down = &self.layer.get_impl().param_propagate_down; 366 | if prop_down.len() > param_id { 367 | prop_down[param_id] 368 | } else { 369 | false 370 | } 371 | } 372 | 373 | pub fn set_param_propagate_down(&mut self, param_id: usize, value: bool) { 374 | let prop_down = &mut self.layer.get_impl_mut().param_propagate_down; 375 | if prop_down.len() <= param_id { 376 | prop_down.resize(param_id + 1, true); 377 | } 378 | 379 | prop_down[param_id] = value; 380 | } 381 | 382 | fn set_loss_weights(&mut self, top: &BlobVec) { 383 | let num_loss_weights = self.layer.get_impl().layer_param.get_loss_weight().len(); 384 | if num_loss_weights == 0usize { 385 | return; 386 | } 387 | 388 | check_eq!(top.len(), num_loss_weights, "loss_weight must be unspecified or specified once per top blob."); 389 | 390 | for top_id in 0..top.len() { 391 | let loss_weight = self.layer.get_impl().layer_param.get_loss_weight()[top_id]; 392 | if loss_weight == 0f32 { 393 | continue; 394 | } 395 | 396 | self.set_loss(top_id, T::from_f32(loss_weight)); 397 | let mut blob = top[top_id].borrow_mut(); 398 | let count = blob.count(); 399 | let loss_multiplier = blob.mutable_cpu_diff(); 400 | caffe_set(count, T::from_f32(loss_weight), loss_multiplier); 401 | } 402 | } 403 | } 404 | -------------------------------------------------------------------------------- /src/data_transformer.rs: -------------------------------------------------------------------------------- 1 | use crate::common::{CaffeRng, Caffe}; 2 | use crate::blob::{BlobType, Blob}; 3 | use crate::proto::caffe::{TransformationParameter, Phase, BlobProto, Datum}; 4 | use crate::util::io::read_proto_from_binary_file_or_die; 5 | use crate::util::rng::caffe_rng_rand; 6 | 7 | 8 | pub struct DataTransformer { 9 | param: TransformationParameter, 10 | rng: Option, 11 | phase: Phase, 12 | data_mean: Blob, 13 | mean_values: Vec 14 | } 15 | 16 | impl DataTransformer { 17 | pub fn new(param: &TransformationParameter, phase: Phase) -> Self { 18 | let mut this = DataTransformer { 19 | param: param.clone(), 20 | rng: None, 21 | phase, 22 | data_mean: Default::default(), 23 | mean_values: Default::default(), 24 | }; 25 | 26 | // Check if we want to use mean_file 27 | if param.has_mean_file() { 28 | assert!(param.get_mean_value().is_empty(), "Cannot specify mean_file and mean_value at the same time"); 29 | let mean_file = param.get_mean_file(); 30 | if Caffe::root_solver() { 31 | info!("Loading mean file from: {:?}", mean_file); 32 | } 33 | 34 | let mut blob_proto = BlobProto::new(); 35 | read_proto_from_binary_file_or_die(mean_file, &mut blob_proto); 36 | this.data_mean.set_from_proto(&blob_proto, true); 37 | } 38 | // Check if we want to use mean_value 39 | if !param.get_mean_value().is_empty() { 40 | assert!(!param.has_mean_file(), "Cannot specify mean_file and mean_value at the same time"); 41 | let mean_value = param.get_mean_value(); 42 | this.mean_values.reserve(mean_value.len()); 43 | for &c in mean_value { 44 | this.mean_values.push(T::from_f32(c)); 45 | } 46 | } 47 | 48 | this 49 | } 50 | 51 | pub fn init_rand(&mut self) { 52 | let needs_rand = self.param.get_mirror() || 53 | (self.phase == Phase::TRAIN && self.param.get_crop_size() != 0); 54 | if needs_rand { 55 | let rng_seed = caffe_rng_rand(); 56 | self.rng = Some(CaffeRng::new_with_seed(rng_seed as u64)); 57 | } else { 58 | self.rng = None; 59 | } 60 | } 61 | 62 | pub fn transform_datum(&mut self, datum: &Datum, transformed_data: &mut [T]) { 63 | let data = datum.get_data(); 64 | let datum_channels = datum.get_channels(); 65 | let datum_height = datum.get_height(); 66 | let datum_width = datum.get_width(); 67 | 68 | let crop_size = self.param.get_crop_size() as i32; 69 | let scale = T::from_f32(self.param.get_scale()); 70 | let do_mirror = self.param.get_mirror() && self.rand(2) != 0; 71 | let has_mean_file = self.param.has_mean_file(); 72 | let has_uint8 = !data.is_empty(); 73 | let has_mean_values = !self.mean_values.is_empty(); 74 | 75 | check_gt!(datum_channels, 0); 76 | check_ge!(datum_height, crop_size); 77 | check_ge!(datum_width, crop_size); 78 | 79 | if has_mean_file { 80 | assert_eq!(datum_channels, self.data_mean.channels()); 81 | assert_eq!(datum_height, self.data_mean.height()); 82 | assert_eq!(datum_width, self.data_mean.width()); 83 | } 84 | if has_mean_values { 85 | assert!(self.mean_values.len() == 1 || self.mean_values.len() == datum_channels as usize, 86 | "Specify either 1 mean_value or as many as channels: {:?}", datum_channels); 87 | if datum_channels > 1 && self.mean_values.len() == 1 { 88 | // Replicate the mean_value for simplicity 89 | let v = self.mean_values[0]; 90 | for _c in 1..datum_channels { 91 | self.mean_values.push(v); 92 | } 93 | } 94 | } 95 | 96 | let mut height = datum_height; 97 | let mut width = datum_width; 98 | let mut h_off = 0; 99 | let mut w_off = 0; 100 | if crop_size != 0 { 101 | height = crop_size; 102 | width = crop_size; 103 | // We only do random crop when we do training. 104 | if self.phase == Phase::TRAIN { 105 | h_off = self.rand(datum_height - crop_size + 1); 106 | w_off = self.rand(datum_width - crop_size + 1); 107 | } else { 108 | h_off = (datum_height - crop_size) / 2; 109 | w_off = (datum_width - crop_size) / 2; 110 | } 111 | } 112 | 113 | let mut mean = None; 114 | if has_mean_file { 115 | mean = Some(self.data_mean.cpu_data()); 116 | } 117 | for c in 0..datum_channels { 118 | for h in 0..height { 119 | for w in 0..width { 120 | let data_index = (c * datum_height + h_off + h) * datum_width + w_off + w; 121 | let top_index = if do_mirror { 122 | (c * height + h) * width + (width - 1 - w) 123 | } else { 124 | (c * height + h) * width + w 125 | }; 126 | let data_index = data_index as usize; 127 | let top_index = top_index as usize; 128 | let mut datum_element = if has_uint8 { 129 | T::from_i32(data[data_index] as i32) 130 | } else { 131 | T::from_f32(datum.get_float_data()[data_index]) 132 | }; 133 | 134 | if has_mean_file { 135 | datum_element -= mean.unwrap()[data_index]; 136 | datum_element *= scale; 137 | transformed_data[top_index] = datum_element; 138 | } else { 139 | if has_mean_values { 140 | datum_element -= self.mean_values[c as usize]; 141 | datum_element *= scale; 142 | transformed_data[top_index] = datum_element; 143 | } else { 144 | datum_element *= scale; 145 | transformed_data[top_index] = datum_element; 146 | } 147 | } 148 | } 149 | } 150 | } 151 | } 152 | 153 | pub fn transform_datum_blob(&mut self, datum: &Datum, transformed_blob: &mut Blob) { 154 | // If datum is encoded, decode and transform the cv::image. 155 | if datum.get_encoded() { 156 | todo!("OpenCV"); 157 | assert!(false, "Encoded datum requires OpenCV"); 158 | } else { 159 | if self.param.get_force_color() || self.param.get_force_gray() { 160 | error!("force_color and force_gray only for encoded datum"); 161 | } 162 | } 163 | 164 | let crop_size = self.param.get_crop_size() as i32; 165 | let datum_channels = datum.get_channels(); 166 | let datum_height = datum.get_height(); 167 | let datum_width = datum.get_width(); 168 | 169 | // Check dimensions. 170 | let channels = transformed_blob.channels(); 171 | let height = transformed_blob.height(); 172 | let width = transformed_blob.width(); 173 | let num = transformed_blob.num(); 174 | 175 | assert_eq!(channels, datum_channels); 176 | check_le!(height, datum_height); 177 | check_le!(width, datum_width); 178 | check_ge!(num, 1); 179 | 180 | if crop_size != 0 { 181 | assert_eq!(crop_size, height); 182 | assert_eq!(crop_size, width); 183 | } else { 184 | assert_eq!(datum_height, height); 185 | assert_eq!(datum_width, width); 186 | } 187 | 188 | let transformed_data = transformed_blob.mutable_cpu_data(); 189 | self.transform_datum(datum, transformed_data); 190 | } 191 | 192 | pub fn transform_datum_vec(&mut self, datum_vector: &Vec, transformed_blob: &mut Blob) { 193 | let datum_num = datum_vector.len(); 194 | let num = transformed_blob.num(); 195 | let channels = transformed_blob.channels(); 196 | let height = transformed_blob.height(); 197 | let width = transformed_blob.width(); 198 | 199 | assert!(datum_num > 0, "There is no datum to add"); 200 | assert!(datum_num <= num as usize, "The size of datum_vector must be no greater than transformed_blob->num()"); 201 | let shape = vec![1, channels, height, width]; 202 | let mut uni_blob = Blob::with_shape(&shape); 203 | for item_id in 0..datum_num { 204 | let offset = transformed_blob.offset(item_id as i32, 0, 0, 0); 205 | let data = transformed_blob.cpu_data_shared().offset(offset); 206 | uni_blob.set_cpu_data(&data); 207 | self.transform_datum_blob(&datum_vector[item_id], &mut uni_blob); 208 | } 209 | } 210 | 211 | pub fn transform_blob(&mut self, input_blob: &mut Blob, transformed_blob: &mut Blob) { 212 | let crop_size = self.param.get_crop_size() as i32; 213 | let input_num = input_blob.num(); 214 | let input_channels = input_blob.channels(); 215 | let input_height = input_blob.height(); 216 | let input_width = input_blob.width(); 217 | 218 | if transformed_blob.count() == 0 { 219 | // Initialize transformed_blob with the right shape. 220 | if crop_size != 0 { 221 | let shape = vec![input_num, input_channels, crop_size, crop_size]; 222 | transformed_blob.reshape(&shape); 223 | } else { 224 | let shape = vec![input_num, input_channels, input_height, input_width]; 225 | transformed_blob.reshape(&shape); 226 | } 227 | } 228 | 229 | let num = transformed_blob.num(); 230 | let channels = transformed_blob.channels(); 231 | let height = transformed_blob.height(); 232 | let width = transformed_blob.width(); 233 | let size = transformed_blob.count(); 234 | 235 | check_le!(input_num, num); 236 | assert_eq!(input_channels, channels); 237 | check_ge!(input_height, height); 238 | check_ge!(input_width, width); 239 | 240 | let scale = self.param.get_scale(); 241 | let do_mirror = self.param.get_mirror() && self.rand(2) != 0; 242 | let has_mean_file = self.param.has_mean_file(); 243 | let has_mean_values = !self.mean_values.is_empty(); 244 | 245 | let mut h_off = 0; 246 | let mut w_off = 0; 247 | if crop_size != 0 { 248 | assert_eq!(crop_size, height); 249 | assert_eq!(crop_size, width); 250 | // We only do random crop when we do training. 251 | let height_diff = input_height - crop_size; 252 | if self.phase == Phase::TRAIN { 253 | h_off = self.rand(height_diff + 1); 254 | w_off = self.rand(input_width - crop_size + 1); 255 | } else { 256 | h_off = (height_diff) / 2; 257 | w_off = (input_width - crop_size) / 2; 258 | }; 259 | } else { 260 | assert_eq!(input_height, height); 261 | assert_eq!(input_width, width); 262 | } 263 | 264 | // SAFETY: mutable slice data borrowed partial which is not accessed in later `input_blob` 265 | // immutable read. 266 | let input_data = unsafe { 267 | let data = input_blob.mutable_cpu_data(); 268 | std::slice::from_raw_parts_mut(data.as_mut_ptr(), data.len()) 269 | }; 270 | if has_mean_file { 271 | assert_eq!(input_channels, self.data_mean.channels()); 272 | assert_eq!(input_height, self.data_mean.height()); 273 | assert_eq!(input_width, self.data_mean.width()); 274 | for n in 0..input_num { 275 | let offset = input_blob.offset(n, 0, 0, 0); 276 | T::caffe_sub_assign(self.data_mean.count(), &mut input_data[offset as usize..], 277 | self.data_mean.cpu_data()); 278 | } 279 | } 280 | 281 | if has_mean_values { 282 | assert!(self.mean_values.len() == 1 || self.mean_values.len() == input_channels as usize, 283 | "Specify either 1 mean_value or as many as channels: {:?}", input_channels); 284 | if self.mean_values.len() == 1 { 285 | let mut alpha = T::default(); 286 | alpha -= self.mean_values[0]; 287 | T::caffe_add_scalar(input_blob.count(), alpha, input_data); 288 | } else { 289 | for n in 0..input_num { 290 | for c in 0..input_channels { 291 | let offset = input_blob.offset(n, c, 0, 0); 292 | let count = (input_height * input_width) as usize; 293 | let mut alpha = T::default(); 294 | alpha -= self.mean_values[c as usize]; 295 | T::caffe_add_scalar(count, alpha, &mut input_data[offset as usize..]); 296 | } 297 | } 298 | } 299 | } 300 | 301 | let transformed_data = transformed_blob.mutable_cpu_data(); 302 | for n in 0..input_num { 303 | let top_index_n = n * channels; 304 | let data_index_n = n * channels; 305 | for c in 0..channels { 306 | let top_index_c = (top_index_n + c) * height; 307 | let data_index_c = (data_index_n + c) * input_height + h_off; 308 | for h in 0..height { 309 | let top_index_h = (top_index_c + h) * width; 310 | let data_index_h = (data_index_c + h) * input_width + w_off; 311 | if do_mirror { 312 | let top_index_w = top_index_h + width - 1; 313 | for w in 0..width { 314 | transformed_data[(top_index_w - w) as usize] = input_data[(data_index_h + w) as usize]; 315 | } 316 | } else { 317 | for w in 0..width { 318 | transformed_data[(top_index_h + w) as usize] = input_data[(data_index_h + w) as usize]; 319 | } 320 | } 321 | } 322 | } 323 | } 324 | if scale != 1f32 { 325 | info!("Scale: {}", scale); 326 | T::caffe_scal(size as i32, T::from_f32(scale), transformed_data); 327 | } 328 | } 329 | 330 | pub fn infer_blob_shape(&self, datum: &Datum) -> Vec { 331 | if datum.get_encoded() { 332 | assert!(false, "Encoded datum requires OpenCV; compile with USE_OPENCV."); 333 | } 334 | 335 | let crop_size = self.param.get_crop_size() as i32; 336 | let datum_channels = datum.get_channels(); 337 | let datum_height = datum.get_height(); 338 | let datum_width = datum.get_width(); 339 | // Check dimensions. 340 | check_gt!(datum_channels, 0); 341 | check_ge!(datum_height, crop_size); 342 | check_ge!(datum_width, crop_size); 343 | // Build BlobShape. 344 | let height = if crop_size != 0 { crop_size } else { datum_height }; 345 | let width = if crop_size != 0 { crop_size } else { datum_width }; 346 | vec![1, datum_channels, height, width] 347 | } 348 | 349 | pub fn infer_blob_shape_vec(&self, datum_vector: &Vec) -> Vec { 350 | let num = datum_vector.len(); 351 | check_gt!(num, 0, "There is no datum to in the vector"); 352 | // Use first datum in the vector to InferBlobShape. 353 | let mut shape = self.infer_blob_shape(datum_vector.first().unwrap()); 354 | // Adjust num to the size of the vector. 355 | shape[0] = num as i32; 356 | shape 357 | } 358 | 359 | fn rand(&mut self, n: i32) -> i32 { 360 | assert!(self.rng.is_some()); 361 | assert!(n > 0); 362 | let r = self.rng.as_mut().unwrap().generator().next_u32() as i32; 363 | r % n 364 | } 365 | } 366 | -------------------------------------------------------------------------------- /src/layers/batch_norm_layer.rs: -------------------------------------------------------------------------------- 1 | use std::rc::Rc; 2 | 3 | use cblas::Transpose; 4 | 5 | use crate::blob::{BlobType, Blob}; 6 | use crate::layer::{LayerImpl, CaffeLayer, BlobVec, def_layer_setup, SharedBlob, make_shared_blob}; 7 | use crate::proto::caffe::{LayerParameter, Phase}; 8 | use crate::util::math_functions::{caffe_set, caffe_copy}; 9 | 10 | 11 | pub struct BatchNormLayer { 12 | layer: LayerImpl, 13 | mean: Blob, 14 | variance: Blob, 15 | temp: Blob, 16 | x_norm: Blob, 17 | use_global_stats: bool, 18 | moving_average_fraction: T, 19 | channels: i32, 20 | eps: T, 21 | 22 | // extra temporary variables is used to carry out sums/broadcasting using BLAS 23 | batch_sum_multiplier: Blob, 24 | num_by_chans: Blob, 25 | spatial_sum_multiplier: Blob, 26 | } 27 | 28 | impl BatchNormLayer { 29 | pub fn new(param: &LayerParameter) -> Self { 30 | Self { 31 | layer: LayerImpl::new(param), 32 | mean: Blob::new(), 33 | variance: Blob::new(), 34 | temp: Blob::new(), 35 | x_norm: Blob::new(), 36 | use_global_stats: false, 37 | moving_average_fraction: T::default(), 38 | channels: 0, 39 | eps: T::default(), 40 | batch_sum_multiplier: Blob::new(), 41 | num_by_chans: Blob::new(), 42 | spatial_sum_multiplier: Blob::new(), 43 | } 44 | } 45 | 46 | fn backward_cpu_impl(&mut self, top_diff: &[T], bottom_diff: &mut [T], num: i32, spatial_dim: i32) { 47 | if self.use_global_stats { 48 | T::caffe_div(self.temp.count(), top_diff, self.temp.cpu_data(), bottom_diff); 49 | return; 50 | } 51 | 52 | let top_data = self.x_norm.cpu_data(); 53 | // if Y = (X-mean(X))/(sqrt(var(X)+eps)), then 54 | // 55 | // dE(Y)/dX = 56 | // (dE/dY - mean(dE/dY) - mean(dE/dY \cdot Y) \cdot Y) 57 | // ./ sqrt(var(X) + eps) 58 | // 59 | // where \cdot and ./ are hadamard product and elementwise division, 60 | // respectively, dE/dY is the top diff, and mean/var/sum are all computed 61 | // along all dimensions except the channels dimension. In the above 62 | // equation, the operations allow for expansion (i.e. broadcast) along all 63 | // dimensions except the channels dimension where required. 64 | 65 | // sum(dE/dY \cdot Y) 66 | T::caffe_mul(self.temp.count(), top_data, top_diff, bottom_diff); 67 | T::caffe_cpu_gemv(Transpose::None, self.channels * num, spatial_dim, T::from_i32(1), 68 | bottom_diff, self.spatial_sum_multiplier.cpu_data(), T::default(), 69 | self.num_by_chans.mutable_cpu_data()); 70 | T::caffe_cpu_gemv(Transpose::Ordinary, num, self.channels, T::from_i32(1), 71 | self.num_by_chans.cpu_data(), self.batch_sum_multiplier.cpu_data(), 72 | T::default(), self.mean.mutable_cpu_data()); 73 | 74 | // reshape (broadcast) the above 75 | T::caffe_cpu_gemm(Transpose::None, Transpose::None, num, self.channels, 1, T::from_i32(1), 76 | self.batch_sum_multiplier.cpu_data(), self.mean.cpu_data(), T::default(), 77 | self.num_by_chans.mutable_cpu_data()); 78 | T::caffe_cpu_gemm(Transpose::None, Transpose::None, self.channels * num, spatial_dim, 1, 79 | T::from_i32(1), self.num_by_chans.cpu_data(), 80 | self.spatial_sum_multiplier.cpu_data(), T::default(), bottom_diff); 81 | 82 | // sum(dE/dY \cdot Y) \cdot Y 83 | T::caffe_mul_assign(self.temp.count(), bottom_diff, top_data); 84 | 85 | // sum(dE/dY)-sum(dE/dY \cdot Y) \cdot Y 86 | T::caffe_cpu_gemv(Transpose::None, self.channels * num, spatial_dim, T::from_i32(1), 87 | top_diff, self.spatial_sum_multiplier.cpu_data(), T::default(), 88 | self.num_by_chans.mutable_cpu_data()); 89 | T::caffe_cpu_gemv(Transpose::Ordinary, num, self.channels, T::from_i32(1), 90 | self.num_by_chans.cpu_data(), self.batch_sum_multiplier.cpu_data(), T::default(), 91 | self.mean.mutable_cpu_data()); 92 | 93 | // reshape (broadcast) the above to make 94 | // sum(dE/dY)-sum(dE/dY \cdot Y) \cdot Y 95 | T::caffe_cpu_gemm(Transpose::None, Transpose::None, num, self.channels, 1, T::from_i32(1), 96 | self.batch_sum_multiplier.cpu_data(), self.mean.cpu_data(), T::default(), 97 | self.num_by_chans.mutable_cpu_data()); 98 | T::caffe_cpu_gemm(Transpose::None, Transpose::None, self.channels * num, spatial_dim, 1, 99 | T::from_i32(1), self.num_by_chans.cpu_data(), 100 | self.spatial_sum_multiplier.cpu_data(), T::from_i32(1), bottom_diff); 101 | 102 | // dE/dY - mean(dE/dY)-mean(dE/dY \cdot Y) \cdot Y 103 | T::caffe_cpu_axpby(self.temp.count() as i32, T::from_i32(1), top_diff, 104 | T::from_f64(-1f64 / (num * spatial_dim) as f64), bottom_diff); 105 | 106 | // note: self.temp still contains sqrt(var(X)+eps), computed during the forward pass. 107 | // SAFETY: `caffe_div` do an element-wise op of the slice items separately. 108 | let im_bottom_diff = unsafe { 109 | std::slice::from_raw_parts(bottom_diff.as_ptr(), bottom_diff.len()) 110 | }; 111 | T::caffe_div(self.temp.count(), im_bottom_diff, self.temp.cpu_data(), bottom_diff); 112 | } 113 | } 114 | 115 | impl CaffeLayer for BatchNormLayer { 116 | type DataType = T; 117 | 118 | fn get_impl(&self) -> &LayerImpl { 119 | &self.layer 120 | } 121 | 122 | fn get_impl_mut(&mut self) -> &mut LayerImpl { 123 | &mut self.layer 124 | } 125 | 126 | fn layer_setup(&mut self, bottom: &BlobVec, _top: &BlobVec) { 127 | let param = self.layer.layer_param.get_batch_norm_param(); 128 | self.moving_average_fraction = T::from_f32(param.get_moving_average_fraction()); 129 | self.use_global_stats = self.layer.phase == Phase::TEST; 130 | if param.has_use_global_stats() { 131 | self.use_global_stats = param.get_use_global_stats(); 132 | } 133 | 134 | let b0 = &bottom[0]; 135 | if b0.as_ref().borrow().num_axes() == 1 { 136 | self.channels = 1; 137 | } else { 138 | self.channels = b0.as_ref().borrow().shape_idx(1); 139 | } 140 | self.eps = T::from_f32(param.get_eps()); 141 | 142 | if !self.layer.blobs.is_empty() { 143 | info!("Skipping parameter initialization."); 144 | } else { 145 | self.layer.blobs.reserve(3); 146 | let mut sz = vec![self.channels]; 147 | self.layer.blobs.push(make_shared_blob(Blob::with_shape(&sz))); 148 | self.layer.blobs.push(make_shared_blob(Blob::with_shape(&sz))); 149 | sz[0] = 1; 150 | self.layer.blobs.push(make_shared_blob(Blob::with_shape(&sz))); 151 | for blob in &self.layer.blobs { 152 | let mut blob = blob.borrow_mut(); 153 | let count = blob.count(); 154 | caffe_set(count, T::default(), blob.mutable_cpu_data()); 155 | } 156 | } 157 | 158 | // Mask statistics from optimization by setting local learning rates 159 | // for mean, variance, and the bias correction to zero. 160 | for i in 0..self.layer.blobs.len() { 161 | if self.layer.layer_param.get_param().len() == i { 162 | let fixed_param_spec = self.layer.layer_param.mut_param().push_default(); 163 | fixed_param_spec.set_lr_mult(0f32); 164 | } else { 165 | assert_eq!(self.layer.layer_param.get_param()[i].get_lr_mult(), 0f32, 166 | "Cannot configure batch normalization statistics as layer parameters."); 167 | } 168 | } 169 | } 170 | 171 | fn reshape(&mut self, bottom: &BlobVec, top: &BlobVec) { 172 | let b0 = bottom[0].as_ref().borrow(); 173 | if b0.num_axes() >= 1 { 174 | assert_eq!(b0.shape_idx(1), self.channels); 175 | } 176 | top[0].borrow_mut().reshape_like(&*b0); 177 | 178 | let mut sz = vec![self.channels]; 179 | self.mean.reshape(&sz); 180 | self.variance.reshape(&sz); 181 | self.temp.reshape(&sz); 182 | self.x_norm.reshape(&sz); 183 | sz[0] = b0.shape_idx(0); 184 | self.batch_sum_multiplier.reshape(&sz); 185 | 186 | let spatial_dim = b0.count() as i32 / (self.channels * b0.shape_idx(0)); 187 | if self.spatial_sum_multiplier.num_axes() == 0 || self.spatial_sum_multiplier.shape_idx(0) != spatial_dim { 188 | sz[0] = spatial_dim; 189 | self.spatial_sum_multiplier.reshape(&sz); 190 | let count = self.spatial_sum_multiplier.count(); 191 | let multiplier_data = self.spatial_sum_multiplier.mutable_cpu_data(); 192 | caffe_set(count, T::from_i32(1), multiplier_data); 193 | } 194 | 195 | let num_by_chans = self.channels * b0.shape_idx(0); 196 | if self.num_by_chans.num_axes() == 0 || self.num_by_chans.shape_idx(0) != num_by_chans { 197 | sz[0] = num_by_chans; 198 | self.num_by_chans.reshape(&sz); 199 | let count = self.batch_sum_multiplier.count(); 200 | caffe_set(count, T::from_i32(1), self.batch_sum_multiplier.mutable_cpu_data()); 201 | } 202 | } 203 | 204 | fn layer_type(&self) -> &'static str { 205 | "BatchNorm" 206 | } 207 | 208 | fn exact_num_bottom_blobs(&self) -> i32 { 209 | 1 210 | } 211 | 212 | fn exact_num_top_blobs(&self) -> i32 { 213 | 1 214 | } 215 | 216 | fn forward_cpu(&mut self, bottom: &BlobVec, top: &BlobVec) { 217 | let mut t0 = top[0].borrow_mut(); 218 | let t0_count = t0.count(); 219 | 220 | let (num, count, spatial_dim, top_data) = if !Rc::ptr_eq(&bottom[0], &top[0]) { 221 | let b0 = bottom[0].as_ref().borrow(); 222 | let num = b0.shape_idx(0); 223 | let count = b0.count(); 224 | let spatial_dim = count as i32 / (num * self.channels); 225 | let top_data = t0.mutable_cpu_data(); 226 | caffe_copy(count, b0.cpu_data(), top_data); 227 | (num, count, spatial_dim, top_data) 228 | } else { 229 | let num = t0.shape_idx(0); 230 | let count = t0.count(); 231 | let spatial_dim = count as i32 / (num * self.channels); 232 | (num, count, spatial_dim, t0.mutable_cpu_data()) 233 | }; 234 | 235 | if self.use_global_stats { 236 | // use the stored mean/variance estimates. 237 | let scale_factor = self.layer.blobs[2].as_ref().borrow().cpu_data()[0]; 238 | let scale_factor = if scale_factor.is_zero() { 239 | scale_factor 240 | } else { 241 | let mut r = T::from_i32(1); 242 | r /= scale_factor; 243 | r 244 | }; 245 | let n = self.variance.count() as i32; 246 | T::caffe_cpu_scale(n, scale_factor, self.layer.blobs[0].as_ref().borrow().cpu_data(), 247 | self.mean.mutable_cpu_data()); 248 | T::caffe_cpu_scale(n, scale_factor, self.layer.blobs[1].as_ref().borrow().cpu_data(), 249 | self.variance.mutable_cpu_data()); 250 | } else { 251 | // compute mean 252 | if Rc::ptr_eq(&bottom[0], &top[0]) { 253 | T::caffe_cpu_gemv(Transpose::None, self.channels * num, spatial_dim, 254 | T::from_f64(1f64 / (num * spatial_dim) as f64), top_data, 255 | self.spatial_sum_multiplier.cpu_data(), T::default(), 256 | self.num_by_chans.mutable_cpu_data()); 257 | } else { 258 | let b0 = bottom[0].as_ref().borrow(); 259 | T::caffe_cpu_gemv(Transpose::None, self.channels * num, spatial_dim, 260 | T::from_f64(1f64 / (num * spatial_dim) as f64), b0.cpu_data(), 261 | self.spatial_sum_multiplier.cpu_data(), T::default(), 262 | self.num_by_chans.mutable_cpu_data()); 263 | } 264 | T::caffe_cpu_gemv(Transpose::Ordinary, num, self.channels, T::from_i32(1), 265 | self.num_by_chans.cpu_data(), self.batch_sum_multiplier.cpu_data(), T::default(), 266 | self.mean.mutable_cpu_data()); 267 | } 268 | 269 | // subtract mean 270 | T::caffe_cpu_gemm(Transpose::None, Transpose::None, num, self.channels, 1, T::from_i32(1), 271 | self.batch_sum_multiplier.cpu_data(), self.mean.cpu_data(), T::default(), 272 | self.num_by_chans.mutable_cpu_data()); 273 | T::caffe_cpu_gemm(Transpose::None, Transpose::None, self.channels * num, spatial_dim, 1, 274 | T::from_i32(-1), self.num_by_chans.cpu_data(), self.spatial_sum_multiplier.cpu_data(), 275 | T::from_i32(1), top_data); 276 | 277 | if !self.use_global_stats { 278 | // compute variance using var(X) = E((X-EX)^2) 279 | T::caffe_sqr(t0_count, top_data, self.temp.mutable_cpu_data()); // (X-EX)^2 280 | T::caffe_cpu_gemv(Transpose::None, self.channels * num, spatial_dim, 281 | T::from_f64(1f64 / (num * spatial_dim) as f64), self.temp.cpu_data(), 282 | self.spatial_sum_multiplier.cpu_data(), T::default(), 283 | self.num_by_chans.mutable_cpu_data()); 284 | T::caffe_cpu_gemv(Transpose::Ordinary, num, self.channels, T::from_i32(1), 285 | self.num_by_chans.cpu_data(), self.batch_sum_multiplier.cpu_data(), T::default(), 286 | self.variance.mutable_cpu_data()); // E((X_EX)^2) 287 | 288 | // compute and save moving average 289 | let mut blob = self.layer.blobs[2].borrow_mut(); 290 | let blob = blob.mutable_cpu_data(); 291 | blob[0] *= self.moving_average_fraction; 292 | blob[0] += T::from_i32(1); 293 | T::caffe_cpu_axpby(self.mean.count() as i32, T::from_i32(1), self.mean.cpu_data(), 294 | self.moving_average_fraction, 295 | self.layer.blobs[0].as_ref().borrow_mut().mutable_cpu_data()); 296 | let m = count as i32 / self.channels; 297 | let bias_correction_factor = if m > 1 { 298 | let mut t = T::from_i32(m); 299 | t /= T::from_i32(m - 1); 300 | t 301 | } else { 302 | T::from_i32(1) 303 | }; 304 | T::caffe_cpu_axpby(self.variance.count() as i32, bias_correction_factor, self.variance.cpu_data(), 305 | self.moving_average_fraction, 306 | self.layer.blobs[0].as_ref().borrow_mut().mutable_cpu_data()); 307 | } 308 | 309 | // normalize variance 310 | T::caffe_add_scalar(self.variance.count(), self.eps, self.variance.mutable_cpu_data()); 311 | { 312 | // SAFETY: the `caffe_sqrt` do an element-wise op on each item of slice separately. 313 | let mut_data = unsafe { 314 | let d = self.variance.mutable_cpu_data(); 315 | std::slice::from_raw_parts_mut(d.as_mut_ptr(), d.len()) 316 | }; 317 | T::caffe_sqrt(self.variance.count(), self.variance.cpu_data(), mut_data); 318 | } 319 | 320 | // replicate variance to input size 321 | T::caffe_cpu_gemm(Transpose::None, Transpose::None, num, self.channels, 1, T::from_i32(1), 322 | self.batch_sum_multiplier.cpu_data(), self.variance.cpu_data(), T::default(), 323 | self.num_by_chans.mutable_cpu_data()); 324 | T::caffe_cpu_gemm(Transpose::None, Transpose::None, self.channels * num, spatial_dim, 1, 325 | T::from_i32(1), self.num_by_chans.cpu_data(), self.spatial_sum_multiplier.cpu_data(), 326 | T::default(), self.temp.mutable_cpu_data()); 327 | // SAFETY: the `caffe_div` do an element-wise op on each item of slice separately. 328 | let im_top_data = unsafe { 329 | std::slice::from_raw_parts(top_data.as_ptr(), top_data.len()) 330 | }; 331 | T::caffe_div(self.temp.count(), im_top_data, self.temp.cpu_data(), top_data); 332 | 333 | caffe_copy(self.x_norm.count(), top_data, self.x_norm.mutable_cpu_data()); 334 | } 335 | 336 | fn forward_gpu(&mut self, bottom: &BlobVec, top: &BlobVec) { 337 | no_gpu!(); 338 | } 339 | 340 | fn backward_cpu(&mut self, top: &BlobVec, _propagate_down: &Vec, bottom: &BlobVec) { 341 | if Rc::ptr_eq(&bottom[0], &top[0]) { 342 | let mut t0 = top[0].borrow_mut(); 343 | let count = self.x_norm.count(); 344 | caffe_copy(count, t0.cpu_diff(), self.x_norm.mutable_cpu_diff()); 345 | // SAFETY: `backward_cpu_impl()` does not access the `self.x_norm.cpu_diff()` memory. 346 | let top_diff = unsafe { 347 | let diff = self.x_norm.cpu_diff(); 348 | std::slice::from_raw_parts(diff.as_ptr(), diff.len()) 349 | }; 350 | let num = t0.shape()[0]; 351 | let spatial_dim = t0.count() as i32 / (t0.shape_idx(0) * self.channels); 352 | let bottom_diff = t0.mutable_cpu_diff(); 353 | self.backward_cpu_impl(top_diff, bottom_diff, num, spatial_dim); 354 | } else { 355 | let t0 = top[0].as_ref().borrow(); 356 | let mut b0 = bottom[0].borrow_mut(); 357 | let top_diff = t0.cpu_diff(); 358 | let num = b0.shape()[0]; 359 | let spatial_dim = b0.count() as i32 / (b0.shape_idx(0) * self.channels); 360 | let bottom_diff = b0.mutable_cpu_diff(); 361 | self.backward_cpu_impl(top_diff, bottom_diff, num, spatial_dim); 362 | } 363 | } 364 | 365 | fn backward_gpu(&mut self, top: &BlobVec, propagate_down: &Vec, 366 | bottom: &BlobVec) { 367 | no_gpu!(); 368 | } 369 | } 370 | 371 | register_layer_class!(BatchNorm); 372 | -------------------------------------------------------------------------------- /src/util/math_functions.rs: -------------------------------------------------------------------------------- 1 | use std::ops::{AddAssign, SubAssign, MulAssign, DivAssign, Neg, Add, Sub, Mul, Div}; 2 | 3 | use cblas::{Transpose, Layout, saxpy, daxpy, sasum, dasum, sdot, ddot, sscal, dscal, sgemm, sgemv, dgemm, dgemv, scopy, dcopy}; 4 | use float_next_after::NextAfter; 5 | use rand::distributions::{Uniform, Distribution, Bernoulli}; 6 | use rand::distributions::uniform::SampleUniform; 7 | use rand_distr::Normal; 8 | 9 | use super::mkl_alternate::*; 10 | use crate::util::rng::caffe_rng; 11 | 12 | 13 | pub trait CaffeNum: 14 | Copy + Sized + Default + PartialOrd + 15 | AddAssign + SubAssign + MulAssign + DivAssign + Neg + 16 | Add + Sub + Mul + Div + 17 | SampleUniform { 18 | fn is_zero(&self) -> bool; 19 | 20 | fn is_nan_v(&self) -> bool; 21 | 22 | fn from_f64(v: f64) -> Self; 23 | 24 | fn from_f32(v: f32) -> Self; 25 | 26 | fn from_i32(v: i32) -> Self; 27 | 28 | fn from_usize(v: usize) -> Self; 29 | 30 | fn to_f64(self) -> f64; 31 | 32 | fn to_f32(self) -> f32; 33 | 34 | fn to_i32(self) -> i32; 35 | 36 | fn to_usize(self) -> usize; 37 | 38 | fn sqrt(v: Self) -> Self; 39 | 40 | /// Return nature logarithm of the number. 41 | fn ln(v: Self) -> Self; 42 | 43 | /// Return `e^v`. 44 | fn exp(v: Self) -> Self; 45 | 46 | fn fabs(v: Self) -> Self; 47 | 48 | fn min(a: Self, b: Self) -> Self { 49 | if b < a { b } else { a } 50 | } 51 | 52 | fn max(a: Self, b: Self) -> Self { 53 | if a < b { b } else { a } 54 | } 55 | 56 | /// Function likes the C++ `std::nextafter`, provided the next representable value of `self` 57 | /// toward the `y` direction. 58 | fn next_toward(self, y: Self) -> Self; 59 | 60 | /// Function likes the C++ `std::numeric_limits::max()`. 61 | /// Returns the max value of Self Type. 62 | fn num_max() -> Self; 63 | 64 | // define the functions associated with Self 65 | 66 | fn caffe_axpy(n: i32, alpha: Self, x: &[Self], y: &mut [Self]); 67 | 68 | fn caffe_cpu_axpby(n: i32, alpha: Self, x: &[Self], beta: Self, y: &mut [Self]); 69 | 70 | fn caffe_cpu_asum(n: i32, x: &[Self]) -> Self; 71 | 72 | fn caffe_cpu_strided_dot(n: i32, x: &[Self], inc_x: i32, y: &[Self], inc_y: i32) -> Self; 73 | 74 | fn caffe_scal(n: i32, alpha: Self, x: &mut [Self]); 75 | 76 | fn caffe_cpu_dot(n: i32, x: &[Self], y: &[Self]) -> Self; 77 | 78 | fn caffe_cpu_scale(n: i32, alpha: Self, x: &[Self], y: &mut [Self]); 79 | 80 | fn caffe_add(n: usize, a: &[Self], b: &[Self], y: &mut [Self]); 81 | 82 | fn caffe_sub(n: usize, a: &[Self], b: &[Self], y: &mut [Self]); 83 | 84 | fn caffe_sub_assign(n: usize, y: &mut [Self], a: &[Self]); 85 | 86 | fn caffe_mul(n: usize, a: &[Self], b: &[Self], y: &mut [Self]); 87 | 88 | fn caffe_mul_assign(n: usize, y: &mut [Self], a: &[Self]); 89 | 90 | fn caffe_div(n: usize, a: &[Self], b: &[Self], y: &mut [Self]); 91 | 92 | fn caffe_add_scalar(n: usize, alpha: Self, y: &mut [Self]); 93 | 94 | fn caffe_powx(n: usize, a: &[Self], b: Self, y: &mut [Self]); 95 | 96 | fn caffe_sqr(n: usize, a: &[Self], y: &mut [Self]); 97 | 98 | fn caffe_sqrt(n: usize, a: &[Self], y: &mut [Self]); 99 | 100 | fn caffe_exp(n: usize, a: &[Self], y: &mut [Self]); 101 | 102 | fn caffe_log(n: usize, a: &[Self], y: &mut [Self]); 103 | 104 | fn caffe_abs(n: usize, a: &[Self], y: &mut [Self]); 105 | 106 | fn caffe_cpu_sign(n: usize, x: &[Self], y: &mut [Self]); 107 | 108 | fn caffe_cpu_sgnbit(n: usize, x: &[Self], y: &mut [Self]); 109 | 110 | fn caffe_cpu_fabs(n: usize, x: &[Self], y: &mut [Self]); 111 | 112 | fn caffe_cpu_gemm(trans_a: Transpose, trans_b: Transpose, m: i32, n: i32, k: i32, 113 | alpha: Self, a: &[Self], b: &[Self], beta: Self, c: &mut [Self]); 114 | 115 | fn caffe_cpu_gemv(trans_a: Transpose, m: i32, n: i32, alpha: Self, 116 | a: &[Self], x: &[Self], beta: Self, y: &mut [Self]); 117 | } 118 | 119 | impl CaffeNum for i32 { 120 | fn is_zero(&self) -> bool { 121 | *self == 0 122 | } 123 | 124 | fn is_nan_v(&self) -> bool { 125 | false 126 | } 127 | 128 | fn from_f64(v: f64) -> Self { 129 | v as i32 130 | } 131 | 132 | fn from_f32(v: f32) -> Self { 133 | v as i32 134 | } 135 | 136 | fn from_i32(v: i32) -> Self { 137 | v 138 | } 139 | 140 | fn from_usize(v: usize) -> Self { 141 | v as i32 142 | } 143 | 144 | fn to_f64(self) -> f64 { 145 | self as f64 146 | } 147 | 148 | fn to_f32(self) -> f32 { 149 | self as f32 150 | } 151 | 152 | fn to_i32(self) -> i32 { 153 | self 154 | } 155 | 156 | fn to_usize(self) -> usize { 157 | self as usize 158 | } 159 | 160 | fn sqrt(v: Self) -> Self { 161 | (v as f64).sqrt() as i32 162 | } 163 | 164 | fn ln(v: Self) -> Self { 165 | (v as f64).ln() as i32 166 | } 167 | 168 | fn exp(v: Self) -> Self { 169 | (v as f64).exp() as i32 170 | } 171 | 172 | fn fabs(v: Self) -> Self { 173 | v.abs() 174 | } 175 | 176 | fn next_toward(self, y: Self) -> Self { 177 | if self == y { 178 | return self; 179 | } 180 | if self > y { 181 | self - 1 182 | } else { 183 | self + 1 184 | } 185 | } 186 | 187 | fn num_max() -> Self { 188 | i32::MAX 189 | } 190 | 191 | fn caffe_axpy(n: i32, alpha: Self, x: &[Self], y: &mut [Self]) { 192 | todo!() 193 | } 194 | 195 | fn caffe_cpu_axpby(n: i32, alpha: Self, x: &[Self], beta: Self, y: &mut [Self]) { 196 | todo!() 197 | } 198 | 199 | fn caffe_cpu_asum(n: i32, x: &[Self]) -> Self { 200 | todo!() 201 | } 202 | 203 | fn caffe_cpu_strided_dot(n: i32, x: &[Self], inc_x: i32, y: &[Self], inc_y: i32) -> Self { 204 | todo!() 205 | } 206 | 207 | fn caffe_scal(n: i32, alpha: Self, x: &mut [Self]) { 208 | todo!() 209 | } 210 | 211 | fn caffe_cpu_dot(n: i32, x: &[Self], y: &[Self]) -> Self { 212 | todo!() 213 | } 214 | 215 | fn caffe_cpu_scale(n: i32, alpha: Self, x: &[Self], y: &mut [Self]) { 216 | todo!() 217 | } 218 | 219 | fn caffe_add(n: usize, a: &[Self], b: &[Self], y: &mut [Self]) { 220 | todo!() 221 | } 222 | 223 | fn caffe_sub(n: usize, a: &[Self], b: &[Self], y: &mut [Self]) { 224 | todo!() 225 | } 226 | 227 | fn caffe_sub_assign(n: usize, y: &mut [Self], a: &[Self]) { 228 | todo!() 229 | } 230 | 231 | fn caffe_mul(n: usize, a: &[Self], b: &[Self], y: &mut [Self]) { 232 | todo!() 233 | } 234 | 235 | fn caffe_mul_assign(n: usize, y: &mut [Self], a: &[Self]) { 236 | todo!() 237 | } 238 | 239 | fn caffe_div(n: usize, a: &[Self], b: &[Self], y: &mut [Self]) { 240 | todo!() 241 | } 242 | 243 | fn caffe_add_scalar(n: usize, alpha: Self, y: &mut [Self]) { 244 | todo!() 245 | } 246 | 247 | fn caffe_powx(n: usize, a: &[Self], b: Self, y: &mut [Self]) { 248 | todo!() 249 | } 250 | 251 | fn caffe_sqr(n: usize, a: &[Self], y: &mut [Self]) { 252 | todo!() 253 | } 254 | 255 | fn caffe_sqrt(n: usize, a: &[Self], y: &mut [Self]) { 256 | todo!() 257 | } 258 | 259 | fn caffe_exp(n: usize, a: &[Self], y: &mut [Self]) { 260 | todo!() 261 | } 262 | 263 | fn caffe_log(n: usize, a: &[Self], y: &mut [Self]) { 264 | todo!() 265 | } 266 | 267 | fn caffe_abs(n: usize, a: &[Self], y: &mut [Self]) { 268 | todo!() 269 | } 270 | 271 | fn caffe_cpu_sign(n: usize, x: &[Self], y: &mut [Self]) { 272 | todo!() 273 | } 274 | 275 | fn caffe_cpu_sgnbit(n: usize, x: &[Self], y: &mut [Self]) { 276 | todo!() 277 | } 278 | 279 | fn caffe_cpu_fabs(n: usize, x: &[Self], y: &mut [Self]) { 280 | todo!() 281 | } 282 | 283 | fn caffe_cpu_gemm(trans_a: Transpose, trans_b: Transpose, m: i32, n: i32, k: i32, 284 | alpha: Self, a: &[Self], b: &[Self], beta: Self, c: &mut [Self]) { 285 | todo!() 286 | } 287 | 288 | fn caffe_cpu_gemv(trans_a: Transpose, m: i32, n: i32, alpha: Self, 289 | a: &[Self], x: &[Self], beta: Self, y: &mut [Self]) { 290 | todo!() 291 | } 292 | } 293 | 294 | impl CaffeNum for f32 { 295 | fn is_zero(&self) -> bool { 296 | *self == 0f32 297 | } 298 | 299 | fn is_nan_v(&self) -> bool { 300 | self.is_nan() 301 | } 302 | 303 | fn from_f64(v: f64) -> Self { 304 | v as f32 305 | } 306 | 307 | fn from_f32(v: f32) -> Self { 308 | v 309 | } 310 | 311 | fn from_i32(v: i32) -> Self { 312 | v as f32 313 | } 314 | 315 | fn from_usize(v: usize) -> Self { 316 | v as f32 317 | } 318 | 319 | fn to_f64(self) -> f64 { 320 | self as f64 321 | } 322 | 323 | fn to_f32(self) -> f32 { 324 | self 325 | } 326 | 327 | fn to_i32(self) -> i32 { 328 | self as i32 329 | } 330 | 331 | fn to_usize(self) -> usize { 332 | self as usize 333 | } 334 | 335 | fn sqrt(v: Self) -> Self { 336 | v.sqrt() 337 | } 338 | 339 | fn ln(v: Self) -> Self { 340 | v.ln() 341 | } 342 | 343 | fn exp(v: Self) -> Self { 344 | v.exp() 345 | } 346 | 347 | fn fabs(v: Self) -> Self { 348 | v.abs() 349 | } 350 | 351 | fn next_toward(self, y: Self) -> Self { 352 | self.next_after(y) 353 | } 354 | 355 | fn num_max() -> Self { 356 | f32::MAX 357 | } 358 | 359 | // Impls functions for Self 360 | 361 | fn caffe_axpy(n: i32, alpha: f32, x: &[f32], y: &mut [f32]) { 362 | unsafe { saxpy(n, alpha, x, 1, y, 1); } 363 | } 364 | 365 | fn caffe_cpu_axpby(n: i32, alpha: Self, x: &[Self], beta: Self, y: &mut [Self]) { 366 | cblas_saxpby(n, alpha, x, 1, beta, y, 1); 367 | } 368 | 369 | fn caffe_cpu_asum(n: i32, x: &[f32]) -> f32 { 370 | unsafe { sasum(n, x, 1) } 371 | } 372 | 373 | fn caffe_cpu_strided_dot(n: i32, x: &[f32], inc_x: i32, y: &[f32], inc_y: i32) -> f32 { 374 | unsafe { sdot(n, x, inc_x, y, inc_y) } 375 | } 376 | 377 | fn caffe_scal(n: i32, alpha: f32, x: &mut [f32]) { 378 | unsafe { sscal(n, alpha, x, 1) } 379 | } 380 | 381 | fn caffe_cpu_dot(n: i32, x: &[f32], y: &[f32]) -> f32 { 382 | Self::caffe_cpu_strided_dot(n, x, 1, y, 1) 383 | } 384 | 385 | fn caffe_cpu_scale(n: i32, alpha: Self, x: &[Self], y: &mut [Self]) { 386 | unsafe { 387 | scopy(n, x, 1, y, 1); 388 | sscal(n, alpha, y, 1); 389 | } 390 | } 391 | 392 | fn caffe_add(n: usize, a: &[f32], b: &[f32], y: &mut [f32]) { 393 | vs_add(n, a, b, y); 394 | } 395 | 396 | fn caffe_sub(n: usize, a: &[f32], b: &[f32], y: &mut [f32]) { 397 | vs_sub(n, a, b, y); 398 | } 399 | 400 | fn caffe_sub_assign(n: usize, y: &mut [Self], a: &[Self]) { 401 | vs_sub_assign(n, y, a); 402 | } 403 | 404 | fn caffe_mul(n: usize, a: &[f32], b: &[f32], y: &mut [f32]) { 405 | vs_mul(n, a, b, y); 406 | } 407 | 408 | fn caffe_mul_assign(n: usize, y: &mut [f32], a: &[f32]) { 409 | vs_mul_assign(n, y, a); 410 | } 411 | 412 | fn caffe_div(n: usize, a: &[f32], b: &[f32], y: &mut [f32]) { 413 | vs_div(n, a, b, y); 414 | } 415 | 416 | fn caffe_add_scalar(n: usize, alpha: Self, y: &mut [Self]) { 417 | assert!(y.len() >= n); 418 | for i in 0..n { 419 | // SAFETY: assert y.len >= n 420 | unsafe { *y.get_unchecked_mut(i) += alpha; } 421 | } 422 | } 423 | 424 | fn caffe_powx(n: usize, a: &[f32], b: f32, y: &mut [f32]) { 425 | vs_powx(n, a, b, y); 426 | } 427 | 428 | fn caffe_sqr(n: usize, a: &[f32], y: &mut [f32]) { 429 | vs_sqr(n, a, y); 430 | } 431 | 432 | fn caffe_sqrt(n: usize, a: &[f32], y: &mut [f32]) { 433 | vs_sqrt(n, a, y); 434 | } 435 | 436 | fn caffe_exp(n: usize, a: &[f32], y: &mut [f32]) { 437 | vs_exp(n, a, y); 438 | } 439 | 440 | fn caffe_log(n: usize, a: &[f32], y: &mut [f32]) { 441 | vs_ln(n, a, y); 442 | } 443 | 444 | fn caffe_abs(n: usize, a: &[f32], y: &mut [f32]) { 445 | vs_abs(n, a, y); 446 | } 447 | 448 | fn caffe_cpu_sign(n: usize, x: &[f32], y: &mut [f32]) { 449 | vs_sign(n, x, y); 450 | } 451 | 452 | fn caffe_cpu_sgnbit(n: usize, x: &[f32], y: &mut [f32]) { 453 | vs_sgn_bit(n, x, y); 454 | } 455 | 456 | fn caffe_cpu_fabs(n: usize, x: &[f32], y: &mut [f32]) { 457 | vs_fabs(n, x, y); 458 | } 459 | 460 | fn caffe_cpu_gemm(trans_a: Transpose, trans_b: Transpose, m: i32, n: i32, k: i32, 461 | alpha: Self, a: &[Self], b: &[Self], beta: Self, c: &mut [Self]) { 462 | let lda = if trans_a == Transpose::None { k } else { m }; 463 | let ldb = if trans_b == Transpose::None { n } else { k }; 464 | unsafe { sgemm(Layout::RowMajor, trans_a, trans_b, m, n, k, alpha, a, lda, b, ldb, beta, c, n); } 465 | } 466 | 467 | fn caffe_cpu_gemv(trans_a: Transpose, m: i32, n: i32, alpha: Self, 468 | a: &[Self], x: &[Self], beta: Self, y: &mut [Self]) { 469 | unsafe { sgemv(Layout::RowMajor, trans_a, m, n, alpha, a, n, x, 1, beta, y, 1); } 470 | } 471 | } 472 | 473 | impl CaffeNum for f64 { 474 | fn is_zero(&self) -> bool { 475 | *self == 0f64 476 | } 477 | 478 | fn is_nan_v(&self) -> bool { 479 | self.is_nan() 480 | } 481 | 482 | fn from_f64(v: f64) -> Self { 483 | v 484 | } 485 | 486 | fn from_f32(v: f32) -> Self { 487 | v as f64 488 | } 489 | 490 | fn from_i32(v: i32) -> Self { 491 | v as f64 492 | } 493 | 494 | fn from_usize(v: usize) -> Self { 495 | v as f64 496 | } 497 | 498 | fn to_f64(self) -> f64 { 499 | self 500 | } 501 | 502 | fn to_f32(self) -> f32 { 503 | self as f32 504 | } 505 | 506 | fn to_i32(self) -> i32 { 507 | self as i32 508 | } 509 | 510 | fn to_usize(self) -> usize { 511 | self as usize 512 | } 513 | 514 | fn sqrt(v: Self) -> Self { 515 | v.sqrt() 516 | } 517 | 518 | fn ln(v: Self) -> Self { 519 | v.ln() 520 | } 521 | 522 | fn exp(v: Self) -> Self { 523 | v.exp() 524 | } 525 | 526 | fn fabs(v: Self) -> Self { 527 | v.abs() 528 | } 529 | 530 | fn next_toward(self, y: Self) -> Self { 531 | self.next_after(y) 532 | } 533 | 534 | fn num_max() -> Self { 535 | f64::MAX 536 | } 537 | 538 | // Impls Self functions. 539 | 540 | fn caffe_axpy(n: i32, alpha: f64, x: &[f64], y: &mut [f64]) { 541 | unsafe { daxpy(n, alpha, x, 1, y, 1); } 542 | } 543 | 544 | fn caffe_cpu_axpby(n: i32, alpha: Self, x: &[Self], beta: Self, y: &mut [Self]) { 545 | cblas_daxpby(n, alpha, x, 1, beta, y, 1); 546 | } 547 | 548 | fn caffe_cpu_asum(n: i32, x: &[f64]) -> f64 { 549 | unsafe { dasum(n, x, 1) } 550 | } 551 | 552 | fn caffe_cpu_strided_dot(n: i32, x: &[f64], inc_x: i32, y: &[f64], inc_y: i32) -> f64 { 553 | unsafe { ddot(n, x, inc_x, y, inc_y) } 554 | } 555 | 556 | fn caffe_scal(n: i32, alpha: f64, x: &mut [f64]) { 557 | unsafe { dscal(n, alpha, x, 1) } 558 | } 559 | 560 | fn caffe_cpu_dot(n: i32, x: &[f64], y: &[f64]) -> f64 { 561 | Self::caffe_cpu_strided_dot(n, x, 1, y, 1) 562 | } 563 | 564 | fn caffe_cpu_scale(n: i32, alpha: Self, x: &[Self], y: &mut [Self]) { 565 | unsafe { 566 | dcopy(n, x, 1, y, 1); 567 | dscal(n, alpha, y, 1); 568 | } 569 | } 570 | 571 | fn caffe_add(n: usize, a: &[f64], b: &[f64], y: &mut [f64]) { 572 | vd_add(n, a, b, y); 573 | } 574 | 575 | fn caffe_sub(n: usize, a: &[f64], b: &[f64], y: &mut [f64]) { 576 | vd_sub(n, a, b, y); 577 | } 578 | 579 | fn caffe_sub_assign(n: usize, y: &mut [Self], a: &[Self]) { 580 | vd_sub_assign(n, y, a); 581 | } 582 | 583 | fn caffe_mul(n: usize, a: &[f64], b: &[f64], y: &mut [f64]) { 584 | vd_mul(n, a, b, y); 585 | } 586 | 587 | fn caffe_mul_assign(n: usize, y: &mut [f64], a: &[f64]) { 588 | vd_mul_assign(n, y, a); 589 | } 590 | 591 | fn caffe_div(n: usize, a: &[f64], b: &[f64], y: &mut [f64]) { 592 | vd_div(n, a, b, y); 593 | } 594 | 595 | fn caffe_add_scalar(n: usize, alpha: f64, y: &mut [f64]) { 596 | assert!(y.len() >= n); 597 | for i in 0..n { 598 | // SAFETY: assert y.len >= n 599 | unsafe { *y.get_unchecked_mut(i) += alpha; } 600 | } 601 | } 602 | 603 | fn caffe_powx(n: usize, a: &[f64], b: f64, y: &mut [f64]) { 604 | vd_powx(n, a, b, y); 605 | } 606 | 607 | fn caffe_sqr(n: usize, a: &[f64], y: &mut [f64]) { 608 | vd_sqr(n, a, y); 609 | } 610 | 611 | fn caffe_sqrt(n: usize, a: &[f64], y: &mut [f64]) { 612 | vd_sqrt(n, a, y); 613 | } 614 | 615 | fn caffe_exp(n: usize, a: &[f64], y: &mut [f64]) { 616 | vd_exp(n, a, y); 617 | } 618 | 619 | fn caffe_log(n: usize, a: &[f64], y: &mut [f64]) { 620 | vd_ln(n, a, y); 621 | } 622 | 623 | fn caffe_abs(n: usize, a: &[f64], y: &mut [f64]) { 624 | vd_abs(n, a, y); 625 | } 626 | 627 | fn caffe_cpu_sign(n: usize, x: &[f64], y: &mut [f64]) { 628 | vd_sign(n, x, y); 629 | } 630 | 631 | fn caffe_cpu_sgnbit(n: usize, x: &[f64], y: &mut [f64]) { 632 | vd_sgn_bit(n, x, y); 633 | } 634 | 635 | fn caffe_cpu_fabs(n: usize, x: &[f64], y: &mut [f64]) { 636 | vd_fabs(n, x, y); 637 | } 638 | 639 | fn caffe_cpu_gemm(trans_a: Transpose, trans_b: Transpose, m: i32, n: i32, k: i32, 640 | alpha: Self, a: &[Self], b: &[Self], beta: Self, c: &mut [Self]) { 641 | let lda = if trans_a == Transpose::None { k } else { m }; 642 | let ldb = if trans_b == Transpose::None { n } else { k }; 643 | unsafe { dgemm(Layout::RowMajor, trans_a, trans_b, m, n, k, alpha, a, lda, b, ldb, beta, c, n); } 644 | } 645 | 646 | fn caffe_cpu_gemv(trans_a: Transpose, m: i32, n: i32, alpha: Self, 647 | a: &[Self], x: &[Self], beta: Self, y: &mut [Self]) { 648 | unsafe { dgemv(Layout::RowMajor, trans_a, m, n, alpha, a, n, x, 1, beta, y, 1); } 649 | } 650 | } 651 | 652 | 653 | pub fn caffe_copy(n: usize, x: &[T], y: &mut [T]) { 654 | if x.as_ptr() != y.as_ptr() { 655 | debug_assert!(x.len() >= n && y.len() >= n); 656 | 657 | y[..n].copy_from_slice(&x[..n]); 658 | } 659 | } 660 | 661 | pub fn caffe_set(n: usize, alpha: T, y: &mut [T]) { 662 | assert!(y.len() >= n); 663 | if alpha.is_zero() { 664 | // SAFETY: the assert check makes sure that slice memory size is valid and `y` 665 | // is allocated in the safety context that guards the memory alignment. 666 | unsafe { std::ptr::write_bytes(y.as_mut_ptr(), 0u8, n); } 667 | } else { 668 | for x in &mut y[0..n] { 669 | *x = alpha; 670 | } 671 | } 672 | } 673 | 674 | pub fn caffe_next_after(b: T) -> T { 675 | b.next_toward(T::num_max()) 676 | } 677 | 678 | pub fn caffe_rng_uniform(n: usize, a: T, b: T, r: &mut [T]) { 679 | assert!(a <= b); 680 | assert!(n <= r.len()); 681 | let random_distribution = Uniform::new(a, caffe_next_after(b)); 682 | let rng = caffe_rng(); 683 | let mut rng = rng.borrow_mut(); 684 | for i in 0..n { 685 | // SAFETY: the assert check makes sure that index `i` is between in the slice range. 686 | unsafe { *r.get_unchecked_mut(i) = random_distribution.sample(rng.generator()) } 687 | } 688 | } 689 | 690 | pub fn caffe_rng_gaussian(n: usize, a: T, sigma: T, r: &mut [T]) { 691 | let a = a.to_f64(); 692 | let sigma = sigma.to_f64(); 693 | assert!(sigma > 0f64); 694 | assert!(n <= r.len()); 695 | let random_distribution = Normal::new(a, sigma).unwrap(); 696 | let rng = caffe_rng(); 697 | let mut rng = rng.borrow_mut(); 698 | for i in 0..n { 699 | // SAFETY: the assert check makes sure that index `i` is between in the slice range. 700 | unsafe { *r.get_unchecked_mut(i) = T::from_f64(random_distribution.sample(rng.generator())); } 701 | } 702 | } 703 | 704 | pub fn caffe_rng_bernoulli_i32(n: usize, p: T, r: &mut [i32]) { 705 | let p = p.to_f64(); 706 | assert!(p >= 0f64 && p <= 1f64); 707 | let random_distribution = Bernoulli::new(p).unwrap(); 708 | let rng = caffe_rng(); 709 | let mut rng = rng.borrow_mut(); 710 | for i in 0..n { 711 | // SAFETY: the assert check makes sure that index `i` is between in the slice range. 712 | unsafe { *r.get_unchecked_mut(i) = random_distribution.sample(rng.generator()) as i32; } 713 | } 714 | } 715 | 716 | pub fn caffe_rng_bernoulli_u32(n: usize, p: T, r: &mut [u32]) { 717 | let p = p.to_f64(); 718 | assert!(p >= 0f64 && p <= 1f64); 719 | let random_distribution = Bernoulli::new(p).unwrap(); 720 | let rng = caffe_rng(); 721 | let mut rng = rng.borrow_mut(); 722 | for i in 0..n { 723 | // SAFETY: the assert check makes sure that index `i` is between in the slice range. 724 | unsafe { *r.get_unchecked_mut(i) = random_distribution.sample(rng.generator()) as u32; } 725 | } 726 | } 727 | -------------------------------------------------------------------------------- /src/blob.rs: -------------------------------------------------------------------------------- 1 | use std::any::TypeId; 2 | use std::borrow::Borrow; 3 | use std::boxed::Box; 4 | use std::cell::{RefCell, Ref, RefMut}; 5 | use std::rc::Rc; 6 | use std::sync::{Arc, Mutex}; 7 | 8 | use crate::proto::caffe::{BlobProto, BlobShape}; 9 | use crate::synced_mem::{SyncedMemory, MemShared, ArcSyncedMemory}; 10 | use crate::util::math_functions::{CaffeNum, caffe_copy}; 11 | 12 | 13 | /// A marker trait to be used in the type bound of `Blob`. It is explicitly marked as `unsafe` and 14 | /// only should be implemented for `f32` and `f64` currently (impl for `i32` partially). 15 | pub unsafe trait BlobType: CaffeNum + std::fmt::Debug + 'static {} 16 | 17 | unsafe impl BlobType for i32 {} 18 | 19 | unsafe impl BlobType for f32 {} 20 | 21 | unsafe impl BlobType for f64 {} 22 | 23 | 24 | #[derive(Copy, Clone)] 25 | pub struct BlobMemRef<'a, T> { 26 | pub data: &'a [T], 27 | pub diff: &'a [T], 28 | } 29 | 30 | pub struct BlobMemRefMut<'a, T> { 31 | pub data: &'a mut [T], 32 | pub diff: &'a mut [T], 33 | } 34 | 35 | 36 | /// A wrapper around [`SyncedMemory`][SyncedMemory] holders serving as the basic computational unit 37 | /// through which `Layer`, `Net` and `Solver` interact. 38 | /// 39 | /// [SyncedMemory]: caffe_rs::synced_mem::SyncedMemory 40 | #[derive(Default)] 41 | pub struct Blob { 42 | data: Option>>>, 43 | diff: Option>>>, 44 | // shape_data: Option>>, 45 | shape: Vec, 46 | count: usize, 47 | capacity: usize, 48 | } 49 | 50 | pub const MAX_BLOB_AXES: i32 = 32; 51 | 52 | impl Blob where T: BlobType { 53 | pub fn new() -> Self { 54 | Default::default() 55 | } 56 | 57 | pub fn with_shape + ?Sized>(shape: &Q) -> Self { 58 | let mut blob = Blob::new(); 59 | blob.reshape(shape); 60 | blob 61 | } 62 | 63 | #[inline] 64 | pub fn num_axes(&self) -> i32 { 65 | self.shape.len() as i32 66 | } 67 | 68 | #[inline] 69 | pub fn shape(&self) -> &Vec { 70 | &self.shape 71 | } 72 | 73 | // pub fn gpu_shape(&self) -> &[i32] {} 74 | 75 | #[inline] 76 | pub fn shape_idx(&self, index: i32) -> i32 { 77 | self.shape[self.canonical_axis_index(index)] 78 | } 79 | 80 | #[inline] 81 | pub fn count(&self) -> usize { 82 | self.count 83 | } 84 | 85 | /// Compute the volume of a slice; i.e., the product of dimensions among a range of axes. 86 | /// 87 | /// * `start`: The first axis to include in the slice. 88 | /// * `end`: The first axis to exclude from the slice. 89 | pub fn count_range(&self, start: usize, end: usize) -> i32 { 90 | if start == end { 91 | 1 92 | } else { 93 | self.shape[start..end].iter().sum() 94 | } 95 | } 96 | 97 | /// Compute the volume of a slice spanning from a particular first axis (`start`) to the final axis. 98 | #[inline] 99 | pub fn count_range_to_end(&self, start: usize) -> i32 { 100 | self.count_range(start, self.shape.len()) 101 | } 102 | 103 | pub fn canonical_axis_index(&self, index: i32) -> usize { 104 | let axes = self.num_axes(); 105 | check_ge!(index, -axes, "axis {:?} out of range for {:?}-D Blob with shape {:?}", 106 | index, axes, self.shape_string()); 107 | check_lt!(index, axes, "axis {:?} out of range for {:?}-D Blob with shape {:?}", 108 | index, axes, self.shape_string()); 109 | 110 | if index < 0 { 111 | (index + axes) as usize 112 | } else { 113 | index as usize 114 | } 115 | } 116 | 117 | pub fn shape_string(&self) -> String { 118 | let capacity = self.shape.len() * 4; 119 | let mut s = String::with_capacity(capacity); 120 | for &x in &self.shape { 121 | s.push_str(x.to_string().as_str()); 122 | s.push(' '); 123 | } 124 | s.push('('); 125 | s.push_str(self.count.to_string().as_str()); 126 | s.push(')'); 127 | 128 | s 129 | } 130 | 131 | pub fn offset(&self, n: i32, c: i32, h: i32, w: i32) -> i32 { 132 | check_ge!(n, 0); 133 | check_le!(n, self.shape_idx(0)); 134 | check_ge!(c, 0); 135 | check_le!(c, self.shape_idx(1)); 136 | check_ge!(h, 0); 137 | check_le!(h, self.shape_idx(2)); 138 | check_ge!(w, 0); 139 | check_le!(w, self.shape_idx(3)); 140 | 141 | ((n * self.shape_idx(1) + c) * self.shape_idx(2) + h) * self.shape_idx(3) + w 142 | } 143 | 144 | pub fn offset_idx(&self, indices: &Vec) -> i32 { 145 | check_le!(indices.len(), self.shape.len()); 146 | 147 | let mut offset = 0; 148 | let mut idx: usize = 0; 149 | let len = indices.len(); 150 | for &x in &self.shape { 151 | offset *= x; 152 | if len > idx { 153 | let v = indices[idx]; 154 | check_ge!(v, 0); 155 | check_le!(v, x); 156 | offset += v; 157 | } 158 | } 159 | 160 | offset 161 | } 162 | 163 | pub fn reshape + ?Sized>(&mut self, shape: &Q) { 164 | let shape = shape.borrow(); 165 | check_le!(shape.len() as i32, MAX_BLOB_AXES); 166 | 167 | let mut count = 1; 168 | for &x in shape { 169 | check_ge!(x, 0); // maybe should constrain with x>0? 170 | if count != 0 { 171 | check_le!(x, i32::MAX / count, "blob size exceeds INT_MAX"); 172 | } 173 | 174 | count *= x; 175 | } 176 | 177 | let count = count as usize; 178 | self.count = count; 179 | self.shape = shape.to_vec(); 180 | 181 | if count > self.capacity { 182 | self.capacity = count; 183 | self.data = Some(Rc::new(RefCell::new(SyncedMemory::new(count)))); 184 | self.diff = Some(Rc::new(RefCell::new(SyncedMemory::new(count)))); 185 | } 186 | } 187 | 188 | pub fn reshape_like(&mut self, other: &Blob) { 189 | self.reshape(other.shape()); 190 | } 191 | 192 | pub fn reshape_with(&mut self, shape: &BlobShape) { 193 | check_le!(shape.get_dim().len() as i32, MAX_BLOB_AXES); 194 | 195 | let mut shape_vec = Vec::with_capacity(shape.get_dim().len()); 196 | for &x in shape.get_dim() { 197 | shape_vec.push(x as i32); 198 | } 199 | self.reshape(&shape_vec); 200 | } 201 | 202 | pub fn cpu_data(&self) -> &[T] { 203 | let (ptr, count) = self.data.as_ref().unwrap().borrow_mut().cpu_data_raw(); 204 | unsafe { std::slice::from_raw_parts(ptr, count) } 205 | } 206 | 207 | pub fn cpu_data_shared(&self) -> MemShared { 208 | self.data.as_ref().unwrap().borrow_mut().cpu_data_shared() 209 | } 210 | 211 | // pub fn gpu_data(&mut self) -> &[T] {} 212 | 213 | pub fn cpu_diff(&self) -> &[T] { 214 | if let Some(ref ptr) = self.diff { 215 | let (ptr, count) = (*ptr).borrow_mut().cpu_data_raw(); 216 | unsafe { std::slice::from_raw_parts(ptr, count) } 217 | } else { 218 | panic!("diff memory not init"); 219 | } 220 | } 221 | 222 | pub fn cpu_diff_shared(&self) -> MemShared { 223 | self.diff.as_ref().unwrap().borrow_mut().cpu_data_shared() 224 | } 225 | 226 | // pub fn gpu_diff(&mut self) -> &[T] {} 227 | 228 | pub fn cpu_mem_ref(&self) -> BlobMemRef { 229 | let (data_ptr, data_count) = self.data.as_ref().unwrap().borrow_mut().cpu_data_raw(); 230 | let (diff_ptr, diff_count) = self.diff.as_ref().unwrap().borrow_mut().cpu_data_raw(); 231 | BlobMemRef { 232 | data: unsafe { std::slice::from_raw_parts(data_ptr, data_count) }, 233 | diff: unsafe { std::slice::from_raw_parts(diff_ptr, diff_count) }, 234 | } 235 | } 236 | 237 | pub fn mutable_cpu_data(&mut self) -> &mut [T] { 238 | if let Some(ref mut ptr) = self.data { 239 | let (ptr, count) = (*ptr).borrow_mut().mutable_cpu_data_raw(); 240 | unsafe { std::slice::from_raw_parts_mut(ptr, count) } 241 | } else { 242 | panic!("data memory not init"); 243 | } 244 | } 245 | 246 | pub fn mutable_cpu_diff(&mut self) -> &mut [T] { 247 | if let Some(ref mut ptr) = self.diff { 248 | let (ptr, count) = (*ptr).borrow_mut().mutable_cpu_data_raw(); 249 | unsafe { std::slice::from_raw_parts_mut(ptr, count) } 250 | } else { 251 | panic!("diff memory not init"); 252 | } 253 | } 254 | 255 | pub fn mutable_cpu_mem_ref(&mut self) -> BlobMemRefMut { 256 | let (data_ptr, data_count) = self.data.as_ref().unwrap().borrow_mut().mutable_cpu_data_raw(); 257 | let (diff_ptr, diff_count) = self.diff.as_ref().unwrap().borrow_mut().mutable_cpu_data_raw(); 258 | BlobMemRefMut { 259 | data: unsafe { std::slice::from_raw_parts_mut(data_ptr, data_count) }, 260 | diff: unsafe { std::slice::from_raw_parts_mut(diff_ptr, diff_count) }, 261 | } 262 | } 263 | 264 | pub fn set_cpu_data(&mut self, data: &MemShared) { 265 | if let Some(ref mut ptr) = self.data { 266 | let data_count = (*ptr).as_ref().borrow().count(); 267 | if data_count != self.count { 268 | self.data = Some(Rc::new(RefCell::new(SyncedMemory::new(self.count)))); 269 | self.diff = Some(Rc::new(RefCell::new(SyncedMemory::new(self.count)))); 270 | } 271 | self.data.as_ref().unwrap().borrow_mut().set_cpu_data(data); 272 | } else { 273 | panic!("data memory not init"); 274 | } 275 | } 276 | 277 | pub fn share_data(&mut self, other: &Blob) { 278 | check_eq!(self.count, other.count()); 279 | 280 | if let Some(ref ptr) = other.data { 281 | self.data = Some(Rc::clone(ptr)); 282 | } else { 283 | panic!("data memory of other not init"); 284 | } 285 | } 286 | 287 | #[inline] 288 | pub fn data_at(&self, n: i32, c: i32, h: i32, w: i32) -> T { 289 | self.cpu_data()[self.offset(n, c, h, w) as usize] 290 | } 291 | 292 | #[inline] 293 | pub fn diff_at(&self, n: i32, c: i32, h: i32, w: i32) -> T { 294 | self.cpu_diff()[self.offset(n, c, h, w) as usize] 295 | } 296 | 297 | #[inline] 298 | pub fn data_at_idx(&self, index: &Vec) -> T { 299 | self.cpu_data()[self.offset_idx(index) as usize] 300 | } 301 | 302 | #[inline] 303 | pub fn diff_at_idx(&self, index: &Vec) -> T { 304 | self.cpu_diff()[self.offset_idx(index) as usize] 305 | } 306 | 307 | pub fn share_diff(&mut self, other: &Blob) { 308 | check_eq!(self.count, other.count()); 309 | 310 | if let Some(ref ptr) = other.diff { 311 | self.diff = Some(Rc::clone(ptr)); 312 | } else { 313 | panic!("diff memory of other not init"); 314 | } 315 | } 316 | 317 | pub fn copy_from(&mut self, source: &Blob, copy_diff: bool, reshape: bool) { 318 | if (self.count != source.count) || (self.shape != source.shape) { 319 | if reshape { 320 | self.reshape_like(source); 321 | } else { 322 | panic!("Trying to copy blobs of different sizes."); 323 | } 324 | } 325 | 326 | if copy_diff { 327 | caffe_copy(self.count, source.cpu_diff(), self.mutable_cpu_diff()); 328 | } else { 329 | caffe_copy(self.count, source.cpu_data(), self.mutable_cpu_data()); 330 | } 331 | } 332 | 333 | // deprecated 334 | pub fn legacy_shape(&self, index: i32) -> i32 { 335 | check_le!(self.num_axes(), 4); 336 | check_lt!(index, 4); 337 | check_ge!(index, -4); 338 | 339 | if index >= self.num_axes() || index < -self.num_axes() { 340 | return 1; 341 | } 342 | self.shape_idx(index) 343 | } 344 | 345 | #[inline] 346 | pub fn num(&self) -> i32 { 347 | self.legacy_shape(0) 348 | } 349 | 350 | #[inline] 351 | pub fn channels(&self) -> i32 { 352 | self.legacy_shape(1) 353 | } 354 | 355 | #[inline] 356 | pub fn height(&self) -> i32 { 357 | self.legacy_shape(2) 358 | } 359 | 360 | #[inline] 361 | pub fn width(&self) -> i32 { 362 | self.legacy_shape(3) 363 | } 364 | 365 | pub fn shape_equals(&self, other: &BlobProto) -> bool { 366 | if other.has_num() || other.has_channels() || other.has_height() || other.has_width() { 367 | return self.num_axes() <= 4 && 368 | self.legacy_shape(-4) == other.get_num() && 369 | self.legacy_shape(-3) == other.get_channels() && 370 | self.legacy_shape(-2) == other.get_height() && 371 | self.legacy_shape(-1) == other.get_width(); 372 | } 373 | 374 | let other_shape = other.get_shape().get_dim(); 375 | let mut shape_vec = Vec::with_capacity(other_shape.len()); 376 | for &x in other_shape { 377 | shape_vec.push(x as i32); 378 | } 379 | 380 | self.shape == shape_vec 381 | } 382 | 383 | pub fn set_from_proto(&mut self, proto: &BlobProto, reshape: bool) { 384 | if reshape { 385 | let mut shape = Vec::new(); 386 | if proto.has_num() || proto.has_channels() || proto.has_height() || proto.has_width() { 387 | shape.reserve(4); 388 | shape.push(proto.get_num()); 389 | shape.push(proto.get_channels()); 390 | shape.push(proto.get_height()); 391 | shape.push(proto.get_width()); 392 | } else { 393 | let other_shape = proto.get_shape().get_dim(); 394 | shape.reserve(other_shape.len()); 395 | for &x in other_shape { 396 | shape.push(x as i32); 397 | } 398 | } 399 | self.reshape(&shape); 400 | } else { 401 | assert!(self.shape_equals(proto), "shape mismatch (reshape not set)"); 402 | } 403 | 404 | { 405 | // copy data 406 | let count = self.count; 407 | let mut data_vec = self.mutable_cpu_data(); 408 | if !proto.get_double_data().is_empty() { 409 | let f64_data = proto.get_double_data(); 410 | check_eq!(count, f64_data.len()); 411 | 412 | for i in 0..count { 413 | data_vec[i] = T::from_f64(f64_data[i]); 414 | } 415 | } else { 416 | let f32_data = proto.get_data(); 417 | check_eq!(count, f32_data.len()); 418 | 419 | for i in 0..count { 420 | data_vec[i] = T::from_f32(f32_data[i]); 421 | } 422 | } 423 | } 424 | { 425 | // check if copy diff 426 | let count = self.count; 427 | if !proto.get_double_diff().is_empty() { 428 | let f64_diff = proto.get_double_diff(); 429 | check_eq!(count, f64_diff.len()); 430 | 431 | let mut diff_vec = self.mutable_cpu_diff(); 432 | for i in 0..count { 433 | diff_vec[i] = T::from_f64(f64_diff[i]); 434 | } 435 | } else if !proto.get_diff().is_empty() { 436 | let f32_diff = proto.get_diff(); 437 | check_eq!(count, f32_diff.len()); 438 | 439 | let mut diff_vec = self.mutable_cpu_diff(); 440 | for i in 0..count { 441 | diff_vec[i] = T::from_f32(f32_diff[i]); 442 | } 443 | } 444 | } 445 | } 446 | 447 | pub fn update(&mut self) { 448 | let count = self.count as i32; 449 | let mem_ref = self.mutable_cpu_mem_ref(); 450 | T::caffe_axpy(count, T::from_f32(-1.0f32), mem_ref.diff, mem_ref.data); 451 | } 452 | 453 | fn asum_cpu(mem: &Option>>>, count: i32) -> T { 454 | mem.as_ref().map_or(T::default(), |ptr| { 455 | let data = (*ptr).as_ref().borrow(); 456 | data.try_map_cpu_data(|slice| T::caffe_cpu_asum(count, slice)) 457 | .unwrap_or_default() 458 | }) 459 | } 460 | 461 | pub fn asum_data(&self) -> T { 462 | Self::asum_cpu(&self.data, self.count as i32) 463 | } 464 | 465 | pub fn asum_diff(&self) -> T { 466 | Self::asum_cpu(&self.diff, self.count as i32) 467 | } 468 | 469 | fn scale_cpu(mem: &Option>>>, count: i32, scale_factor: T) { 470 | if let Some(ref ptr) = mem { 471 | let mut data = (*ptr).borrow_mut(); 472 | data.try_map_cpu_mut_data(|slice| T::caffe_scal(count, scale_factor, slice)); 473 | } 474 | } 475 | 476 | pub fn scale_data(&mut self, scale_factor: T) { 477 | Self::scale_cpu(&self.data, self.count as i32, scale_factor); 478 | } 479 | 480 | pub fn scale_diff(&mut self, scale_factor: T) { 481 | Self::scale_cpu(&self.diff, self.count as i32, scale_factor); 482 | } 483 | 484 | fn sumsq_cpu(mem: &Option>>>, count: i32) -> T { 485 | mem.as_ref().map_or(T::default(), |ptr| { 486 | let data = (*ptr).as_ref().borrow(); 487 | data.try_map_cpu_data(|slice| T::caffe_cpu_dot(count, slice, slice)) 488 | .unwrap_or_default() 489 | }) 490 | } 491 | 492 | pub fn sumsq_data(&self) -> T { 493 | Self::sumsq_cpu(&self.data, self.count as i32) 494 | } 495 | 496 | pub fn sumsq_diff(&self) -> T { 497 | Self::sumsq_cpu(&self.diff, self.count as i32) 498 | } 499 | 500 | pub fn to_proto(&self, proto: &mut BlobProto, write_diff: bool) { 501 | proto.clear_shape(); 502 | { 503 | let mut shape_dim = Vec::with_capacity(self.shape.len()); 504 | for &i in &self.shape { 505 | shape_dim.push(i as i64); 506 | } 507 | proto.mut_shape().set_dim(shape_dim); 508 | } 509 | 510 | if TypeId::of::() == TypeId::of::() { 511 | proto_write_f32_data(proto, self.cpu_data()); 512 | if write_diff { 513 | proto_write_f32_diff(proto, self.cpu_diff()); 514 | } 515 | } else if TypeId::of::() == TypeId::of::() { 516 | proto_write_f64_data(proto, self.cpu_data()); 517 | if write_diff { 518 | proto_write_f64_diff(proto, self.cpu_diff()); 519 | } 520 | } 521 | } 522 | } 523 | 524 | fn proto_write_f32_data(proto: &mut BlobProto, data: &[T]) { 525 | proto.clear_data(); 526 | proto.clear_diff(); 527 | { 528 | for &i in data { 529 | proto.mut_data().push(i.to_f32()); 530 | } 531 | } 532 | } 533 | 534 | fn proto_write_f32_diff(proto: &mut BlobProto, diff: &[T]) { 535 | for &i in diff { 536 | proto.mut_diff().push(i.to_f32()); 537 | } 538 | } 539 | 540 | fn proto_write_f64_data(proto: &mut BlobProto, data: &[T]) { 541 | proto.clear_double_data(); 542 | proto.clear_double_diff(); 543 | { 544 | for &i in data { 545 | proto.mut_double_data().push(i.to_f64()); 546 | } 547 | } 548 | } 549 | 550 | fn proto_write_f64_diff(proto: &mut BlobProto, diff: &[T]) { 551 | for &i in diff { 552 | proto.mut_double_diff().push(i.to_f64()); 553 | } 554 | } 555 | 556 | 557 | 558 | 559 | /// A thread-safe version of [`Blob`][Blob]. 560 | /// 561 | /// [Blob]: caffe_rs::blob::Blob 562 | #[derive(Default)] 563 | pub struct ArcBlob { 564 | data: Option>>>, 565 | diff: Option>>>, 566 | shape: Vec, 567 | count: usize, 568 | capacity: usize, 569 | } 570 | 571 | impl ArcBlob { 572 | pub fn new() -> Self { 573 | Default::default() 574 | } 575 | 576 | pub fn from(mut blob: Blob) -> Result> { 577 | let data = blob.data.map(|d| Rc::try_unwrap(d)); 578 | let data = data.map( 579 | |r| r.map(|rc| ArcSyncedMemory::from(rc.into_inner())) 580 | ); 581 | let diff = blob.diff.map(|d| Rc::try_unwrap(d)); 582 | let diff = diff.map( 583 | |r| r.map(|rc| ArcSyncedMemory::from(rc.into_inner())) 584 | ); 585 | if data.as_ref().map_or(false, |r| r.is_err()) || 586 | diff.as_ref().map_or(false, |r| r.is_err()) { 587 | blob.data = data.map(|r| r.err().unwrap()); 588 | blob.diff = diff.map(|r| r.err().unwrap()); 589 | return Result::Err(blob); 590 | } 591 | 592 | let data = data.map(|r| r.ok().unwrap()); 593 | let diff = diff.map(|r| r.ok().unwrap()); 594 | if data.as_ref().map_or(false, |r| r.is_err()) || 595 | diff.as_ref().map_or(false, |r| r.is_err()) { 596 | blob.data = data.map(|r| Rc::new(RefCell::new(r.err().unwrap()))); 597 | blob.diff = diff.map(|r| Rc::new(RefCell::new(r.err().unwrap()))); 598 | return Result::Err(blob); 599 | } 600 | 601 | let data = data.map( 602 | |d| Arc::new(Mutex::new(d.ok().unwrap())) 603 | ); 604 | let diff = diff.map( 605 | |d| Arc::new(Mutex::new(d.ok().unwrap())) 606 | ); 607 | let arc_blob = ArcBlob { 608 | data, 609 | diff, 610 | shape: blob.shape, 611 | count: blob.count, 612 | capacity: blob.capacity, 613 | }; 614 | Result::Ok(arc_blob) 615 | } 616 | 617 | pub fn into_blob(mut self) -> Result, ArcBlob> { 618 | let data = self.data.map(|a| Arc::try_unwrap(a)); 619 | let data = data.map( 620 | |r| r.map(|m| m.into_inner().unwrap().into_mem()) 621 | ); 622 | let diff = self.diff.map(|a| Arc::try_unwrap(a)); 623 | let diff = diff.map( 624 | |r| r.map(|m| m.into_inner().unwrap().into_mem()) 625 | ); 626 | if data.as_ref().map_or(false, |r| r.is_err()) || 627 | diff.as_ref().map_or(false, |r| r.is_err()) { 628 | self.data = data.map(|r| r.err().unwrap()); 629 | self.diff = diff.map(|r| r.err().unwrap()); 630 | return Result::Err(self); 631 | } 632 | 633 | let data = data.map(|r| r.ok().unwrap()); 634 | let diff = diff.map(|r| r.ok().unwrap()); 635 | if data.as_ref().map_or(false, |r| r.is_err()) || 636 | diff.as_ref().map_or(false, |r| r.is_err()) { 637 | self.data = data.map(|r| Arc::new(Mutex::new(r.err().unwrap()))); 638 | self.diff = diff.map(|r| Arc::new(Mutex::new(r.err().unwrap()))); 639 | return Result::Err(self); 640 | } 641 | 642 | let data = data.map( 643 | |d| Rc::new(RefCell::new(d.ok().unwrap())) 644 | ); 645 | let diff = diff.map( 646 | |d| Rc::new(RefCell::new(d.ok().unwrap())) 647 | ); 648 | let blob = Blob { 649 | data, 650 | diff, 651 | shape: self.shape, 652 | count: self.count, 653 | capacity: self.capacity, 654 | }; 655 | Result::Ok(blob) 656 | } 657 | } 658 | 659 | --------------------------------------------------------------------------------