├── .gitignore ├── Cargo.toml ├── LICENSE ├── README.md ├── ideas.md ├── play ├── optimal-forgetting.hs └── print-forget.rs ├── revad.py └── src ├── chain.rs ├── lib.rs └── tape.rs /.gitignore: -------------------------------------------------------------------------------- 1 | # generic 2 | .#* 3 | \#*# 4 | .nfs???? 5 | [tT]humbs.db 6 | *~ 7 | *.dep 8 | *.log 9 | *.orig 10 | *.pid 11 | *.tmp 12 | 13 | # native 14 | a.out 15 | *.a 16 | *.aps 17 | *.dll 18 | *.dylib 19 | *.exe 20 | *.gcda 21 | *.gch 22 | *.gcno 23 | *.ipch 24 | *.lcov 25 | *.lib 26 | *.ncb 27 | *.o 28 | *.obj 29 | *.opensdf 30 | *.pch 31 | *.so 32 | *.sdf 33 | *.stackdump 34 | *.suo 35 | *.user 36 | 37 | # Fortran 38 | *.mod 39 | 40 | # Haskell 41 | .cabal-sandbox/ 42 | .stack-work/ 43 | dist/ 44 | cabal.config 45 | cabal.sandbox.config 46 | stack.yaml 47 | *.chi 48 | *.hcr 49 | *.hi 50 | 51 | # JavaScript 52 | node_modules/ 53 | 54 | # Python 55 | MANIFEST 56 | __pycache__ 57 | *.pyc 58 | 59 | # Rust 60 | target/ 61 | Cargo.lock 62 | 63 | # TeX 64 | *.aux 65 | *.bbl 66 | *.blg 67 | *.fdb_latexmk 68 | *.fls 69 | *.thm 70 | *.toc 71 | *Notes.bib 72 | 73 | # web 74 | .sass-cache 75 | *.css.map 76 | *.js.map 77 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "revad" 3 | version = "0.1.0" 4 | authors = ["Rufflewind"] 5 | description = "Reverse-mode automatic differentiation." 6 | 7 | [dependencies] 8 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2016-2017 Phil Ruffwind 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## revad: Reverse-mode automatic differentiation demo 2 | 3 | This is a demonstration of how gradients could be calculated using reverse-mode automatic differentiation. 4 | 5 | ~~~rust 6 | let t = Tape::new(); 7 | let x = t.var(0.5); 8 | let y = t.var(4.2); 9 | let z = x * y + x.sin(); 10 | let grad = z.grad(); 11 | println!("z = {}", z.value()); // z = 2.579425538604203 12 | println!("∂z/∂x = {}", grad.wrt(x)); // ∂z/∂x = 5.077582561890373 13 | println!("∂z/∂y = {}", grad.wrt(y)); // ∂z/∂y = 0.5 14 | ~~~ 15 | 16 | This library is an experiment/demonstration/prototype and is therefore woefully incomplete. Feel free to use its [ideas](ideas.md) to build an actual AD library! 17 | 18 | ## Usage 19 | 20 | Add this to your [`Cargo.toml`](http://doc.crates.io/specifying-dependencies.html): 21 | 22 | ~~~toml 23 | [dependencies] 24 | revad = { git = "https://github.com/Rufflewind/revad" } 25 | ~~~ 26 | 27 | and add this line to the [root module](https://doc.rust-lang.org/book/crates-and-modules.html#basic-terminology-crates-and-modules) of your crate: 28 | 29 | ~~~rust 30 | extern crate revad; 31 | ~~~ 32 | -------------------------------------------------------------------------------- /ideas.md: -------------------------------------------------------------------------------- 1 | # Ideas 2 | 3 | The existing design is fairly efficient but it lacks in flexibility. Now, it would be nice to achieve these goals without sacrificing (too much) efficiency, but that may not be possible if we admit richer nodes into our graph (= allocation for each node! :/ ). We also wanna keep the separation between tape and the filled in gradients, so we can re-use the tape. This is difficult when you want arbitrarily-typed variable! 4 | 5 | ## Richly typed variables 6 | 7 | We don't want to assume everything is an amorphous vector of floats or something, which means would encourage users to write things in a type-unsafe way (or unpack everything into dumb vectors). 8 | 9 | Update: Maybe we can get away with a homogeneously typed tape, and then compose differently typed tapes together using structures (like how Iterators work)? But how does a generic tape work? 10 | 11 | ## Better composability 12 | 13 | It would be nice to allow nesting of some sort. Right now, the whole process of AD-ing a function is monolithic: you can't embed a subgraph into a node. 14 | 15 | ## Reduce memory usage 16 | 17 | It'd be nice to have some “manual override” that allows you to reduce memory usage for a specific portion of the function that is known to be intensive. 18 | 19 | Also, it would be nice to implement the CTZ forgetting (and recomputation) strategy for a long repeated calculation, which I think is asymptotically optimal and can reduce both memory usage and recompute costs to logarithmic scaling. It'd be even better to see how this pattern can be applied to more general recursive/iterative control flow. 20 | 21 | ## Prevent variables of different tapes from being used together 22 | 23 | We can use [branded indices](https://github.com/bluss/indexing) for this. 24 | 25 | ## Abstract notion of adjoint functions 26 | 27 | Here is how reverse-mode AD works on an abstract mathematical level. 28 | 29 | Suppose we have a function: 30 | 31 | f : X -> Y 32 | 33 | which maps each point `x` in the input space `X` to a point `y` in the output space `Y`. If `f` is differentiable, then there exists an **adjoint** function: 34 | 35 | adj_f: (x : X) -> GY(f(x)) -> GX(x) 36 | 37 | that computes the gradient `gx: GX(x)` of `f` at the point `x : X` along the cotangent vector `gy: GY(f(x))`. Mathematically, `GX(x)` is the cotangent space of `X` at the point `x` (and similarly for `GY`). It is a vector space composed of every possible differential: 38 | 39 | {dx₁, dx₂, dx₃, …, plus all linear combinations of such} 40 | 41 | i.e. differentials are the basis vectors. Don't confuse this with the “differentials” in forward-mode AD, which are really tangent vectors! 42 | 43 | - For a tangent vector, the coefficients are differentials (`dx/dt`), whereas the basis vectors are directional derivatives (`∂/∂x`). 44 | - For a cotangent vector, the coefficients are directional derivatives (`∂t/∂x`), whereas the basis vectors are differentials (`dx`). 45 | 46 | Notice that `GX` is parametrized by `x : X`. We can't encode this in Rust though, so we will instead pretend that `GX` is independent of `x : X`. 47 | -------------------------------------------------------------------------------- /play/optimal-forgetting.hs: -------------------------------------------------------------------------------- 1 | -- brute-force search for optimal trimming/"forgetting" strategies 2 | import Debug.Trace 3 | import Data.Foldable 4 | import Data.Word 5 | import Data.Bits 6 | 7 | -- budget: maximum amount of memory allowed to be used 8 | -- cost: number of recomputes 9 | 10 | recompute :: Int -> Word64 -> Maybe (Int, Word64) 11 | recompute budget cache = do 12 | let cost = countTrailingZeros cache 13 | let cache' = cache .|. (bit cost - 1) 14 | optTrim <- optimalTrim (budget - 1) (cache' `shiftR` 1) 15 | pure (cost, (optTrim `shiftL` 1) .|. 1) 16 | 17 | reverseSweepCost :: Int -> Word64 -> Maybe Int 18 | reverseSweepCost budget cache 19 | | cache == 0 = Just 0 20 | | cache .&. 1 /= 0 = reverseSweepCost budget (cache `shiftR` 1) 21 | | otherwise = do 22 | (cost, cache') <- recompute budget cache 23 | c2 <- reverseSweepCost budget cache' 24 | pure (cost + c2) 25 | 26 | getBits :: Show b => FiniteBits b => b -> [Int] 27 | getBits w = 28 | foldl' (\ l i -> if testBit w i then i : l else l) 29 | [] [0 .. finiteBitSize w - 1 - countLeadingZeros w] 30 | 31 | optimalTrim :: Int -> Word64 -> Maybe Word64 32 | optimalTrim budget cache 33 | | used <= budget = Just cache 34 | | otherwise = do 35 | let pt = weigh <$> possibleTrims (used - budget) cache 36 | case pt of 37 | [] -> Nothing 38 | _ -> Just (snd (minimum pt)) 39 | where 40 | used = popCount cache 41 | weigh c = (reverseSweepCost budget c, c) 42 | 43 | possibleTrims :: Int -> Word64 -> [Word64] 44 | possibleTrims 0 c = pure c 45 | possibleTrims n c = do 46 | i <- tail (getBits c) -- avoid unsetting the last bit 47 | possibleTrims (n - 1) (c `xor` bit i) 48 | 49 | fromStr :: String -> Word64 50 | fromStr s = go 0 s 51 | where go w "" = w 52 | go w (' ' : xs) = go (w `shiftL` 1) xs 53 | go w (_ : xs) = go ((w `shiftL` 1) + 1) xs 54 | 55 | toStr :: Word64 -> String 56 | toStr = reverse . go 57 | where go 0 = "" 58 | go w = (if testBit w 0 then 'x' else ' ') : go (w `shiftR` 1) 59 | 60 | reverseSweep :: Int -> Word64 -> IO () 61 | reverseSweep budget cache 62 | | cache == 0 = pure () 63 | | cache .&. 1 /= 0 = do 64 | print ("sweeping", toStr cache, 0::Int) 65 | reverseSweep budget (cache `shiftR` 1) 66 | | otherwise = do 67 | case recompute budget cache of 68 | Nothing -> do 69 | print "impossible" 70 | Just (cost, cache') -> do 71 | print ("sweeping", toStr cache, cost) 72 | reverseSweep budget cache' 73 | 74 | main :: IO () 75 | main = do 76 | let b = 4 -- budget 77 | let initial = "x x x x" 78 | let t = case optimalTrim b (fromStr initial) of 79 | Just x -> x 80 | Nothing -> error "no solution" 81 | print ("result", toStr t) 82 | print ("cost", reverseSweepCost b t) 83 | reverseSweep b t 84 | -------------------------------------------------------------------------------- /play/print-forget.rs: -------------------------------------------------------------------------------- 1 | // print a pretty picture that shows which elements are cached and which are 2 | // forgotten 3 | // 4 | // here, we use a simple CTZ-based strategy (which I believe is optimal); the 5 | // count-trailing-zero (CTZ) operation is used to generates the so-called 6 | // ruler sequence, which is controls how long each new cache value can live 7 | // for (i.e. the height of each column in the diagram produced by this 8 | // program); funny enough, CTZ *also* appears in the formula that determines 9 | // which index of the existing cache vector to annihilate so as to lead to the 10 | // same desired result. 11 | // 12 | // anyway, number trickery aside, this strategy has the property of 13 | // eliminating the caches that are further away from the "current" entry; 14 | // the result looks like the ticks on a log2-scale axis of a plot 15 | 16 | fn main() { 17 | 18 | let mut l = Vec::new(); 19 | 20 | let mut ruler: usize = 1; 21 | let mut ruler_max = -2; 22 | for i in 0 .. 64 { 23 | l.push(i); 24 | 25 | let j = ruler_max - ruler.trailing_zeros() as i64; 26 | if j <= 0 { 27 | ruler = 1; 28 | ruler_max += 1; 29 | } else { 30 | ruler += 1; 31 | l.remove(j as usize); 32 | } 33 | 34 | for u in 0 .. i + 1 { 35 | if l.contains(&u) { 36 | print!("x"); 37 | } else { 38 | print!(" "); 39 | } 40 | } 41 | println!(""); 42 | } 43 | 44 | } 45 | -------------------------------------------------------------------------------- /revad.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | class Var: 4 | def __init__(self, value): 5 | self.value = value 6 | self.children = [] 7 | self.grad_value = None 8 | 9 | def grad(self): 10 | if self.grad_value is None: 11 | self.grad_value = sum(weight * var.grad() 12 | for weight, var in self.children) 13 | return self.grad_value 14 | 15 | def __add__(self, other): 16 | z = Var(self.value + other.value) 17 | self.children.append((1.0, z)) 18 | other.children.append((1.0, z)) 19 | return z 20 | 21 | def __mul__(self, other): 22 | z = Var(self.value * other.value) 23 | self.children.append((other.value, z)) 24 | other.children.append((self.value, z)) 25 | return z 26 | 27 | def sin(x): 28 | z = Var(math.sin(x.value)) 29 | x.children.append((math.cos(x.value), z)) 30 | return z 31 | 32 | x = Var(0.5) 33 | y = Var(4.2) 34 | z = x * y + sin(x) 35 | z.grad_value = 1.0 36 | 37 | assert abs(z.value - 2.579425538604203) <= 1e-15 38 | assert abs(x.grad() - (y.value + math.cos(x.value))) <= 1e-15 39 | assert abs(y.grad() - x.value) <= 1e-15 40 | -------------------------------------------------------------------------------- /src/chain.rs: -------------------------------------------------------------------------------- 1 | //! This module implements explicit algorithms for handling loops in 2 | //! reverse-mode AD code. 3 | //! 4 | //! `FullChain` is the naive approach: it basically records a snapshot at 5 | //! every step. A smarter way is done by the `CtzChain`, which will keep only 6 | //! a small subset of the snapshot (roughly, a logarithmic amount) at the 7 | //! expense of requiring recomputations (also, a logarithmic amount). 8 | //! 9 | //! Both of these chains are constructed from an iterator of *snapshots*, 10 | //! which are arbitrary objects that may be used to compute the adjoints. To 11 | //! compute an adjoint value, the adjoint function `j: J` (conceptually, the 12 | //! transposed Jacobian) is called in the following manner: 13 | //! 14 | //! ~~~text 15 | //! let gx = j(s, gy); 16 | //! ~~~ 17 | //! 18 | //! where `s: S` is the snapshot and `gx: G` and `gy: G` are the adjoint 19 | //! values. Unfortunately the `sweep_*` methods have slightly differing 20 | //! requirements on the signature of the adjoint function `J`: they differ in 21 | //! how the snapshot is to be passed in (immutable/mutable/by value). (I wish 22 | //! there was a way to avoid this.) 23 | //! 24 | //! When using a loop to compute `f(f(f(…)))`, we use a condensed version of 25 | //! `f` called the restoration function (`restore`) to recompute the 26 | //! snapshots, if necessary. A trivial definition of the restoration function 27 | //! is given by `|x| f(x)`, if we assume the snapshots are just the input 28 | //! values themselves. But you may not want to use the trivial definition, if 29 | //! the restoration function can be implemented in a simpler way (as is the 30 | //! case if `f` is a linear function or is partly linear, in which case the 31 | //! derivative would be independent or partly independent of the input 32 | //! values). In any case, the restoration function `r` must satisfy the 33 | //! following condition with respect to the sequence of snapshots: if the 34 | //! snapshot `s2` follows `s1`, then `r(s1)` must return `s2`. 35 | //! 36 | //! Pictorially, a chain is used to model the following control flow: 37 | //! 38 | //! ~~~text 39 | //! x0 --+-> x1 --+-> x2 --+-> x3 (plain values) 40 | //! | | | 41 | //! | | | 42 | //! v v v 43 | //! s0 s1 s2 (snapshots) 44 | //! | | | 45 | //! |j |j |j (adjoint function) 46 | //! v v v 47 | //! g0 <-+-- g1 <-+-- g2 <-+-- g3 (adjoint values) 48 | //! ----> ----> ----> 49 | //! r r r (restoration function) 50 | //! ~~~ 51 | //! 52 | //! The iterator is what produces the snapshots, so the chain implementations 53 | //! never actually see the input values `x: X` directly. 54 | 55 | /// Reifies an operation built from a loop. 56 | /// 57 | /// When called as a function, it will run the adjoint functions in reverse. 58 | pub struct FullChain { 59 | snapshots: Vec, 60 | adjoint: J, 61 | } 62 | 63 | impl FullChain { 64 | 65 | pub fn new(snapshots: I, adjoint: J) -> Self where I: Iterator { 66 | FullChain { 67 | snapshots: snapshots.collect(), 68 | adjoint: adjoint, 69 | } 70 | } 71 | 72 | pub fn sweep(&self, mut x: G) -> G where J: Fn(&S, G) -> G { 73 | for g in self.snapshots.iter().rev() { 74 | x = (self.adjoint)(g, x); 75 | } 76 | return x; 77 | } 78 | 79 | pub fn sweep_mut(&mut self, mut x: G) -> G where J: FnMut(&mut S, G) -> G { 80 | for g in self.snapshots.iter_mut().rev() { 81 | x = (self.adjoint)(g, x); 82 | } 83 | return x; 84 | } 85 | 86 | pub fn sweep_once(mut self, mut x: G) -> G where J: FnMut(S, G) -> G { 87 | loop { 88 | match self.snapshots.pop() { 89 | Some(g) => x = (self.adjoint)(g, x), 90 | None => return x, 91 | } 92 | } 93 | } 94 | } 95 | 96 | /// Lossily extend a vector of elements using an eviction strategy based on 97 | /// the count-trailing-zeros operation. The algorithm tends to evict items 98 | /// further away from the current item. The latest item is never evicted, and 99 | /// existing items are left untouched. Given `n` new items to be added, the 100 | /// number of items that will actually get added to the vector is exactly 101 | /// `ceil(log2(n)) + 1`. 102 | pub fn ctz_extend(v: &mut Vec<(usize, T)>, i0: usize, xs: I) 103 | where I: Iterator { 104 | let mut ruler: usize = 1; 105 | let mut ruler_max = -2; 106 | let start = v.len(); 107 | for (i, x) in xs.enumerate() { 108 | v.push((i0 + i, x)); 109 | let j = ruler_max - ruler.trailing_zeros() as i64; 110 | if j <= 0 { 111 | ruler = 1; 112 | ruler_max += 1; 113 | } else { 114 | ruler += 1; 115 | v.remove(start + (j as usize)); 116 | } 117 | } 118 | } 119 | 120 | /// Maintains a partial tape using the count-trailing-zeros (CTZ) eviction 121 | /// strategy. This results in space usage that is logarithmic in the number 122 | /// of steps, while also keeping a logarithmic amount of recomputation time. 123 | /// 124 | /// FIXME: `sweep_mut` is not yet implemented. 125 | pub struct CtzChain { 126 | snapshots: Vec<(usize, S)>, 127 | adjoint: J, 128 | restore: R, 129 | } 130 | 131 | impl CtzChain { 132 | pub fn new(snapshots: I, adjoint: J, restore: R) -> Self 133 | where I: Iterator { 134 | let mut chain = CtzChain { 135 | snapshots: Vec::new(), 136 | adjoint: adjoint, 137 | restore: restore, 138 | }; 139 | ctz_extend(&mut chain.snapshots, 0, snapshots); 140 | chain 141 | } 142 | 143 | pub fn sweep(&self, mut x: G) -> G 144 | where J: Fn(&S, G) -> G, R: Fn(&S) -> S { 145 | 146 | fn extend(restored_snapshots: &mut Vec<(usize, S)>, 147 | mut num_missing: usize, i: usize, 148 | s_old: &mut Option, restore: &R) 149 | where R: Fn(&S) -> S { 150 | ctz_extend(restored_snapshots, i + 1, Generator(|| { 151 | let mut s_new = match s_old { 152 | &mut None => { 153 | None 154 | }, 155 | &mut Some(ref mut s_old) => { 156 | num_missing -= 1; 157 | if num_missing == 0 { 158 | None 159 | } else { 160 | Some(restore(s_old)) 161 | } 162 | }, 163 | }; 164 | ::std::mem::swap(s_old, &mut s_new); 165 | s_new 166 | })); 167 | } 168 | 169 | let mut j = match self.snapshots.last() { 170 | None => return x, 171 | Some(&(i, _)) => i + 1, 172 | }; 173 | let mut restored_snapshots = Vec::new(); 174 | for &(i, ref s) in self.snapshots.iter().rev() { 175 | loop { 176 | match restored_snapshots.pop() { 177 | Some((i, s)) => { 178 | let num_missing = j - i - 1; 179 | if num_missing != 0 { 180 | let mut s_old = Some((self.restore)(&s)); 181 | restored_snapshots.push((i, s)); 182 | extend(&mut restored_snapshots, num_missing, 183 | i, &mut s_old, &self.restore); 184 | } else { 185 | x = (self.adjoint)(&s, x); 186 | j -= 1; 187 | } 188 | }, 189 | None => { 190 | let num_missing = j - i - 1; 191 | if num_missing != 0 { 192 | let mut s_old = Some((self.restore)(&s)); 193 | extend(&mut restored_snapshots, num_missing, 194 | i, &mut s_old, &self.restore); 195 | } else { 196 | x = (self.adjoint)(&s, x); 197 | j -= 1; 198 | break; 199 | } 200 | }, 201 | } 202 | } 203 | } 204 | x 205 | } 206 | 207 | pub fn sweep_once(mut self, mut x: G) -> G 208 | where J: FnMut(S, G) -> G, R: FnMut(&S) -> S { 209 | let mut i = match self.snapshots.last() { 210 | None => return x, 211 | Some(&(j, _)) => j + 1, 212 | }; 213 | loop { 214 | match self.snapshots.pop() { 215 | Some((j, s)) => { 216 | let mut num_missing = i - j - 1; 217 | if num_missing == 0 { 218 | x = (self.adjoint)(s, x); 219 | i -= 1; 220 | } else { 221 | let mut s_old = Some((self.restore)(&s)); 222 | self.snapshots.push((j, s)); 223 | let mut restore = self.restore; 224 | ctz_extend(&mut self.snapshots, j + 1, Generator(|| { 225 | let mut s_new = match &mut s_old { 226 | &mut None => { 227 | None 228 | }, 229 | &mut Some(ref mut s_old) => { 230 | num_missing -= 1; 231 | if num_missing == 0 { 232 | None 233 | } else { 234 | Some(restore(s_old)) 235 | } 236 | }, 237 | }; 238 | ::std::mem::swap(&mut s_old, &mut s_new); 239 | s_new 240 | })); 241 | self.restore = restore; 242 | } 243 | }, 244 | None => { 245 | assert_eq!(i, 0); 246 | return x; 247 | }, 248 | } 249 | } 250 | } 251 | } 252 | 253 | /// Wrap a `next` function into an `Iterator`. 254 | pub struct Generator(pub F); 255 | 256 | impl Iterator for Generator where F: FnMut() -> Option { 257 | type Item = T; 258 | fn next(&mut self) -> Option { 259 | (self.0)() 260 | } 261 | } 262 | 263 | #[cfg(test)] 264 | mod tests { 265 | use super::*; 266 | 267 | static E: f64 = 1.01; 268 | static N: i32 = 100; 269 | static X0: f64 = 4.2; 270 | 271 | fn f(mut x: Vec) -> Vec { 272 | x[0] = x[0].powf(E); 273 | x 274 | } 275 | 276 | #[test] 277 | fn full_chain() { 278 | let x0 = vec![X0]; 279 | let expected = E.powi(N) * X0.powf(E.powi(N) - 1.0); 280 | 281 | let g = { 282 | let mut x = x0.clone(); 283 | let mut i = 0; 284 | FullChain::new(Generator(|| { 285 | if !(i < N) { 286 | return None; 287 | } 288 | i += 1; 289 | let mut x2 = f(x.clone()); 290 | ::std::mem::swap(&mut x2, &mut x); 291 | Some(x2) 292 | }), |x: &Vec, mut g: Vec| { 293 | g[0] *= E * x[0].powf(E - 1.0); 294 | g 295 | }) 296 | }.sweep(vec![1.0]); 297 | assert!((g[0] - expected).abs() < 1e-10); 298 | 299 | let g = { 300 | let mut x = x0.clone(); 301 | let mut i = 0; 302 | FullChain::new(Generator(|| { 303 | if !(i < N) { 304 | return None; 305 | } 306 | i += 1; 307 | let mut x2 = f(x.clone()); 308 | ::std::mem::swap(&mut x2, &mut x); 309 | Some(x2) 310 | }), |x: &mut Vec, mut g: Vec| { 311 | g[0] *= E * x[0].powf(E - 1.0); 312 | g 313 | }) 314 | }.sweep_mut(vec![1.0]); 315 | assert!((g[0] - expected).abs() < 1e-10); 316 | 317 | let g = { 318 | let mut x = x0.clone(); 319 | let mut i = 0; 320 | FullChain::new(Generator(|| { 321 | if !(i < N) { 322 | return None; 323 | } 324 | i += 1; 325 | let mut x2 = f(x.clone()); 326 | ::std::mem::swap(&mut x2, &mut x); 327 | Some(x2) 328 | }), |x: Vec, mut g: Vec| { 329 | g[0] *= E * x[0].powf(E - 1.0); 330 | g 331 | }) 332 | }.sweep_once(vec![1.0]); 333 | assert!((g[0] - expected).abs() < 1e-10); 334 | } 335 | 336 | #[test] 337 | fn ctz_chain() { 338 | let x0 = vec![X0]; 339 | let expected = E.powi(N) * X0.powf(E.powi(N) - 1.0); 340 | 341 | let g = { 342 | let mut x = x0.clone(); 343 | let mut i = 0; 344 | CtzChain::new(Generator(|| { 345 | if !(i < N) { 346 | return None; 347 | } 348 | i += 1; 349 | let mut x2 = f(x.clone()); 350 | ::std::mem::swap(&mut x2, &mut x); 351 | Some(x2) 352 | }), |x: &Vec, mut g: Vec| { 353 | g[0] *= E * x[0].powf(E - 1.0); 354 | g 355 | }, |x: &Vec| f(x.clone())) 356 | }.sweep(vec![1.0]); 357 | assert!((g[0] - expected).abs() < 1e-10); 358 | 359 | let g = { 360 | let mut x = x0.clone(); 361 | let mut i = 0; 362 | CtzChain::new(Generator(|| { 363 | if !(i < N) { 364 | return None; 365 | } 366 | i += 1; 367 | let mut x2 = f(x.clone()); 368 | ::std::mem::swap(&mut x2, &mut x); 369 | Some(x2) 370 | }), |x: Vec, mut g: Vec| { 371 | g[0] *= E * x[0].powf(E - 1.0); 372 | g 373 | }, |x: &Vec| f(x.clone())) 374 | }.sweep_once(vec![1.0]); 375 | assert!((g[0] - expected).abs() < 1e-10); 376 | } 377 | } 378 | -------------------------------------------------------------------------------- /src/lib.rs: -------------------------------------------------------------------------------- 1 | pub mod chain; 2 | pub mod tape; 3 | -------------------------------------------------------------------------------- /src/tape.rs: -------------------------------------------------------------------------------- 1 | use std::cell::RefCell; 2 | 3 | #[derive(Clone, Copy)] 4 | struct Node { 5 | weights: [f64; 2], 6 | deps: [usize; 2], 7 | } 8 | 9 | pub struct Tape { nodes: RefCell> } 10 | 11 | impl Tape { 12 | pub fn new() -> Self { 13 | Tape { nodes: RefCell::new(Vec::new()) } 14 | } 15 | 16 | pub fn var<'t>(&'t self, value: f64) -> Var<'t> { 17 | Var { 18 | tape: self, 19 | value: value, 20 | index: self.push0(), 21 | } 22 | } 23 | 24 | fn len(&self) -> usize { 25 | self.nodes.borrow().len() 26 | } 27 | 28 | fn push0(&self) -> usize { 29 | let mut nodes = self.nodes.borrow_mut(); 30 | let len = nodes.len(); 31 | nodes.push(Node { 32 | weights: [0.0, 0.0], 33 | deps: [len, len], 34 | }); 35 | len 36 | } 37 | 38 | fn push1(&self, dep0: usize, weight0: f64) -> usize { 39 | let mut nodes = self.nodes.borrow_mut(); 40 | let len = nodes.len(); 41 | nodes.push(Node { 42 | weights: [weight0, 0.0], 43 | deps: [dep0, len], 44 | }); 45 | len 46 | } 47 | 48 | fn push2(&self, 49 | dep0: usize, weight0: f64, 50 | dep1: usize, weight1: f64) -> usize { 51 | let mut nodes = self.nodes.borrow_mut(); 52 | let len = nodes.len(); 53 | nodes.push(Node { 54 | weights: [weight0, weight1], 55 | deps: [dep0, dep1], 56 | }); 57 | len 58 | } 59 | } 60 | 61 | #[derive(Clone, Copy)] 62 | pub struct Var<'t> { 63 | tape: &'t Tape, 64 | index: usize, 65 | value: f64, 66 | } 67 | 68 | impl<'t> Var<'t> { 69 | pub fn value(&self) -> f64 { 70 | self.value 71 | } 72 | 73 | pub fn grad(&self) -> Grad { 74 | let len = self.tape.len(); 75 | let nodes = self.tape.nodes.borrow(); 76 | let mut derivs = vec![0.0; len]; 77 | derivs[self.index] = 1.0; 78 | for i in (0 .. len).rev() { 79 | let node = nodes[i]; 80 | let deriv = derivs[i]; 81 | for j in 0 .. 2 { 82 | derivs[node.deps[j]] += node.weights[j] * deriv; 83 | } 84 | } 85 | Grad { derivs: derivs } 86 | } 87 | 88 | pub fn sin(self) -> Self { 89 | Var { 90 | tape: self.tape, 91 | value: self.value.sin(), 92 | index: self.tape.push1( 93 | self.index, self.value.cos(), 94 | ), 95 | } 96 | } 97 | } 98 | 99 | impl<'t> ::std::ops::Add for Var<'t> { 100 | type Output = Var<'t>; 101 | fn add(self, other: Var<'t>) -> Self::Output { 102 | assert_eq!(self.tape as *const Tape, other.tape as *const Tape); 103 | Var { 104 | tape: self.tape, 105 | value: self.value + other.value, 106 | index: self.tape.push2( 107 | self.index, 1.0, 108 | other.index, 1.0, 109 | ), 110 | } 111 | } 112 | } 113 | 114 | impl<'t> ::std::ops::Mul for Var<'t> { 115 | type Output = Var<'t>; 116 | fn mul(self, other: Var<'t>) -> Self::Output { 117 | assert_eq!(self.tape as *const Tape, other.tape as *const Tape); 118 | Var { 119 | tape: self.tape, 120 | value: self.value * other.value, 121 | index: self.tape.push2( 122 | self.index, other.value, 123 | other.index, self.value, 124 | ), 125 | } 126 | } 127 | } 128 | 129 | pub struct Grad { derivs: Vec } 130 | 131 | impl Grad { 132 | pub fn wrt<'t>(&self, var: Var<'t>) -> f64 { 133 | self.derivs[var.index] 134 | } 135 | } 136 | 137 | #[cfg(test)] 138 | mod tests { 139 | use super::Tape; 140 | 141 | #[test] 142 | fn x_times_y_plus_sin_x() { 143 | let t = Tape::new(); 144 | let x = t.var(0.5); 145 | let y = t.var(4.2); 146 | let z = x * y + x.sin(); 147 | let grad = z.grad(); 148 | assert!((z.value - 2.579425538604203).abs() <= 1e-15); 149 | assert!((grad.wrt(x) - (y.value + x.value.cos())).abs() <= 1e-15); 150 | assert!((grad.wrt(y) - x.value).abs() <= 1e-15); 151 | } 152 | } 153 | --------------------------------------------------------------------------------