├── .merlin ├── pics ├── function.gif └── derivative.gif ├── dual_number.mli ├── test.ml ├── dual_number.ml ├── auto_differentiate.mli ├── auto_differentiate.ml ├── .gitignore └── README.md /.merlin: -------------------------------------------------------------------------------- 1 | PKG core_kernel 2 | 3 | B _build -------------------------------------------------------------------------------- /pics/function.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jaybosamiya/automatic-differentiation/HEAD/pics/function.gif -------------------------------------------------------------------------------- /pics/derivative.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jaybosamiya/automatic-differentiation/HEAD/pics/derivative.gif -------------------------------------------------------------------------------- /dual_number.mli: -------------------------------------------------------------------------------- 1 | module Dual : sig 2 | 3 | type t 4 | 5 | val from_floats : float -> float -> t 6 | 7 | val real : t -> float 8 | val non_real : t -> float 9 | 10 | module Infix : sig 11 | val ( ! ) : t -> t 12 | val ( + ) : t -> t -> t 13 | val ( - ) : t -> t -> t 14 | val ( * ) : t -> t -> t 15 | val ( / ) : t -> t -> t 16 | end 17 | end 18 | -------------------------------------------------------------------------------- /test.ml: -------------------------------------------------------------------------------- 1 | open Core_kernel.Std 2 | open Auto_differentiate 3 | 4 | let f, f' = D.val_deriv (fun x -> 5 | let open D.Operators in 6 | if x < ~$5. 7 | then ~$2. * x ** 2. - ~$3. * x + ~$5. 8 | else ~$5. / x + ~$39.) 9 | 10 | let () = 11 | List.range 0 10 |> 12 | List.map ~f:Float.of_int |> 13 | List.iter ~f:(fun x -> 14 | printf "f(%f) = %f\n" x (f x); 15 | printf "f'(%f) = %f\n\n" x (f' x)) 16 | -------------------------------------------------------------------------------- /dual_number.ml: -------------------------------------------------------------------------------- 1 | module Dual = struct 2 | 3 | type t = float * float 4 | 5 | let from_floats a b : t = a, b 6 | 7 | let real (a, _ : t) = a 8 | 9 | let non_real (_, b : t) = b 10 | 11 | module Infix = struct 12 | 13 | let ( ! ) ((a, b) : t) : t = 14 | (a, -. b) 15 | 16 | let ( + ) ((a, b) : t) ((c, d) : t) : t = 17 | (a +. c, b +. d) 18 | 19 | let ( - ) ((a, b) : t) ((c, d) : t) : t = 20 | (a -. c, b -. d) 21 | 22 | let ( * ) ((a, b) : t) ((c, d) : t) : t = 23 | (a *. c, a*.d +. b*.c) 24 | 25 | let ( / ) ((a, b) : t) ((c, d) : t) : t = 26 | (a /. c, ((b*.c -. a*.d)/.(c*.c))) 27 | end 28 | 29 | end 30 | -------------------------------------------------------------------------------- /auto_differentiate.mli: -------------------------------------------------------------------------------- 1 | module D : sig 2 | type t 3 | 4 | val value : (t -> t) -> (float -> float) 5 | val derivative : (t -> t) -> (float -> float) 6 | 7 | val val_deriv : (t -> t) -> (float -> float) * (float -> float) 8 | 9 | module Operators : sig 10 | val ( ~$ ) : float -> t 11 | val ( + ) : t -> t -> t 12 | val ( - ) : t -> t -> t 13 | val ( * ) : t -> t -> t 14 | val ( / ) : t -> t -> t 15 | val ( < ) : t -> t -> bool 16 | val ( = ) : t -> t -> bool 17 | val ( > ) : t -> t -> bool 18 | val ( <= ) : t -> t -> bool 19 | val ( >= ) : t -> t -> bool 20 | val ( <> ) : t -> t -> bool 21 | val sqrt : t -> t 22 | val ( ** ) : t -> float -> t 23 | end 24 | end 25 | -------------------------------------------------------------------------------- /auto_differentiate.ml: -------------------------------------------------------------------------------- 1 | open Dual_number 2 | 3 | module D = struct 4 | 5 | type t = Dual.t 6 | 7 | let variable x = Dual.from_floats x 1. 8 | 9 | let constant x = Dual.from_floats x 0. 10 | 11 | let value f = fun x -> 12 | variable x |> f |> Dual.real 13 | 14 | let derivative f = fun x -> 15 | variable x |> f |> Dual.non_real 16 | 17 | let val_deriv f = 18 | value f, derivative f 19 | 20 | module Operators = struct 21 | include Dual.Infix 22 | let ( ~$ ) = constant 23 | 24 | let ( < ) a b = 25 | Dual.real a < Dual.real b 26 | 27 | let ( = ) a b = 28 | Dual.real a = Dual.real b 29 | 30 | let ( > ) a b = 31 | Dual.real a > Dual.real b 32 | 33 | let ( <= ) a b = 34 | a = b || a < b 35 | 36 | let ( >= ) a b = 37 | a = b || a > b 38 | 39 | let ( <> ) a b = 40 | Dual.real a <> Dual.real b 41 | 42 | let sqrt a = 43 | let arsqrt = sqrt (Dual.real a) in 44 | let anr = Dual.non_real a in 45 | Dual.from_floats arsqrt (0.5 *. anr /. arsqrt) 46 | 47 | let ( ** ) a b = 48 | Dual.from_floats 49 | (Dual.real a ** b) 50 | (b *. Dual.non_real a *. (Dual.real a ** (b -. 1.))) 51 | 52 | 53 | end 54 | 55 | end 56 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | # Created by https://www.gitignore.io/api/emacs,linux,ocaml 3 | 4 | ### Emacs ### 5 | # -*- mode: gitignore; -*- 6 | *~ 7 | \#*\# 8 | /.emacs.desktop 9 | /.emacs.desktop.lock 10 | *.elc 11 | auto-save-list 12 | tramp 13 | .\#* 14 | 15 | # Org-mode 16 | .org-id-locations 17 | *_archive 18 | 19 | # flymake-mode 20 | *_flymake.* 21 | 22 | # eshell files 23 | /eshell/history 24 | /eshell/lastdir 25 | 26 | # elpa packages 27 | /elpa/ 28 | 29 | # reftex files 30 | *.rel 31 | 32 | # AUCTeX auto folder 33 | /auto/ 34 | 35 | # cask packages 36 | .cask/ 37 | dist/ 38 | 39 | # Flycheck 40 | flycheck_*.el 41 | 42 | # server auth directory 43 | /server/ 44 | 45 | # projectiles files 46 | .projectile 47 | 48 | ### Linux ### 49 | *~ 50 | 51 | # temporary files which can be created if a process still has a handle open of a deleted file 52 | .fuse_hidden* 53 | 54 | # KDE directory preferences 55 | .directory 56 | 57 | # Linux trash folder which might appear on any partition or disk 58 | .Trash-* 59 | 60 | 61 | ### OCaml ### 62 | *.annot 63 | *.cmo 64 | *.cma 65 | *.cmi 66 | *.a 67 | *.o 68 | *.cmx 69 | *.cmxs 70 | *.cmxa 71 | 72 | # ocamlbuild working directory 73 | _build/ 74 | 75 | # ocamlbuild targets 76 | *.byte 77 | *.native 78 | 79 | # oasis generated files 80 | setup.data 81 | setup.log 82 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Automatic Differentiation 2 | 3 | > Find first derivatives of functions _automagically_ 4 | 5 | ## How does it work? 6 | 7 | Through the magic of [dual numbers](https://en.wikipedia.org/wiki/Dual_number), we can calculate the derivative of a function while calculating its value. 8 | 9 | By using OCaml's operator overloading, we can abstract away the 10 | implementation of these dual numbers and make it almost seamless in 11 | operation. 12 | 13 | ## Usage 14 | 15 | After `open`ing `Auto_differentiate`, one can easily use the 16 | `D.val_deriv` function along with `D.operators` to construct a 17 | function and its derivative. 18 | 19 | For example, consider the following piecewise function: 20 | 21 | ![](pics/function.gif) 22 | 23 | The normal implementation of this would be using: 24 | 25 | ```OCaml 26 | let f x = 27 | if x <. 5. 28 | then 2. *. x ** 2. -. 3. *. x +. 5. 29 | else 5. /. x +. 39.) 30 | ``` 31 | 32 | To calculate the derivative, along with the function's value itself, 33 | we only need to update the function as follows: 34 | 35 | ```OCaml 36 | let f, f' = D.val_deriv (fun x -> 37 | let open D.Operators in 38 | if x < ~$5. 39 | then ~$2. * x ** 2. - ~$3. * x + ~$5. 40 | else ~$5. / x + ~$39.) 41 | ``` 42 | 43 | Now, `f` calculates the value, while `f'` calculates the derivative. 44 | 45 | The values outputted for this derivative `f'` match as per [WolframAlpha's output](http://www.wolframalpha.com/input/?i=differentiate+piecewise%5B%7B%7B2*x%5E2-3x%2B5,+x+%3C+5%7D,%7B5%2Fx%2B39,+x+%3E%3D+5%7D%7D%5D) when its is checked numerically: 46 | 47 | ![](pics/derivative.gif) 48 | 49 | The results for numerical computation can be checked using 50 | 51 | ``` 52 | corebuild test.native 53 | ./test.native 54 | ``` 55 | 56 | Obviously, the *indeterminate* case is not handled properly, and 57 | instead, a value is returned based upon the `x >= 5` case. 58 | 59 | ## License 60 | 61 | [MIT License](https://jay.mit-license.org/2017) 62 | 63 | ## Acknowledgements 64 | 65 | Thanks to Demofox, for introducing me to Dual Numbers and Automatic Differentiation in their [blogpost](http://blog.demofox.org/2014/12/30/dual-numbers-automatic-differentiation/). 66 | --------------------------------------------------------------------------------