├── .gitignore ├── Cargo.toml ├── LICENSE ├── README.md └── src ├── lib.rs └── main.rs /.gitignore: -------------------------------------------------------------------------------- 1 | # Compiled files 2 | *.o 3 | *.so 4 | *.rlib 5 | *.dll 6 | 7 | # Executables 8 | *.exe 9 | 10 | # Generated by Cargo 11 | /target/ 12 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | 3 | name = "ad" 4 | version = "0.0.1" 5 | authors = ["Igor Babuschkin "] 6 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2014 Igor Babuschkin 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | 23 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # rust-ad 2 | 3 | Automatic differentiation library for rust. 4 | Currently only supports first order forward AD. 5 | 6 | ## Example 7 | 8 | Calculate gradient of exp(x/y^2) at (1, 2): 9 | ```rust 10 | let result = ad::grad(|x| { Float::exp(x[0] / Float::powi(x[1], 2)) }, vec![1.0, 2.0]); 11 | println!("Out: {}", result); 12 | // Out: [0.321006, -0.321006] 13 | ``` 14 | -------------------------------------------------------------------------------- /src/lib.rs: -------------------------------------------------------------------------------- 1 | #![feature(default_type_params)] 2 | 3 | use std::num::Float; 4 | use std::num::FloatMath; 5 | use std::num::NumCast; 6 | use std::num::FpCategory; 7 | use std::f64; 8 | 9 | #[allow(non_camel_case_types)] 10 | #[deriving(Show,Copy,Clone)] 11 | pub struct Num { 12 | pub val: f64, 13 | pub eps: f64, 14 | } 15 | 16 | impl Neg for Num { 17 | fn neg(self) -> Num { 18 | Num { val: -self.val, 19 | eps: -self.eps } 20 | } 21 | } 22 | 23 | impl Add for Num { 24 | fn add(self, _rhs: Num) -> Num { 25 | Num { val: self.val + _rhs.val, 26 | eps: self.eps + _rhs.eps } 27 | } 28 | } 29 | 30 | impl Sub for Num { 31 | fn sub(self, _rhs: Num) -> Num { 32 | Num { val: self.val - _rhs.val, 33 | eps: self.eps - _rhs.eps } 34 | } 35 | } 36 | 37 | impl Mul for Num { 38 | fn mul(self, _rhs: Num) -> Num { 39 | Num { val: self.val * _rhs.val, 40 | eps: self.eps * _rhs.val + self.val * _rhs.eps } 41 | } 42 | } 43 | 44 | impl Mul for f64 { 45 | fn mul(self, _rhs: Num) -> Num { 46 | Num { val: self * _rhs.val, 47 | eps: self * _rhs.eps } 48 | } 49 | } 50 | 51 | impl Mul for Num { 52 | fn mul(self, _rhs: f64) -> Num { 53 | Num { val: self.val * _rhs, 54 | eps: self.eps * _rhs } 55 | } 56 | } 57 | 58 | impl Div for Num { 59 | fn div(self, _rhs: Num) -> Num { 60 | Num { val: self.val / _rhs.val, 61 | eps: (self.eps * _rhs.val - self.val * _rhs.eps) 62 | / (_rhs.val * _rhs.val) } 63 | } 64 | } 65 | 66 | impl Rem for Num { 67 | fn rem(self, _rhs: Num) -> Num { 68 | panic!("Remainder not implemented") 69 | } 70 | } 71 | 72 | impl PartialEq for Num { 73 | fn eq(&self, _rhs: &Num) -> bool { 74 | self.val == _rhs.val 75 | } 76 | } 77 | 78 | impl PartialOrd for Num { 79 | fn partial_cmp(&self, other: &Num) -> Option { 80 | PartialOrd::partial_cmp(&self.val, &other.val) 81 | } 82 | } 83 | 84 | impl ToPrimitive for Num { 85 | fn to_i64(&self) -> Option { self.val.to_i64() } 86 | fn to_u64(&self) -> Option { self.val.to_u64() } 87 | fn to_int(&self) -> Option { self.val.to_int() } 88 | fn to_i8(&self) -> Option { self.val.to_i8() } 89 | fn to_i16(&self) -> Option { self.val.to_i16() } 90 | fn to_i32(&self) -> Option { self.val.to_i32() } 91 | fn to_uint(&self) -> Option { self.val.to_uint() } 92 | fn to_u8(&self) -> Option { self.val.to_u8() } 93 | fn to_u16(&self) -> Option { self.val.to_u16() } 94 | fn to_u32(&self) -> Option { self.val.to_u32() } 95 | fn to_f32(&self) -> Option { self.val.to_f32() } 96 | fn to_f64(&self) -> Option { self.val.to_f64() } 97 | } 98 | 99 | impl NumCast for Num { 100 | fn from (n: T) -> Option { 101 | let _val = n.to_f64(); 102 | match _val { 103 | Some(x) => Some(Num { val: x, eps: 0.0 }), 104 | None => None 105 | } 106 | } 107 | } 108 | 109 | impl Float for Num { 110 | fn nan() -> Num { Num { val: f64::NAN, eps: 0.0 } } 111 | fn infinity() -> Num { Num { val: f64::INFINITY, eps: 0.0 } } 112 | fn neg_infinity() -> Num { Num { val: f64::NEG_INFINITY, eps: 0.0 } } 113 | fn zero() -> Num { Num { val: 0.0, eps: 0.0 } } 114 | fn neg_zero() -> Num { Num { val: -0.0, eps: 0.0 } } 115 | fn one() -> Num { Num { val: 1.0, eps: 0.0 } } 116 | fn is_nan(self) -> bool { self.val.is_nan() || self.eps.is_nan() } 117 | fn is_infinite(self) -> bool { self.val.is_infinite() || self.eps.is_infinite() } 118 | fn is_finite(self) -> bool { self.val.is_finite() && self.eps.is_finite() } 119 | fn is_normal(self) -> bool { self.val.is_normal() && self.eps.is_normal() } 120 | fn classify(self) -> FpCategory { self.val.classify() } 121 | #[allow(unused_variables)] 122 | fn mantissa_digits(unused_self: Option) -> uint { f64::MANTISSA_DIGITS } 123 | #[allow(unused_variables)] 124 | fn digits(unused_self: Option) -> uint { f64::DIGITS } 125 | fn epsilon() -> Num { Num { val: f64::EPSILON, eps: 0.0 } } 126 | #[allow(unused_variables)] 127 | fn min_exp(unused_self: Option) -> int { f64::MIN_EXP } 128 | #[allow(unused_variables)] 129 | fn max_exp(unused_self: Option) -> int { f64::MAX_EXP } 130 | #[allow(unused_variables)] 131 | fn min_10_exp(unused_self: Option) -> int { f64::MIN_10_EXP } 132 | #[allow(unused_variables)] 133 | fn max_10_exp(unused_self: Option) -> int { f64::MAX_10_EXP } 134 | fn min_value() -> Num { Num { val: f64::MIN_VALUE, eps: 0.0 } } 135 | #[allow(unused_variables)] 136 | fn min_pos_value(unused_self: Option) -> Num { Num { val: f64::MIN_POS_VALUE, eps: 0.0 } } 137 | fn max_value() -> Num { Num { val: f64::MAX_VALUE, eps: 0.0 } } 138 | fn integer_decode(self) -> (u64, i16, i8) { self.val.integer_decode() } 139 | fn floor(self) -> Num { Num { val: self.val.floor(), eps: self.eps } } 140 | fn ceil(self) -> Num { Num { val: self.val.ceil(), eps: self.eps } } 141 | fn round(self) -> Num { Num { val: self.val.round(), eps: self.eps } } 142 | fn trunc(self) -> Num { Num { val: self.val.trunc(), eps: self.eps } } 143 | fn fract(self) -> Num { Num { val: self.val.fract(), eps: self.eps } } 144 | fn abs(self) -> Num { 145 | if self.val >= 0.0 { 146 | Num { val: self.val.abs(), eps: self.eps } 147 | } else { 148 | Num { val: self.val.abs(), eps: -self.eps } 149 | } 150 | } 151 | fn signum(self) -> Num { Num { val: self.val.signum(), eps: 0.0 } } 152 | fn is_positive(self) -> bool { self.val.is_positive() } 153 | fn is_negative(self) -> bool { self.val.is_negative() } 154 | fn mul_add(self, a: Num, b: Num) -> Num { 155 | self * a + b 156 | } 157 | fn recip(self) -> Num { Num { val: self.val.recip(), eps: -self.eps/(self.val * self.val) } } 158 | fn powi(self, n: i32) -> Num { 159 | Num { 160 | val: self.val.powi(n), 161 | eps: self.eps * n as f64 * self.val.powi(n - 1) 162 | } 163 | } 164 | fn powf(self, n: Num) -> Num { 165 | Num { 166 | val: Float::powf(self.val, n.val), 167 | eps: (Float::ln(self.val) * n.eps + n.val * self.eps / self.val) * Float::powf(self.val, n.val) 168 | } 169 | } 170 | fn sqrt2() -> Num { Num { val: f64::consts::SQRT2, eps: 0.0} } 171 | fn frac_1_sqrt2() -> Num { Num { val: f64::consts::FRAC_1_SQRT2, eps: 0.0} } 172 | fn sqrt(self) -> Num { Num { val: self.val.sqrt(), eps: self.eps * 0.5 * self.val.rsqrt() } } 173 | fn rsqrt(self) -> Num { Num { val: self.val.rsqrt(), eps: self.eps * -0.5 / self.val.sqrt().powi(3) } } 174 | fn pi() -> Num { Num { val: f64::consts::PI, eps: 0.0 } } 175 | fn two_pi() -> Num { Num { val: 2.0 * f64::consts::PI, eps: 0.0 } } 176 | fn frac_pi_2() -> Num { Num { val: f64::consts::FRAC_PI_2, eps: 0.0 } } 177 | fn frac_pi_3() -> Num { Num { val: f64::consts::FRAC_PI_3, eps: 0.0 } } 178 | fn frac_pi_4() -> Num { Num { val: f64::consts::FRAC_PI_4, eps: 0.0 } } 179 | fn frac_pi_6() -> Num { Num { val: f64::consts::FRAC_PI_6, eps: 0.0 } } 180 | fn frac_pi_8() -> Num { Num { val: f64::consts::FRAC_PI_8, eps: 0.0 } } 181 | fn frac_1_pi() -> Num { Num { val: f64::consts::FRAC_1_PI, eps: 0.0 } } 182 | fn frac_2_pi() -> Num { Num { val: f64::consts::FRAC_2_PI, eps: 0.0 } } 183 | fn frac_2_sqrtpi() -> Num { Num { val: f64::consts::FRAC_2_SQRTPI, eps: 0.0 } } 184 | fn e() -> Num { Num { val: f64::consts::E, eps: 0.0 } } 185 | fn log2_e() -> Num { Num { val: f64::consts::LOG2_E, eps: 0.0 } } 186 | fn log10_e() -> Num { Num { val: f64::consts::LOG10_E, eps: 0.0 } } 187 | fn ln_2() -> Num { Num { val: f64::consts::LN_2, eps: 0.0 } } 188 | fn ln_10() -> Num { Num { val: f64::consts::LN_10, eps: 0.0 } } 189 | fn exp(self) -> Num { Num { val: Float::exp(self.val), eps: self.eps * Float::exp(self.val) } } 190 | fn exp2(self) -> Num { Num { val: Float::exp2(self.val), eps: self.eps * Float::ln(2.0) * Float::exp(self.val) } } 191 | fn ln(self) -> Num { Num { val: Float::ln(self.val), eps: self.eps * self.val.recip() } } 192 | fn log(self, b: Num) -> Num { 193 | Num { 194 | val: Float::log(self.val, b.val), 195 | eps: -Float::ln(self.val) * b.eps / (b.val * Float::powi(Float::ln(b.val), 2)) + self.eps / (self.val * Float::ln(b.val)), 196 | } } 197 | fn log2(self) -> Num { Float::log(self, Num { val: 2.0, eps: 0.0 }) } 198 | fn log10(self) -> Num { Float::log(self, Num { val: 10.0, eps: 0.0 }) } 199 | fn to_degrees(self) -> Num { Num { val: Float::to_degrees(self.val), eps: 0.0 } } 200 | fn to_radians(self) -> Num { Num { val: Float::to_radians(self.val), eps: 0.0 } } 201 | } 202 | 203 | impl FloatMath for Num { 204 | fn ldexp(x: Num, exp: int) -> Num { Num { val: FloatMath::ldexp(x.val, exp), eps: FloatMath::ldexp(x.eps, exp) } } 205 | fn frexp(self) -> (Num, int) { 206 | let (x, exp) = FloatMath::frexp(self.val); 207 | (Num { val: x, eps: 0.0 }, exp) 208 | } 209 | fn next_after(self, other: Num) -> Num { Num { val: FloatMath::next_after(self.val, other.val), eps: 0.0 } } 210 | fn max(self, other: Num) -> Num { Num { val: FloatMath::max(self.val, other.val), eps: 0.0 } } 211 | fn min(self, other: Num) -> Num { Num { val: FloatMath::min(self.val, other.val), eps: 0.0 } } 212 | fn abs_sub(self, other: Num) -> Num { 213 | if self > other { 214 | Num { val: FloatMath::abs_sub(self.val, other.val), eps: (self - other).eps } 215 | } else { 216 | Num { val: 0.0, eps: 0.0 } 217 | } 218 | } 219 | fn cbrt(self) -> Num { Num { val: FloatMath::cbrt(self.val), eps: 1.0/3.0 * self.val.powf(-2.0/3.0) * self.eps } } 220 | fn hypot(self, other: Num) -> Num { 221 | Float::sqrt(Float::powi(self, 2) + Float::powi(other, 2)) 222 | } 223 | fn sin(self) -> Num { Num { val: FloatMath::sin(self.val), eps: self.eps * FloatMath::cos(self.val) } } 224 | fn cos(self) -> Num { Num { val: FloatMath::cos(self.val), eps: -self.eps * FloatMath::sin(self.val) } } 225 | fn tan(self) -> Num { 226 | let t = FloatMath::tan(self.val); 227 | Num { val: t, eps: self.eps * (t * t + 1.0) } 228 | } 229 | fn asin(self) -> Num { Num { val: FloatMath::asin(self.val), eps: self.eps / Float::sqrt(1.0 - Float::powi(self.val, 2)) } } 230 | fn acos(self) -> Num { Num { val: FloatMath::acos(self.val), eps: -self.eps / Float::sqrt(1.0 - Float::powi(self.val, 2)) } } 231 | fn atan(self) -> Num { Num { val: FloatMath::atan(self.val), eps: self.eps / Float::sqrt(Float::powi(self.val, 2) + 1.0) } } 232 | fn atan2(self, other: Num) -> Num { 233 | Num { 234 | val: FloatMath::atan2(self.val, other.val), 235 | eps: (other.val * self.eps - self.val * other.eps) / (Float::powi(self.val, 2) + Float::powi(other.val, 2)) 236 | } 237 | } 238 | fn sin_cos(self) -> (Num, Num) { 239 | let (s, c) = FloatMath::sin_cos(self.val); 240 | let sn = Num { val: s, eps: self.eps * c }; 241 | let cn = Num { val: c, eps: -self.eps * s }; 242 | (sn, cn) 243 | } 244 | fn exp_m1(self) -> Num { 245 | Num { val: FloatMath::exp_m1(self.val), eps: self.eps * Float::exp(self.val) } 246 | } 247 | fn ln_1p(self) -> Num { 248 | Num { val: FloatMath::ln_1p(self.val), eps: self.eps / (self.val + 1.0) } 249 | } 250 | fn sinh(self) -> Num { Num { val: FloatMath::sinh(self.val), eps: self.eps * FloatMath::cosh(self.val) } } 251 | fn cosh(self) -> Num { Num { val: FloatMath::cosh(self.val), eps: self.eps * FloatMath::sinh(self.val) } } 252 | fn tanh(self) -> Num { Num { val: FloatMath::tanh(self.val), eps: self.eps * (1.0 - Float::powi(FloatMath::tanh(self.val), 2)) } } 253 | fn asinh(self) -> Num { Num { val: FloatMath::asinh(self.val), eps: self.eps * (Float::powi(self.val, 2) + 1.0) } } 254 | fn acosh(self) -> Num { Num { val: FloatMath::acosh(self.val), eps: self.eps * (Float::powi(self.val, 2) - 1.0) } } 255 | fn atanh(self) -> Num { Num { val: FloatMath::atanh(self.val), eps: self.eps * (-Float::powi(self.val, 2) + 1.0) } } 256 | } 257 | 258 | /// Function for creating a constant from a float 259 | pub fn cst(x: f64) -> Num { 260 | Num { val: x, eps: 0.0 } 261 | } 262 | 263 | /// Evaluates the derivative of `func` at `x0` 264 | pub fn diff(func: |Num| -> Num, x0: f64) -> f64 { 265 | let x = Num { val: x0, eps: 1.0 }; 266 | func(x).eps 267 | } 268 | 269 | /// Evaluates the gradient of `func` at `x0` 270 | pub fn grad(func: |Vec| -> Num, x0: Vec) -> Vec { 271 | let mut params = Vec::new(); 272 | for x in x0.iter() { 273 | params.push(Num { val: *x, eps: 0.0 }); 274 | } 275 | 276 | let mut results = Vec::new(); 277 | 278 | for i in range(0u, params.len()) { 279 | params[i].eps = 1.0; 280 | results.push(func(params.clone()).eps); 281 | params[i].eps = 0.0; 282 | } 283 | results 284 | } 285 | 286 | -------------------------------------------------------------------------------- /src/main.rs: -------------------------------------------------------------------------------- 1 | 2 | extern crate ad; 3 | 4 | use std::num::FloatMath; 5 | 6 | fn main() { 7 | let result = ad::diff(FloatMath::sin, 0.0); 8 | 9 | println!("Out: {}", result); 10 | } 11 | 12 | 13 | --------------------------------------------------------------------------------