├── .gitignore ├── .travis.yml ├── Cargo.toml ├── src ├── main.rs ├── lib.rs ├── context.rs └── float.rs ├── README.md ├── tests ├── unary_operations.rs ├── binary_operations.rs └── unorganized.rs ├── benches └── benches.rs └── LICENSE /.gitignore: -------------------------------------------------------------------------------- 1 | # Compiled files 2 | *.o 3 | *.so 4 | *.rlib 5 | *.dll 6 | 7 | # Executables 8 | *.exe 9 | 10 | # Generated by Cargo 11 | /target/ 12 | 13 | # Cargo lock file 14 | Cargo.lock 15 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | # Any copyright is dedicated to the Public Domain. 2 | # http://creativecommons.org/publicdomain/zero/1.0/ 3 | 4 | language: rust 5 | 6 | rust: 7 | - nightly 8 | - beta 9 | - stable 10 | 11 | script: 12 | - | 13 | cargo build && 14 | cargo test && 15 | cargo bench && 16 | cargo run && 17 | cargo doc 18 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | # This Source Code Form is subject to the terms of the Mozilla Public 2 | # License, v. 2.0. If a copy of the MPL was not distributed with this 3 | # file, You can obtain one at http://mozilla.org/MPL/2.0/. 4 | 5 | # TODO Comprehensive configurations. 6 | # TODO Setup contributor license agreement. 7 | 8 | [package] 9 | name = "autograd" 10 | version = "0.1.0" 11 | authors = ["Kibeom Kim "] 12 | 13 | [[bin]] 14 | name = "autograd_bin" 15 | test = false 16 | doc = false 17 | 18 | [dependencies] 19 | num = "*" 20 | -------------------------------------------------------------------------------- /src/main.rs: -------------------------------------------------------------------------------- 1 | // Any copyright is dedicated to the Public Domain. 2 | // http://creativecommons.org/publicdomain/zero/1.0/ 3 | 4 | #![feature(thread_local)] 5 | #![feature(std_misc)] 6 | #![feature(alloc)] 7 | 8 | #[macro_use] 9 | extern crate autograd; 10 | 11 | use autograd::Context; 12 | 13 | fn main() { 14 | // Initialize Autograd context with type f32 and capacity 100. 15 | let context = new_autograd_context!(f32, 100); 16 | 17 | // Initializes input variables. 18 | let x1 = context.new_variable(1.5); 19 | let x2 = context.new_variable(2.0); 20 | 21 | // Computes a math expression. 22 | let y = (x1 * x2) + x1 + 5.0; 23 | println!("y == {}", y.value); 24 | 25 | // Computes gradient with respect to y. 26 | context.differentiate(y); 27 | println!("dx1 == {}", context.get_derivative(x1)); 28 | println!("dx2 == {}", context.get_derivative(x2)); 29 | } 30 | 31 | // Output 32 | // y == 9.5 33 | // dx1 == 3 34 | // dx2 == 1.5 35 | // 36 | -------------------------------------------------------------------------------- /src/lib.rs: -------------------------------------------------------------------------------- 1 | // This Source Code Form is subject to the terms of the Mozilla Public 2 | // License, v. 2.0. If a copy of the MPL was not distributed with this 3 | // file, You can obtain one at http://mozilla.org/MPL/2.0/. 4 | 5 | // ! 6 | // # Autogard 7 | // 8 | 9 | // TODO Using expression template will give better performance. 10 | // e.g., http://www.met.reading.ac.uk/clouds/adept/ 11 | 12 | // TODO Multi-threading support. 13 | 14 | // TODO Add Valgrind, ASAN, and TSAN tests. 15 | 16 | // TODO Repect Rust column wrapping guide. 17 | // https://github. 18 | // com/rust-lang/rust-guidelines/blob/master/style/whitespace.md 19 | 20 | #![crate_name = "autograd"] 21 | #![crate_type = "rlib"] 22 | 23 | #![feature(thread_local)] 24 | 25 | extern crate num; 26 | 27 | mod context; 28 | mod float; 29 | 30 | pub use context::Context; 31 | // TODO ideally Private traits shouldn't be exported. 32 | pub use context::ContextCratePrivate; 33 | pub use context::ContextModulePrivate; 34 | 35 | pub use float::Float; 36 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | **Discontinued. The author switched to Tensorflow for personal projects.** 2 | 3 | # Autograd [![Build Status](https://travis-ci.org/kkimdev/autograd.svg?branch=master)](https://travis-ci.org/kkimdev/autograd) 4 | Rust automatic differentiation library to compute gradient values. It is mainly for nonlinear optimization and machine learning. **It is alpha stage yet.** 5 | 6 | ## Example 7 | ~~~rust 8 | #![feature(thread_local)] 9 | #![feature(std_misc)] 10 | #![feature(alloc)] 11 | 12 | #[macro_use] 13 | extern crate autograd; 14 | 15 | use autograd::Context; 16 | 17 | fn main() { 18 | // Initialize Autograd context with type f32 and capacity 100. 19 | let context = new_autograd_context!(f32, 100); 20 | 21 | // Initializes input variables. 22 | let x1 = context.new_variable(1.5); 23 | let x2 = context.new_variable(2.0); 24 | 25 | // Computes a math expression. 26 | let y = (x1 * x2) + x1 + 5.0; 27 | println!("y == {}", y.value); 28 | 29 | // Computes gradient with respect to y. 30 | context.differentiate(y); 31 | println!("dx1 == {}", context.get_derivative(x1)); 32 | println!("dx2 == {}", context.get_derivative(x2)); 33 | } 34 | 35 | /* Output 36 | y == 9.5 37 | dx1 == 3 38 | dx2 == 1.5 39 | /* 40 | 41 | ~~~ 42 | -------------------------------------------------------------------------------- /tests/unary_operations.rs: -------------------------------------------------------------------------------- 1 | /* Any copyright is dedicated to the Public Domain. 2 | * http://creativecommons.org/publicdomain/zero/1.0/ */ 3 | 4 | #![feature(thread_local)] 5 | #![feature(std_misc)] 6 | #![feature(alloc)] 7 | 8 | extern crate num; 9 | #[macro_use] 10 | extern crate autograd; 11 | 12 | use autograd::Context; 13 | use num::Float; 14 | use std::ops::Neg; 15 | 16 | macro_rules! unary_operation_test { 17 | ($name:ident, $x:expr, $y:expr, $dx:expr) => ( 18 | #[test] 19 | fn $name() { 20 | let context = new_autograd_context!(f32, 1000); 21 | let x = context.new_variable($x); 22 | 23 | let y = x.$name(); 24 | assert_eq!(y.value, $y); 25 | 26 | context.differentiate(y); 27 | assert_eq!(context.get_derivative(x), $dx); 28 | assert_eq!(context.get_derivative(y), 1.); 29 | } 30 | ) 31 | } 32 | 33 | // TODO What if we want multiple tests for an unary operation? 34 | // We can't use concat_idents! for function name yet though. 35 | // https://github.com/rust-lang/rust/issues/12249 36 | unary_operation_test!(cos, 0., 1., 0.); 37 | unary_operation_test!(neg, 1.5, -1.5, -1.); 38 | unary_operation_test!(sqrt, 16., 4., 0.125); 39 | unary_operation_test!(exp, 0., 1., 1.); 40 | -------------------------------------------------------------------------------- /tests/binary_operations.rs: -------------------------------------------------------------------------------- 1 | /* Any copyright is dedicated to the Public Domain. 2 | * http://creativecommons.org/publicdomain/zero/1.0/ */ 3 | 4 | #![feature(thread_local)] 5 | #![feature(std_misc)] 6 | #![feature(alloc)] 7 | 8 | #[macro_use] 9 | extern crate autograd; 10 | 11 | use autograd::Context; 12 | use std::ops::Add; 13 | use std::ops::Mul; 14 | use std::ops::Div; 15 | 16 | // TODO Write binary operation tests with constants involved. 17 | // There is a Rust bug regarding this though. https://github.com/rust-lang/rust/issues/19035 18 | macro_rules! binary_operation_test { 19 | ($name:ident, $x1:expr, $x2:expr, $y:expr, $dx1:expr, $dx2:expr) => ( 20 | #[test] 21 | fn $name() { 22 | let context = new_autograd_context!(f32, 1000); 23 | let x1 = context.new_variable($x1); 24 | let x2 = context.new_variable($x2); 25 | 26 | let y = x1.$name(x2); 27 | assert_eq!(y.value, $y); 28 | 29 | context.differentiate(y); 30 | assert_eq!(context.get_derivative(x1), $dx1); 31 | assert_eq!(context.get_derivative(x2), $dx2); 32 | assert_eq!(context.get_derivative(y), 1.); 33 | } 34 | ) 35 | } 36 | 37 | binary_operation_test!(add, 1.5, 2.5, 4., 1., 1.); 38 | binary_operation_test!(mul, 1.5, 2.5, 3.75, 2.5, 1.5); 39 | binary_operation_test!(div, 2.0, 1.0, 2.0, 1.0, -0.25); 40 | -------------------------------------------------------------------------------- /benches/benches.rs: -------------------------------------------------------------------------------- 1 | /* Any copyright is dedicated to the Public Domain. 2 | * http://creativecommons.org/publicdomain/zero/1.0/ */ 3 | 4 | #![feature(thread_local)] 5 | #![feature(std_misc)] 6 | #![feature(alloc)] 7 | #![feature(test)] 8 | #![feature(convert)] 9 | 10 | extern crate num; 11 | extern crate test; 12 | #[macro_use] 13 | extern crate autograd; 14 | 15 | static BENCH_SIZE : usize = 1024; 16 | 17 | fn norm(inputs: &[T]) -> T where T : num::Float { 18 | let mut inputs_iter = inputs.iter(); 19 | let first_input = *inputs_iter.next().unwrap(); 20 | let mut sum = first_input * first_input; 21 | for &input in inputs_iter { 22 | sum = sum + (input * input); 23 | } 24 | sum.sqrt() 25 | } 26 | 27 | #[bench] 28 | fn compute_norm_f32(bencher: &mut test::Bencher) { 29 | bencher.iter(|| { 30 | use num::Float; 31 | 32 | let mut x = Vec::::with_capacity(8 * BENCH_SIZE); 33 | 34 | for _ in (0..BENCH_SIZE) { 35 | x.push(2.0); 36 | } 37 | 38 | let y = norm(x.as_slice()); 39 | assert_eq!(y, 2.0 * (BENCH_SIZE as f32).sqrt()); 40 | y 41 | }); 42 | } 43 | 44 | #[bench] 45 | fn compute_norm_autograd_f32(bencher: &mut test::Bencher) { 46 | bencher.iter(|| { 47 | use num::Float; 48 | use autograd::Context; 49 | 50 | let context = new_autograd_context!(f32, 8 * BENCH_SIZE); 51 | let mut x = Vec::with_capacity(BENCH_SIZE); 52 | 53 | for _ in (0..BENCH_SIZE) { 54 | x.push(context.new_variable(2.0)); 55 | } 56 | 57 | let y = norm(x.as_slice()); 58 | assert_eq!(y.value, 2.0 * (BENCH_SIZE as f32).sqrt()); 59 | 60 | context.differentiate(y); 61 | context.get_derivative(x[0]) 62 | }); 63 | } 64 | -------------------------------------------------------------------------------- /tests/unorganized.rs: -------------------------------------------------------------------------------- 1 | /* Any copyright is dedicated to the Public Domain. 2 | * http://creativecommons.org/publicdomain/zero/1.0/ */ 3 | 4 | #![feature(thread_local)] 5 | #![feature(std_misc)] 6 | #![feature(alloc)] 7 | 8 | #[macro_use] 9 | extern crate autograd; 10 | 11 | use autograd::Context; 12 | 13 | #[test] 14 | fn single_thread_multiple_run() { 15 | for _ in 0..10 { 16 | // TODO make it longer to run. 17 | let context = new_autograd_context!(f32, 1000); 18 | let x1 = context.new_variable(1.5); 19 | let x2 = context.new_variable(2.5); 20 | 21 | let y = x1 * x2; 22 | assert_eq!(y.value, 3.75); 23 | 24 | context.differentiate(y); 25 | assert_eq!(context.get_derivative(x1), 2.5); 26 | assert_eq!(context.get_derivative(x2), 1.5); 27 | assert_eq!(context.get_derivative(y), 1.); 28 | } 29 | } 30 | 31 | // TODO std::sync::TaskPool is removed. https://github.com/rust-lang/rust/pull/22783 32 | // Find a way to re-enable this test. 33 | 34 | // #[test] 35 | // fn multi_thread_multiple_run() { 36 | // // TODO Insert purposeful sleep for each task? 37 | // let task_pool = std::sync::TaskPool::new(10); 38 | // let semaphore = std::sync::Arc::new(std::sync::Semaphore::new(-999)); 39 | // for _ in 0..1000 { 40 | // let semaphore_task = semaphore.clone(); 41 | // task_pool.execute(move || { 42 | // single_thread_multiple_run(); 43 | // semaphore_task.release(); 44 | // }); 45 | // } 46 | // semaphore.access(); 47 | // } 48 | 49 | #[test] 50 | #[should_panic(expected = "This Context instance is in use now. Note that a Context instance is allowed per construction location and per thread. Consequently, it cannot be recursively constructed unless it is destructed. This is a limitation caused by the thread local static variables usages in the current implementation.")] 51 | #[allow(unused_variables)] 52 | fn context_lock() { 53 | fn recursive_context(n: usize) { 54 | if n > 0 { 55 | let context = new_autograd_context!(f32, 10); 56 | recursive_context(n - 1); 57 | } 58 | } 59 | recursive_context(2); 60 | } 61 | 62 | #[test] 63 | #[allow(unused_variables)] 64 | fn context_capacity_expression_argument() { 65 | let a = 200; 66 | let b = 10; 67 | let context = new_autograd_context!(f32, a/b); 68 | } 69 | 70 | #[test] 71 | fn capacity_full() { 72 | let context = new_autograd_context!(f32, 100); 73 | for _ in 0..99 { 74 | context.new_variable(1.); 75 | } 76 | let y = context.new_variable(1.); 77 | context.differentiate(y) 78 | } 79 | 80 | // TODO Test capacity overflow check at context.differentiate(...). But it will be double-panic since it's also panicking at Drop. 81 | 82 | #[test] 83 | #[should_panic(expected = "There are more recorded variables, 101, than its capacity, 100. Memory is corrupted. Please consider using bigger capacity.")] 84 | fn capacity_overflow_drop() { 85 | let context = new_autograd_context!(f32, 100); 86 | for _ in 0..101 { 87 | context.new_variable(1.); 88 | } 89 | } 90 | -------------------------------------------------------------------------------- /src/context.rs: -------------------------------------------------------------------------------- 1 | // This Source Code Form is subject to the terms of the Mozilla Public 2 | // License, v. 2.0. If a copy of the MPL was not distributed with this 3 | // file, You can obtain one at http://mozilla.org/MPL/2.0/. 4 | 5 | //! this is context class to use autograd. 6 | 7 | use num; 8 | use std; 9 | 10 | 11 | // TODO #[inline] where appropriate. 12 | 13 | pub trait Context: ContextCratePrivate + std::marker::Sized where InternalFloat: num::Float { 14 | // public functions 15 | 16 | fn new_variable(&self, value: InternalFloat) -> super::float::Float { 17 | use super::float::FloatCratePrivate; 18 | 19 | FloatCratePrivate::new(value, Self::get_new_variable_index()) 20 | } 21 | 22 | fn differentiate(&self, float: super::float::Float) { 23 | use super::float::FloatCratePrivate; 24 | 25 | // TODO The current implementation is not performant and dirty. 26 | unsafe { 27 | assert!(*Self::get_recorded_variables_count() <= self.capacity(), 28 | "There are more recorded variables, {}, than its capacity, {}. Memory is \ 29 | corrupted. Please consider using bigger capacity.", 30 | *Self::get_recorded_variables_count(), 31 | self.capacity()); 32 | 33 | for i in (0..(*Self::get_recorded_variables_count())) { 34 | *Self::get_result_derivatives().offset(i as isize) = num::traits::Zero::zero(); 35 | } 36 | 37 | *Self::get_result_derivatives().offset(float.float_get_index() as isize) = 38 | num::traits::One::one(); 39 | for i in (0..(*Self::get_recorded_entries_count())).rev() { 40 | let lhs_index = *Self::get_lhs_indices().offset(i as isize); 41 | let rhs_index = *Self::get_rhs_indices().offset(i as isize); 42 | *Self::get_result_derivatives().offset(rhs_index as isize) = 43 | *Self::get_result_derivatives().offset(rhs_index as isize) + 44 | (*Self::get_result_derivatives().offset(lhs_index as isize) * 45 | *Self::get_adjoints().offset(i as isize)); 46 | } 47 | } 48 | } 49 | 50 | fn get_derivative(&self, float: super::float::Float) -> InternalFloat { 51 | use super::float::FloatCratePrivate; 52 | 53 | let float_index_offset = float.float_get_index() as isize; 54 | unsafe { *Self::get_result_derivatives().offset(float_index_offset) } 55 | } 56 | } 57 | 58 | pub trait ContextCratePrivate: ContextModulePrivate where InternalFloat: num::Float { 59 | fn unary_operation(adjoint: InternalFloat, rhs_index: usize) -> usize { 60 | let lhs_index = Self::get_new_variable_index(); 61 | let recorded_entries_count_offset = Self::get_new_entry_index() as isize; 62 | unsafe { 63 | *Self::get_adjoints().offset(recorded_entries_count_offset) = adjoint; 64 | *Self::get_lhs_indices().offset(recorded_entries_count_offset) = lhs_index; 65 | *Self::get_rhs_indices().offset(recorded_entries_count_offset) = rhs_index; 66 | } 67 | lhs_index 68 | } 69 | 70 | fn binary_operation(adjoints: &[InternalFloat; 2], rhs_indices: &[usize; 2]) -> usize { 71 | let lhs_index = Self::get_new_variable_index(); 72 | let recorded_entries_count_offset_1 = Self::get_new_entry_index() as isize; 73 | let recorded_entries_count_offset_2 = Self::get_new_entry_index() as isize; 74 | 75 | unsafe { 76 | // TODO is indexing inefficient? 77 | *Self::get_adjoints().offset(recorded_entries_count_offset_1) = adjoints[0]; 78 | *Self::get_lhs_indices().offset(recorded_entries_count_offset_1) = lhs_index; 79 | *Self::get_rhs_indices().offset(recorded_entries_count_offset_1) = rhs_indices[0]; 80 | 81 | *Self::get_adjoints().offset(recorded_entries_count_offset_2) = adjoints[1]; 82 | *Self::get_lhs_indices().offset(recorded_entries_count_offset_2) = lhs_index; 83 | *Self::get_rhs_indices().offset(recorded_entries_count_offset_2) = rhs_indices[1]; 84 | } 85 | lhs_index 86 | } 87 | } 88 | 89 | pub trait ContextModulePrivate where InternalFloat: num::Float { 90 | fn get_new_variable_index() -> usize { 91 | let count = Self::get_recorded_variables_count(); 92 | let index = *count; 93 | *count += 1; 94 | index 95 | } 96 | 97 | fn get_new_entry_index() -> usize { 98 | let count = Self::get_recorded_entries_count(); 99 | let index = *count; 100 | *count += 1; 101 | index 102 | } 103 | 104 | fn capacity(&self) -> usize; 105 | 106 | fn get_recorded_variables_count() -> &'static mut usize; 107 | fn get_recorded_entries_count() -> &'static mut usize; 108 | // TODO use 'static lifetime instead? 109 | fn get_adjoints<'a>() -> &'a mut *mut InternalFloat; 110 | fn get_lhs_indices<'a>() -> &'a mut *mut usize; 111 | fn get_rhs_indices<'a>() -> &'a mut *mut usize; 112 | fn get_result_derivatives<'a>() -> &'a mut *mut InternalFloat; 113 | } 114 | 115 | #[macro_export] 116 | macro_rules! new_autograd_context { 117 | ($InternalFloat:ty, $capacity:expr) => ( 118 | { 119 | struct ContextImpl { 120 | capacity: usize, 121 | _mutex_guard: std::sync::MutexGuard<'static, ()>, 122 | } 123 | 124 | impl $crate::Context<$InternalFloat> for ContextImpl { 125 | } 126 | 127 | impl $crate::ContextCratePrivate<$InternalFloat> for ContextImpl { 128 | } 129 | 130 | impl $crate::ContextModulePrivate<$InternalFloat> for ContextImpl { 131 | fn capacity(&self) -> usize { 132 | self.capacity 133 | } 134 | 135 | fn get_recorded_variables_count() -> &'static mut usize { 136 | #[thread_local] 137 | static mut ptr : usize = 0; 138 | unsafe { 139 | &mut ptr 140 | } 141 | } 142 | fn get_recorded_entries_count() -> &'static mut usize { 143 | #[thread_local] 144 | static mut ptr : usize = 0; 145 | unsafe { 146 | &mut ptr 147 | } 148 | } 149 | fn get_adjoints<'a>() -> &'a mut*mut $InternalFloat { 150 | #[thread_local] 151 | static mut ptr : *mut $InternalFloat = 0 as *mut $InternalFloat; 152 | unsafe { 153 | &mut ptr 154 | } 155 | } 156 | fn get_lhs_indices<'a>() -> &'a mut*mut usize { 157 | #[thread_local] 158 | static mut ptr : *mut usize = 0 as *mut usize; 159 | unsafe { 160 | &mut ptr 161 | } 162 | } 163 | fn get_rhs_indices<'a>() -> &'a mut*mut usize{ 164 | #[thread_local] 165 | static mut ptr : *mut usize = 0 as *mut usize; 166 | unsafe { 167 | &mut ptr 168 | } 169 | } 170 | 171 | fn get_result_derivatives<'a>() -> &'a mut*mut $InternalFloat { 172 | #[thread_local] 173 | static mut ptr : *mut $InternalFloat = 0 as *mut $InternalFloat; 174 | unsafe { 175 | &mut ptr 176 | } 177 | } 178 | } 179 | 180 | impl ContextImpl { 181 | fn new(capacity: usize) -> Self { 182 | let context; 183 | 184 | #[thread_local] 185 | static _MUTEX : std::sync::StaticMutex = std::sync::MUTEX_INIT; 186 | match _MUTEX.try_lock() { 187 | Ok(mutex_guard) => context = ContextImpl{capacity: capacity, _mutex_guard: mutex_guard}, 188 | Err(std::sync::TryLockError::WouldBlock) => panic!("This Context instance is in use now. Note that a Context instance is allowed per construction location and per thread. Consequently, it cannot be recursively constructed unless it is destructed. This is a limitation caused by the thread local static variables usages in the current implementation."), 189 | Err(std::sync::TryLockError::Poisoned(poison_error)) => panic!("{:?}", poison_error), 190 | } 191 | 192 | // TODO use checked_mul? 193 | // example : let usize_size = capacity.checked_mul(std::mem::size_of::()).expect("capacity overflow"); 194 | let usize_size = capacity * std::mem::size_of::(); 195 | let t_size = capacity * std::mem::size_of::<$InternalFloat>(); 196 | // std::rt::heap::allocate(t_size, mem::min_align_of::()) 197 | 198 | unsafe { 199 | use $crate::ContextModulePrivate; 200 | 201 | *Self::get_recorded_variables_count() = 0; 202 | *Self::get_recorded_entries_count() = 0; 203 | *Self::get_adjoints() = std::rt::heap::allocate(t_size, std::mem::align_of::<$InternalFloat>()) as *mut $InternalFloat; 204 | *Self::get_lhs_indices() = std::rt::heap::allocate(usize_size, std::mem::align_of::()) as *mut usize; 205 | *Self::get_rhs_indices() = std::rt::heap::allocate(usize_size, std::mem::align_of::()) as *mut usize; 206 | // TODO we don't have to allocate get_result_derivatives now, isn't it? 207 | *Self::get_result_derivatives() = std::rt::heap::allocate(t_size, std::mem::align_of::<$InternalFloat>()) as *mut $InternalFloat; 208 | } 209 | context 210 | } 211 | } 212 | 213 | impl std::ops::Drop for ContextImpl { 214 | fn drop(&mut self) { 215 | use $crate::ContextModulePrivate; 216 | 217 | // TODO Ideally we should detect overflow at the time it overflows. Similar to stack overflow problem. mprotect? 218 | assert!(*Self::get_recorded_variables_count() <= (ContextModulePrivate::<$InternalFloat>::capacity(self)), 219 | "There are more recorded variables, {}, than its capacity, {}. Memory is corrupted. Please consider using bigger capacity.", 220 | *Self::get_recorded_variables_count(), (ContextModulePrivate::<$InternalFloat>::capacity(self))); 221 | 222 | // TODO Do we want to reset these? 223 | // *Context::<$InternalFloat>::get_recorded_variables_count(None::) = 0; 224 | // *Context::<$InternalFloat>::get_adjoints(None::) = 0; 225 | 226 | let usize_size = self.capacity * std::mem::size_of::(); 227 | let t_size = self.capacity * std::mem::size_of::<$InternalFloat>(); 228 | 229 | unsafe { 230 | std::rt::heap::deallocate(*Self::get_adjoints() as *mut u8, t_size, std::mem::align_of::<$InternalFloat>()); 231 | std::rt::heap::deallocate(*Self::get_lhs_indices() as *mut u8, usize_size, std::mem::align_of::()); 232 | std::rt::heap::deallocate(*Self::get_rhs_indices() as *mut u8, usize_size, std::mem::align_of::()); 233 | std::rt::heap::deallocate(*Self::get_result_derivatives() as *mut u8, t_size, std::mem::align_of::<$InternalFloat>()); 234 | } 235 | } 236 | } 237 | 238 | ContextImpl::new($capacity) 239 | } 240 | ) 241 | } 242 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Mozilla Public License, version 2.0 2 | 3 | 1. Definitions 4 | 5 | 1.1. "Contributor" 6 | 7 | means each individual or legal entity that creates, contributes to the 8 | creation of, or owns Covered Software. 9 | 10 | 1.2. "Contributor Version" 11 | 12 | means the combination of the Contributions of others (if any) used by a 13 | Contributor and that particular Contributor's Contribution. 14 | 15 | 1.3. "Contribution" 16 | 17 | means Covered Software of a particular Contributor. 18 | 19 | 1.4. "Covered Software" 20 | 21 | means Source Code Form to which the initial Contributor has attached the 22 | notice in Exhibit A, the Executable Form of such Source Code Form, and 23 | Modifications of such Source Code Form, in each case including portions 24 | thereof. 25 | 26 | 1.5. "Incompatible With Secondary Licenses" 27 | means 28 | 29 | a. that the initial Contributor has attached the notice described in 30 | Exhibit B to the Covered Software; or 31 | 32 | b. that the Covered Software was made available under the terms of 33 | version 1.1 or earlier of the License, but not also under the terms of 34 | a Secondary License. 35 | 36 | 1.6. "Executable Form" 37 | 38 | means any form of the work other than Source Code Form. 39 | 40 | 1.7. "Larger Work" 41 | 42 | means a work that combines Covered Software with other material, in a 43 | separate file or files, that is not Covered Software. 44 | 45 | 1.8. "License" 46 | 47 | means this document. 48 | 49 | 1.9. "Licensable" 50 | 51 | means having the right to grant, to the maximum extent possible, whether 52 | at the time of the initial grant or subsequently, any and all of the 53 | rights conveyed by this License. 54 | 55 | 1.10. "Modifications" 56 | 57 | means any of the following: 58 | 59 | a. any file in Source Code Form that results from an addition to, 60 | deletion from, or modification of the contents of Covered Software; or 61 | 62 | b. any new file in Source Code Form that contains any Covered Software. 63 | 64 | 1.11. "Patent Claims" of a Contributor 65 | 66 | means any patent claim(s), including without limitation, method, 67 | process, and apparatus claims, in any patent Licensable by such 68 | Contributor that would be infringed, but for the grant of the License, 69 | by the making, using, selling, offering for sale, having made, import, 70 | or transfer of either its Contributions or its Contributor Version. 71 | 72 | 1.12. "Secondary License" 73 | 74 | means either the GNU General Public License, Version 2.0, the GNU Lesser 75 | General Public License, Version 2.1, the GNU Affero General Public 76 | License, Version 3.0, or any later versions of those licenses. 77 | 78 | 1.13. "Source Code Form" 79 | 80 | means the form of the work preferred for making modifications. 81 | 82 | 1.14. "You" (or "Your") 83 | 84 | means an individual or a legal entity exercising rights under this 85 | License. For legal entities, "You" includes any entity that controls, is 86 | controlled by, or is under common control with You. For purposes of this 87 | definition, "control" means (a) the power, direct or indirect, to cause 88 | the direction or management of such entity, whether by contract or 89 | otherwise, or (b) ownership of more than fifty percent (50%) of the 90 | outstanding shares or beneficial ownership of such entity. 91 | 92 | 93 | 2. License Grants and Conditions 94 | 95 | 2.1. Grants 96 | 97 | Each Contributor hereby grants You a world-wide, royalty-free, 98 | non-exclusive license: 99 | 100 | a. under intellectual property rights (other than patent or trademark) 101 | Licensable by such Contributor to use, reproduce, make available, 102 | modify, display, perform, distribute, and otherwise exploit its 103 | Contributions, either on an unmodified basis, with Modifications, or 104 | as part of a Larger Work; and 105 | 106 | b. under Patent Claims of such Contributor to make, use, sell, offer for 107 | sale, have made, import, and otherwise transfer either its 108 | Contributions or its Contributor Version. 109 | 110 | 2.2. Effective Date 111 | 112 | The licenses granted in Section 2.1 with respect to any Contribution 113 | become effective for each Contribution on the date the Contributor first 114 | distributes such Contribution. 115 | 116 | 2.3. Limitations on Grant Scope 117 | 118 | The licenses granted in this Section 2 are the only rights granted under 119 | this License. No additional rights or licenses will be implied from the 120 | distribution or licensing of Covered Software under this License. 121 | Notwithstanding Section 2.1(b) above, no patent license is granted by a 122 | Contributor: 123 | 124 | a. for any code that a Contributor has removed from Covered Software; or 125 | 126 | b. for infringements caused by: (i) Your and any other third party's 127 | modifications of Covered Software, or (ii) the combination of its 128 | Contributions with other software (except as part of its Contributor 129 | Version); or 130 | 131 | c. under Patent Claims infringed by Covered Software in the absence of 132 | its Contributions. 133 | 134 | This License does not grant any rights in the trademarks, service marks, 135 | or logos of any Contributor (except as may be necessary to comply with 136 | the notice requirements in Section 3.4). 137 | 138 | 2.4. Subsequent Licenses 139 | 140 | No Contributor makes additional grants as a result of Your choice to 141 | distribute the Covered Software under a subsequent version of this 142 | License (see Section 10.2) or under the terms of a Secondary License (if 143 | permitted under the terms of Section 3.3). 144 | 145 | 2.5. Representation 146 | 147 | Each Contributor represents that the Contributor believes its 148 | Contributions are its original creation(s) or it has sufficient rights to 149 | grant the rights to its Contributions conveyed by this License. 150 | 151 | 2.6. Fair Use 152 | 153 | This License is not intended to limit any rights You have under 154 | applicable copyright doctrines of fair use, fair dealing, or other 155 | equivalents. 156 | 157 | 2.7. Conditions 158 | 159 | Sections 3.1, 3.2, 3.3, and 3.4 are conditions of the licenses granted in 160 | Section 2.1. 161 | 162 | 163 | 3. Responsibilities 164 | 165 | 3.1. Distribution of Source Form 166 | 167 | All distribution of Covered Software in Source Code Form, including any 168 | Modifications that You create or to which You contribute, must be under 169 | the terms of this License. You must inform recipients that the Source 170 | Code Form of the Covered Software is governed by the terms of this 171 | License, and how they can obtain a copy of this License. You may not 172 | attempt to alter or restrict the recipients' rights in the Source Code 173 | Form. 174 | 175 | 3.2. Distribution of Executable Form 176 | 177 | If You distribute Covered Software in Executable Form then: 178 | 179 | a. such Covered Software must also be made available in Source Code Form, 180 | as described in Section 3.1, and You must inform recipients of the 181 | Executable Form how they can obtain a copy of such Source Code Form by 182 | reasonable means in a timely manner, at a charge no more than the cost 183 | of distribution to the recipient; and 184 | 185 | b. You may distribute such Executable Form under the terms of this 186 | License, or sublicense it under different terms, provided that the 187 | license for the Executable Form does not attempt to limit or alter the 188 | recipients' rights in the Source Code Form under this License. 189 | 190 | 3.3. Distribution of a Larger Work 191 | 192 | You may create and distribute a Larger Work under terms of Your choice, 193 | provided that You also comply with the requirements of this License for 194 | the Covered Software. If the Larger Work is a combination of Covered 195 | Software with a work governed by one or more Secondary Licenses, and the 196 | Covered Software is not Incompatible With Secondary Licenses, this 197 | License permits You to additionally distribute such Covered Software 198 | under the terms of such Secondary License(s), so that the recipient of 199 | the Larger Work may, at their option, further distribute the Covered 200 | Software under the terms of either this License or such Secondary 201 | License(s). 202 | 203 | 3.4. Notices 204 | 205 | You may not remove or alter the substance of any license notices 206 | (including copyright notices, patent notices, disclaimers of warranty, or 207 | limitations of liability) contained within the Source Code Form of the 208 | Covered Software, except that You may alter any license notices to the 209 | extent required to remedy known factual inaccuracies. 210 | 211 | 3.5. Application of Additional Terms 212 | 213 | You may choose to offer, and to charge a fee for, warranty, support, 214 | indemnity or liability obligations to one or more recipients of Covered 215 | Software. However, You may do so only on Your own behalf, and not on 216 | behalf of any Contributor. You must make it absolutely clear that any 217 | such warranty, support, indemnity, or liability obligation is offered by 218 | You alone, and You hereby agree to indemnify every Contributor for any 219 | liability incurred by such Contributor as a result of warranty, support, 220 | indemnity or liability terms You offer. You may include additional 221 | disclaimers of warranty and limitations of liability specific to any 222 | jurisdiction. 223 | 224 | 4. Inability to Comply Due to Statute or Regulation 225 | 226 | If it is impossible for You to comply with any of the terms of this License 227 | with respect to some or all of the Covered Software due to statute, 228 | judicial order, or regulation then You must: (a) comply with the terms of 229 | this License to the maximum extent possible; and (b) describe the 230 | limitations and the code they affect. Such description must be placed in a 231 | text file included with all distributions of the Covered Software under 232 | this License. Except to the extent prohibited by statute or regulation, 233 | such description must be sufficiently detailed for a recipient of ordinary 234 | skill to be able to understand it. 235 | 236 | 5. Termination 237 | 238 | 5.1. The rights granted under this License will terminate automatically if You 239 | fail to comply with any of its terms. However, if You become compliant, 240 | then the rights granted under this License from a particular Contributor 241 | are reinstated (a) provisionally, unless and until such Contributor 242 | explicitly and finally terminates Your grants, and (b) on an ongoing 243 | basis, if such Contributor fails to notify You of the non-compliance by 244 | some reasonable means prior to 60 days after You have come back into 245 | compliance. Moreover, Your grants from a particular Contributor are 246 | reinstated on an ongoing basis if such Contributor notifies You of the 247 | non-compliance by some reasonable means, this is the first time You have 248 | received notice of non-compliance with this License from such 249 | Contributor, and You become compliant prior to 30 days after Your receipt 250 | of the notice. 251 | 252 | 5.2. If You initiate litigation against any entity by asserting a patent 253 | infringement claim (excluding declaratory judgment actions, 254 | counter-claims, and cross-claims) alleging that a Contributor Version 255 | directly or indirectly infringes any patent, then the rights granted to 256 | You by any and all Contributors for the Covered Software under Section 257 | 2.1 of this License shall terminate. 258 | 259 | 5.3. In the event of termination under Sections 5.1 or 5.2 above, all end user 260 | license agreements (excluding distributors and resellers) which have been 261 | validly granted by You or Your distributors under this License prior to 262 | termination shall survive termination. 263 | 264 | 6. Disclaimer of Warranty 265 | 266 | Covered Software is provided under this License on an "as is" basis, 267 | without warranty of any kind, either expressed, implied, or statutory, 268 | including, without limitation, warranties that the Covered Software is free 269 | of defects, merchantable, fit for a particular purpose or non-infringing. 270 | The entire risk as to the quality and performance of the Covered Software 271 | is with You. Should any Covered Software prove defective in any respect, 272 | You (not any Contributor) assume the cost of any necessary servicing, 273 | repair, or correction. This disclaimer of warranty constitutes an essential 274 | part of this License. No use of any Covered Software is authorized under 275 | this License except under this disclaimer. 276 | 277 | 7. Limitation of Liability 278 | 279 | Under no circumstances and under no legal theory, whether tort (including 280 | negligence), contract, or otherwise, shall any Contributor, or anyone who 281 | distributes Covered Software as permitted above, be liable to You for any 282 | direct, indirect, special, incidental, or consequential damages of any 283 | character including, without limitation, damages for lost profits, loss of 284 | goodwill, work stoppage, computer failure or malfunction, or any and all 285 | other commercial damages or losses, even if such party shall have been 286 | informed of the possibility of such damages. This limitation of liability 287 | shall not apply to liability for death or personal injury resulting from 288 | such party's negligence to the extent applicable law prohibits such 289 | limitation. Some jurisdictions do not allow the exclusion or limitation of 290 | incidental or consequential damages, so this exclusion and limitation may 291 | not apply to You. 292 | 293 | 8. Litigation 294 | 295 | Any litigation relating to this License may be brought only in the courts 296 | of a jurisdiction where the defendant maintains its principal place of 297 | business and such litigation shall be governed by laws of that 298 | jurisdiction, without reference to its conflict-of-law provisions. Nothing 299 | in this Section shall prevent a party's ability to bring cross-claims or 300 | counter-claims. 301 | 302 | 9. Miscellaneous 303 | 304 | This License represents the complete agreement concerning the subject 305 | matter hereof. If any provision of this License is held to be 306 | unenforceable, such provision shall be reformed only to the extent 307 | necessary to make it enforceable. Any law or regulation which provides that 308 | the language of a contract shall be construed against the drafter shall not 309 | be used to construe this License against a Contributor. 310 | 311 | 312 | 10. Versions of the License 313 | 314 | 10.1. New Versions 315 | 316 | Mozilla Foundation is the license steward. Except as provided in Section 317 | 10.3, no one other than the license steward has the right to modify or 318 | publish new versions of this License. Each version will be given a 319 | distinguishing version number. 320 | 321 | 10.2. Effect of New Versions 322 | 323 | You may distribute the Covered Software under the terms of the version 324 | of the License under which You originally received the Covered Software, 325 | or under the terms of any subsequent version published by the license 326 | steward. 327 | 328 | 10.3. Modified Versions 329 | 330 | If you create software not governed by this License, and you want to 331 | create a new license for such software, you may create and use a 332 | modified version of this License if you rename the license and remove 333 | any references to the name of the license steward (except to note that 334 | such modified license differs from this License). 335 | 336 | 10.4. Distributing Source Code Form that is Incompatible With Secondary 337 | Licenses If You choose to distribute Source Code Form that is 338 | Incompatible With Secondary Licenses under the terms of this version of 339 | the License, the notice described in Exhibit B of this License must be 340 | attached. 341 | 342 | Exhibit A - Source Code Form License Notice 343 | 344 | This Source Code Form is subject to the 345 | terms of the Mozilla Public License, v. 346 | 2.0. If a copy of the MPL was not 347 | distributed with this file, You can 348 | obtain one at 349 | http://mozilla.org/MPL/2.0/. 350 | 351 | If it is not possible or desirable to put the notice in a particular file, 352 | then You may include the notice in a location (such as a LICENSE file in a 353 | relevant directory) where a recipient would be likely to look for such a 354 | notice. 355 | 356 | You may add additional accurate notices of copyright ownership. 357 | 358 | Exhibit B - "Incompatible With Secondary Licenses" Notice 359 | 360 | This Source Code Form is "Incompatible 361 | With Secondary Licenses", as defined by 362 | the Mozilla Public License, v. 2.0. 363 | 364 | -------------------------------------------------------------------------------- /src/float.rs: -------------------------------------------------------------------------------- 1 | // This Source Code Form is subject to the terms of the Mozilla Public 2 | // License, v. 2.0. If a copy of the MPL was not distributed with this 3 | // file, You can obtain one at http://mozilla.org/MPL/2.0/. 4 | 5 | use num; 6 | use std; 7 | 8 | // TODO derive Hash? 9 | #[derive(Debug)] 10 | pub struct Float 11 | where InternalFloat: num::Float, 12 | CT: super::context::Context 13 | { 14 | pub value: InternalFloat, 15 | index: usize, 16 | phantom_context: std::marker::PhantomData, 17 | } 18 | 19 | // TODO implement num::Float 20 | impl num::Float for Float where InternalFloat: num::Float, CT: super::context::Context { 21 | fn nan() -> Self { 22 | unimplemented!(); 23 | } 24 | 25 | fn infinity() -> Self { 26 | unimplemented!(); 27 | } 28 | 29 | fn neg_infinity() -> Self { 30 | unimplemented!(); 31 | } 32 | 33 | fn neg_zero() -> Self { 34 | unimplemented!(); 35 | } 36 | 37 | fn min_value() -> Self { 38 | unimplemented!(); 39 | } 40 | 41 | fn max_value() -> Self { 42 | unimplemented!(); 43 | } 44 | 45 | fn min_positive_value() -> Self { 46 | unimplemented!(); 47 | } 48 | 49 | fn is_nan(self) -> bool { 50 | self.value.is_nan() 51 | } 52 | 53 | fn is_infinite(self) -> bool { 54 | self.value.is_infinite() 55 | } 56 | 57 | fn is_finite(self) -> bool { 58 | self.value.is_finite() 59 | } 60 | 61 | fn is_normal(self) -> bool { 62 | self.value.is_normal() 63 | } 64 | 65 | fn classify(self) -> std::num::FpCategory { 66 | self.value.classify() 67 | } 68 | 69 | fn floor(self) -> Self { 70 | unimplemented!(); 71 | } 72 | 73 | fn ceil(self) -> Self { 74 | unimplemented!(); 75 | } 76 | 77 | fn round(self) -> Self { 78 | unimplemented!(); 79 | } 80 | 81 | fn trunc(self) -> Self { 82 | unimplemented!(); 83 | } 84 | 85 | fn fract(self) -> Self { 86 | unimplemented!(); 87 | } 88 | 89 | fn abs(self) -> Self { 90 | unimplemented!(); 91 | } 92 | 93 | fn signum(self) -> Self { 94 | unimplemented!(); 95 | } 96 | 97 | fn is_sign_positive(self) -> bool { 98 | self.value.is_sign_positive() 99 | } 100 | 101 | fn is_sign_negative(self) -> bool { 102 | self.value.is_sign_negative() 103 | } 104 | 105 | #[allow(unused_variables)] 106 | fn mul_add(self, a: Self, b: Self) -> Self { 107 | unimplemented!(); 108 | } 109 | 110 | fn recip(self) -> Self { 111 | unimplemented!(); 112 | } 113 | 114 | #[allow(unused_variables)] 115 | fn powi(self, n: i32) -> Self { 116 | unimplemented!(); 117 | } 118 | 119 | #[allow(unused_variables)] 120 | fn powf(self, n: Self) -> Self { 121 | unimplemented!(); 122 | } 123 | 124 | fn sqrt(self) -> Self { 125 | let two = InternalFloat::one() + InternalFloat::one(); 126 | let sqrt_value = self.value.sqrt(); 127 | Float { 128 | value: sqrt_value, 129 | index: CT::unary_operation((two * sqrt_value).recip(), self.index), 130 | phantom_context: std::marker::PhantomData, 131 | } 132 | } 133 | 134 | fn exp(self) -> Self { 135 | let exp_value = self.value.exp(); 136 | Float { 137 | value: exp_value, 138 | index: CT::unary_operation(exp_value, self.index), 139 | phantom_context: std::marker::PhantomData, 140 | } 141 | } 142 | 143 | fn exp2(self) -> Self { 144 | unimplemented!(); 145 | } 146 | 147 | fn ln(self) -> Self { 148 | unimplemented!(); 149 | } 150 | 151 | #[allow(unused_variables)] 152 | fn log(self, base: Self) -> Self { 153 | unimplemented!(); 154 | } 155 | 156 | fn log2(self) -> Self { 157 | unimplemented!(); 158 | } 159 | 160 | fn log10(self) -> Self { 161 | unimplemented!(); 162 | } 163 | 164 | #[allow(unused_variables)] 165 | fn max(self, other: Self) -> Self { 166 | unimplemented!(); 167 | } 168 | 169 | #[allow(unused_variables)] 170 | fn min(self, other: Self) -> Self { 171 | unimplemented!(); 172 | } 173 | 174 | #[allow(unused_variables)] 175 | fn abs_sub(self, other: Self) -> Self { 176 | unimplemented!(); 177 | } 178 | 179 | fn cbrt(self) -> Self { 180 | unimplemented!(); 181 | } 182 | 183 | #[allow(unused_variables)] 184 | fn hypot(self, other: Self) -> Self { 185 | unimplemented!(); 186 | } 187 | 188 | fn sin(self) -> Self { 189 | Float { 190 | value: self.value.sin(), 191 | index: CT::unary_operation(self.value.cos(), self.index), 192 | phantom_context: std::marker::PhantomData, 193 | } 194 | } 195 | 196 | fn cos(self) -> Self { 197 | Float { 198 | value: self.value.cos(), 199 | index: CT::unary_operation(-self.value.sin(), self.index), 200 | phantom_context: std::marker::PhantomData, 201 | } 202 | } 203 | 204 | fn tan(self) -> Self { 205 | unimplemented!(); 206 | } 207 | 208 | fn asin(self) -> Self { 209 | unimplemented!(); 210 | } 211 | 212 | fn acos(self) -> Self { 213 | unimplemented!(); 214 | } 215 | 216 | fn atan(self) -> Self { 217 | unimplemented!(); 218 | } 219 | 220 | #[allow(unused_variables)] 221 | fn atan2(self, other: Self) -> Self { 222 | unimplemented!(); 223 | } 224 | 225 | fn sin_cos(self) -> (Self, Self) { 226 | unimplemented!(); 227 | } 228 | 229 | fn exp_m1(self) -> Self { 230 | unimplemented!(); 231 | } 232 | 233 | fn ln_1p(self) -> Self { 234 | unimplemented!(); 235 | } 236 | 237 | fn sinh(self) -> Self { 238 | unimplemented!(); 239 | } 240 | 241 | fn cosh(self) -> Self { 242 | unimplemented!(); 243 | } 244 | 245 | fn tanh(self) -> Self { 246 | unimplemented!(); 247 | } 248 | 249 | fn asinh(self) -> Self { 250 | unimplemented!(); 251 | } 252 | 253 | fn acosh(self) -> Self { 254 | unimplemented!(); 255 | } 256 | 257 | fn atanh(self) -> Self { 258 | unimplemented!(); 259 | } 260 | 261 | fn integer_decode(self) -> (u64, i16, i8) { 262 | self.value.integer_decode() 263 | } 264 | } 265 | 266 | impl num::Zero for Float where InternalFloat: num::Float, CT: super::context::Context { 267 | fn zero() -> Self { 268 | unimplemented!(); 269 | } 270 | fn is_zero(&self) -> bool { 271 | self.value.is_zero() 272 | } 273 | } 274 | 275 | impl num::One for Float where InternalFloat: num::Float, CT: super::context::Context { 276 | fn one() -> Self { 277 | unimplemented!(); 278 | } 279 | } 280 | 281 | impl num::Num for Float where InternalFloat: num::Float, CT: super::context::Context { 282 | type FromStrRadixErr = num::traits::ParseFloatError; 283 | 284 | #[allow(unused_variables)] 285 | fn from_str_radix(str: &str, radix: u32) -> Result { 286 | unimplemented!(); 287 | } 288 | } 289 | 290 | impl std::clone::Clone for Float where InternalFloat: num::Float, CT: super::context::Context { 291 | fn clone(&self) -> Self { 292 | Float { 293 | value: self.value, 294 | index: self.index, 295 | phantom_context: std::marker::PhantomData, 296 | } 297 | } 298 | } 299 | 300 | 301 | impl num::ToPrimitive for Float where InternalFloat: num::Float, CT: super::context::Context { 302 | fn to_i64(&self) -> Option { 303 | self.value.to_i64() 304 | } 305 | 306 | fn to_u64(&self) -> Option { 307 | self.value.to_u64() 308 | } 309 | // TODO implement the rest optional function. 310 | } 311 | 312 | impl num::NumCast for Float where InternalFloat: num::Float, CT: super::context::Context { 313 | #[allow(unused_variables)] 314 | fn from(n: TP) -> Option 315 | where TP: num::ToPrimitive 316 | { 317 | panic!("We disallow constructing from a constant because it is unnecessary."); 318 | } 319 | } 320 | 321 | impl std::ops::Neg for Float where InternalFloat: num::Float, CT: super::context::Context { 322 | type Output = Float; 323 | fn neg(self) -> Float { 324 | Float { 325 | value: -self.value, 326 | index: CT::unary_operation(InternalFloat::one().neg(), self.index), 327 | phantom_context: std::marker::PhantomData, 328 | } 329 | } 330 | } 331 | 332 | // TODO adjoints of 1 can be optimized out, i.e., not multiplying. Should we? 333 | // TODO implement operations with underlying type. 334 | impl std::ops::Add> for Float where InternalFloat: num::Float, CT: super::context::Context { 335 | type Output = Float; 336 | fn add(self, other: Float) -> Float { 337 | Float { 338 | value: self.value + other.value, 339 | index: CT::binary_operation(&[num::One::one(), num::One::one()], 340 | &[self.index, other.index]), 341 | phantom_context: std::marker::PhantomData, 342 | } 343 | } 344 | } 345 | 346 | impl std::ops::Add for Float where InternalFloat: num::Float, CT: super::context::Context { 347 | type Output = Float; 348 | fn add(self, other: InternalFloat) -> Float { 349 | Float { 350 | value: self.value + other, 351 | index: CT::unary_operation(num::One::one(), self.index), 352 | phantom_context: std::marker::PhantomData, 353 | } 354 | } 355 | } 356 | 357 | impl std::ops::Sub> for Float where InternalFloat: num::Float, CT: super::context::Context { 358 | type Output = Float; 359 | fn sub(self, other: Float) -> Float { 360 | let one: InternalFloat = num::One::one(); 361 | Float { 362 | value: self.value - other.value, 363 | index: CT::binary_operation(&[one, -one], &[self.index, other.index]), 364 | phantom_context: std::marker::PhantomData, 365 | } 366 | } 367 | } 368 | 369 | impl std::ops::Sub for Float where InternalFloat: num::Float, CT: super::context::Context { 370 | type Output = Float; 371 | fn sub(self, other: InternalFloat) -> Float { 372 | Float { 373 | value: self.value - other, 374 | index: CT::unary_operation(num::One::one(), self.index), 375 | phantom_context: std::marker::PhantomData, 376 | } 377 | } 378 | } 379 | 380 | impl std::ops::Mul> for Float where InternalFloat: num::Float, CT: super::context::Context { 381 | type Output = Float; 382 | fn mul(self, other: Float) -> Float { 383 | Float { 384 | value: self.value * other.value, 385 | index: CT::binary_operation(&[other.value, self.value], &[self.index, other.index]), 386 | phantom_context: std::marker::PhantomData, 387 | } 388 | } 389 | } 390 | 391 | impl std::ops::Mul for Float where InternalFloat: num::Float, CT: super::context::Context { 392 | type Output = Float; 393 | fn mul(self, other: InternalFloat) -> Float { 394 | Float { 395 | value: self.value * other, 396 | index: CT::unary_operation(other, self.index), 397 | phantom_context: std::marker::PhantomData, 398 | } 399 | } 400 | } 401 | 402 | impl std::ops::Div> for Float where InternalFloat: num::Float, CT: super::context::Context { 403 | type Output = Float; 404 | fn div(self, other: Float) -> Float { 405 | Float { 406 | value: self.value / other.value, 407 | index: CT::binary_operation(&[other.value.recip(), 408 | -((self.value * self.value).recip())], 409 | &[self.index, other.index]), 410 | phantom_context: std::marker::PhantomData, 411 | } 412 | } 413 | } 414 | 415 | impl std::ops::Div for Float where InternalFloat: num::Float, CT: super::context::Context { 416 | type Output = Float; 417 | fn div(self, other: InternalFloat) -> Float { 418 | Float { 419 | value: self.value / other, 420 | index: CT::unary_operation(other.recip(), self.index), 421 | phantom_context: std::marker::PhantomData, 422 | } 423 | } 424 | } 425 | 426 | // TODO 1. Does it make sense to support % operator between two Floats ? 427 | // 2. If so, should we record other even though the multiplier is 0? 428 | impl std::ops::Rem> for Float where InternalFloat: num::Float, CT: super::context::Context { 429 | type Output = Float; 430 | fn rem(self, other: Float) -> Float { 431 | // TODO add this kind of assert everywhere. 432 | // assert!(( self . context as * const super::context::Context < InternalFloat > 433 | // ) == ( 434 | // other . context as * const super::context::Context < InternalFloat > 435 | // )); 436 | Float { 437 | value: self.value % other.value, 438 | index: CT::binary_operation(&[num::One::one(), num::Zero::zero()], 439 | &[self.index, other.index]), 440 | phantom_context: std::marker::PhantomData, 441 | } 442 | } 443 | } 444 | 445 | impl std::ops::Rem for Float where InternalFloat: num::Float, CT: super::context::Context { 446 | type Output = Float; 447 | fn rem(self, other: InternalFloat) -> Float { 448 | Float { 449 | value: self.value % other, 450 | index: CT::unary_operation(num::One::one(), self.index), 451 | phantom_context: std::marker::PhantomData, 452 | } 453 | } 454 | } 455 | 456 | macro_rules! impl_std_ops { 457 | ($InternalFloat:ty) => ( 458 | impl std::ops::Add> for $InternalFloat where CT: super::context::Context<$InternalFloat> { 459 | type Output = Float<$InternalFloat, CT>; 460 | fn add(self, other: Float<$InternalFloat, CT>) -> Float<$InternalFloat, CT> { 461 | other.add(self) 462 | } 463 | } 464 | 465 | impl std::ops::Sub> for $InternalFloat where CT: super::context::Context<$InternalFloat> { 466 | type Output = Float<$InternalFloat, CT>; 467 | fn sub(self, other: Float<$InternalFloat, CT>) -> Float<$InternalFloat, CT> { 468 | Float{value: self - other.value, 469 | index: CT::unary_operation( 470 | -<$InternalFloat as num::One>::one(), other.index), 471 | phantom_context: std::marker::PhantomData} 472 | } 473 | } 474 | 475 | impl std::ops::Mul> for $InternalFloat where CT: super::context::Context<$InternalFloat> { 476 | type Output = Float<$InternalFloat, CT>; 477 | fn mul(self, other: Float<$InternalFloat, CT>) -> Float<$InternalFloat, CT> { 478 | other.mul(self) 479 | } 480 | } 481 | 482 | impl std::ops::Div> for $InternalFloat where CT: super::context::Context<$InternalFloat> { 483 | type Output = Float<$InternalFloat, CT>; 484 | fn div(self, other: Float<$InternalFloat, CT>) -> Float<$InternalFloat, CT> { 485 | Float{value: self / other.value, 486 | index: CT::unary_operation( 487 | -((other.value * other.value).recip()), other.index), 488 | phantom_context: std::marker::PhantomData} 489 | } 490 | } 491 | 492 | impl std::ops::Rem> for $InternalFloat where CT: super::context::Context<$InternalFloat> { 493 | type Output = Float<$InternalFloat, CT>; 494 | #[allow(unused_variables)] 495 | fn rem(self, other: Float<$InternalFloat, CT>) -> Float<$InternalFloat, CT> { 496 | unimplemented!(); 497 | } 498 | } 499 | ) 500 | } 501 | 502 | // TODO We shouldn't hard code f32 and f64. Or at least, users should be able 503 | // to implement this for their choice of Float type without modifying this 504 | // library. 505 | impl_std_ops!(f32); 506 | impl_std_ops!(f64); 507 | 508 | impl std::cmp::PartialEq for Float where InternalFloat: num::Float, CT: super::context::Context { 509 | fn eq(&self, other: &Float) -> bool { 510 | self.value == other.value 511 | } 512 | } 513 | 514 | impl std::cmp::PartialOrd for Float where InternalFloat: num::Float, CT: super::context::Context { 515 | fn partial_cmp(&self, other: &Float) -> Option { 516 | self.value.partial_cmp(&other.value) 517 | } 518 | } 519 | 520 | impl std::marker::Copy for Float where InternalFloat: num::Float, CT: super::context::Context { 521 | } 522 | 523 | pub trait FloatCratePrivate where InternalFloat: num::Float, CT: super::context::Context { 524 | fn new(value: InternalFloat, index: usize) -> Self; 525 | fn float_get_index(&self) -> usize; 526 | } 527 | 528 | impl FloatCratePrivate for Float where InternalFloat: num::Float, CT: super::context::Context { 529 | fn new(value: InternalFloat, index: usize) -> Self { 530 | Float { 531 | value: value, 532 | index: index, 533 | phantom_context: std::marker::PhantomData, 534 | } 535 | } 536 | 537 | fn float_get_index(&self) -> usize { 538 | self.index 539 | } 540 | } 541 | --------------------------------------------------------------------------------