├── .gitignore ├── .travis.yml ├── LICENSE ├── Manifest.toml ├── Project.toml ├── README.md ├── src ├── backandforth.jl ├── forward.jl ├── intro.jl ├── notebooks.jl ├── reverse.jl ├── tracing.jl └── utils.jl └── test └── runtests.jl /.gitignore: -------------------------------------------------------------------------------- 1 | notebooks 2 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | ## Documentation: http://docs.travis-ci.com/user/languages/julia/ 2 | language: julia 3 | dist: bionic 4 | os: 5 | - linux 6 | julia: 7 | - 1.3 8 | - nightly 9 | notifications: 10 | email: false 11 | git: 12 | depth: 99999999 13 | before_script: 14 | - touch src/diff-zoo.jl # so that we can use the standard script 15 | jobs: 16 | allow_failures: 17 | - julia: nightly 18 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Mike J Innes 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Manifest.toml: -------------------------------------------------------------------------------- 1 | # This file is machine-generated - editing it directly is not advised 2 | 3 | [[Base64]] 4 | uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" 5 | 6 | [[DataStructures]] 7 | deps = ["InteractiveUtils", "OrderedCollections"] 8 | git-tree-sha1 = "a1b652fb77ae8ca7ea328fa7ba5aa151036e5c10" 9 | uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" 10 | version = "0.17.6" 11 | 12 | [[Dates]] 13 | deps = ["Printf"] 14 | uuid = "ade2ca70-3891-5945-98fb-dc099432e06a" 15 | 16 | [[Distributed]] 17 | deps = ["Random", "Serialization", "Sockets"] 18 | uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" 19 | 20 | [[InteractiveUtils]] 21 | deps = ["Markdown"] 22 | uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" 23 | 24 | [[JSON]] 25 | deps = ["Dates", "Mmap", "Parsers", "Unicode"] 26 | git-tree-sha1 = "b34d7cef7b337321e97d22242c3c2b91f476748e" 27 | uuid = "682c06a0-de6a-54ab-a142-c8b1cf79cde6" 28 | version = "0.21.0" 29 | 30 | [[LibGit2]] 31 | uuid = "76f85450-5226-5b5a-8eaa-529ad045b433" 32 | 33 | [[Libdl]] 34 | uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb" 35 | 36 | [[Literate]] 37 | deps = ["Base64", "JSON", "REPL"] 38 | git-tree-sha1 = "463a0fe61a863fe1098f45a80eade3ed04f6586e" 39 | uuid = "98b081ad-f1c9-55d3-8b20-4c87d4299306" 40 | version = "2.2.0" 41 | 42 | [[Logging]] 43 | uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" 44 | 45 | [[MacroTools]] 46 | deps = ["DataStructures", "Markdown", "Random"] 47 | git-tree-sha1 = "e2fc7a55bb2224e203bbd8b59f72b91323233458" 48 | repo-rev = "master" 49 | repo-url = "https://github.com/MikeInnes/MacroTools.jl.git" 50 | uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" 51 | version = "0.5.3" 52 | 53 | [[Markdown]] 54 | deps = ["Base64"] 55 | uuid = "d6f4376e-aef5-505a-96c1-9c027394607a" 56 | 57 | [[Mmap]] 58 | uuid = "a63ad114-7e13-5084-954f-fe012c677804" 59 | 60 | [[OpenSpecFun_jll]] 61 | deps = ["Libdl", "Pkg"] 62 | git-tree-sha1 = "65f672edebf3f4e613ddf37db9dcbd7a407e5e90" 63 | uuid = "efe28fd5-8261-553b-a9e1-b2916fc3738e" 64 | version = "0.5.3+1" 65 | 66 | [[OrderedCollections]] 67 | deps = ["Random", "Serialization", "Test"] 68 | git-tree-sha1 = "c4c13474d23c60d20a67b217f1d7f22a40edf8f1" 69 | uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" 70 | version = "1.1.0" 71 | 72 | [[Parsers]] 73 | deps = ["Dates", "Test"] 74 | git-tree-sha1 = "0139ba59ce9bc680e2925aec5b7db79065d60556" 75 | uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0" 76 | version = "0.3.10" 77 | 78 | [[Pkg]] 79 | deps = ["Dates", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"] 80 | uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" 81 | 82 | [[Printf]] 83 | deps = ["Unicode"] 84 | uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7" 85 | 86 | [[REPL]] 87 | deps = ["InteractiveUtils", "Markdown", "Sockets"] 88 | uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" 89 | 90 | [[Random]] 91 | deps = ["Serialization"] 92 | uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" 93 | 94 | [[SHA]] 95 | uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" 96 | 97 | [[Serialization]] 98 | uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b" 99 | 100 | [[Sockets]] 101 | uuid = "6462fe0b-24de-5631-8697-dd941f90decc" 102 | 103 | [[SpecialFunctions]] 104 | deps = ["OpenSpecFun_jll"] 105 | git-tree-sha1 = "268052ee908b2c086cc0011f528694f02f3e2408" 106 | uuid = "276daf66-3868-5448-9aa4-cd146d93841b" 107 | version = "0.9.0" 108 | 109 | [[Test]] 110 | deps = ["Distributed", "InteractiveUtils", "Logging", "Random"] 111 | uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" 112 | 113 | [[UUIDs]] 114 | deps = ["Random", "SHA"] 115 | uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" 116 | 117 | [[Unicode]] 118 | uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5" 119 | -------------------------------------------------------------------------------- /Project.toml: -------------------------------------------------------------------------------- 1 | name = "diff-zoo" 2 | uuid = "0e990500-197d-40a9-9393-fcd66887cc96" 3 | authors = ["Mike J Innes "] 4 | 5 | [deps] 6 | InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" 7 | Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306" 8 | MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" 9 | Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" 10 | SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" 11 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Differentiation for Hackers 2 | 3 | [![Build Status](https://travis-ci.org/MikeInnes/diff-zoo.svg?branch=master)](https://travis-ci.org/MikeInnes/diff-zoo) 4 | 5 | The goal of this handbook is to demystify *algorithmic differentiation*, the 6 | tool that underlies modern machine learning. It begins with a calculus-101 style 7 | understanding and gradually extends this to build toy implementations of systems 8 | similar to PyTorch and TensorFlow. I have tried to clarify the relationships 9 | between every kind of differentiation I can think of – including forward and 10 | reverse, symbolic, numeric, tracing and source transformation. Where typical real-word ADs are mired in implementation details, these implementations are designed to be coherent enough that the real, fundamental differences – of which there are surprisingly few – become obvious. 11 | 12 | The intro notebook is recommended to start with, but otherwise notebooks do not have a fixed order. 13 | 14 | * [Intro](https://github.com/MikeInnes/diff-zoo/blob/notebooks/intro.ipynb) – explains the basics, beginning with a simple symbolic differentiation routine. 15 | * [Back & Forth](https://github.com/MikeInnes/diff-zoo/blob/notebooks/backandforth.ipynb) – discusses the difference between forward and reverse mode AD. 16 | * [Forward](https://github.com/MikeInnes/diff-zoo/blob/notebooks/forward.ipynb) – discusses forward-mode AD and its relationship to symbolic and numerical differentiation. 17 | * [Tracing](https://github.com/MikeInnes/diff-zoo/blob/notebooks/tracing.ipynb) – discusses tracing-based implementations of reverse mode, as used by TensorFlow and PyTorch. 18 | * [Reverse](https://github.com/MikeInnes/diff-zoo/blob/notebooks/reverse.ipynb) – discusses a more powerful reverse mode based on source transformation (not complete). 19 | 20 | If you want to run the notebooks locally, they can be built by running the 21 | `src/notebooks.jl` script using Julia. They should appear inside a `/notebooks` 22 | folder. Alternatively, you can run through the scripts in Juno. 23 | -------------------------------------------------------------------------------- /src/backandforth.jl: -------------------------------------------------------------------------------- 1 | # Forward- and Reverse-Mode Differentiation 2 | # ========================================= 3 | 4 | include("utils.jl"); 5 | 6 | # Differentiation tools are frequently described as implementing "forward mode" 7 | # or "reverse mode" AD. This distinction was briefly covered in the intro 8 | # notebook, but here we'll go into more detail. We'll start with an intuition 9 | # for what the distinction *means* in terms of the differentiation process; then 10 | # we'll discuss why it's an important consideration in practice. 11 | 12 | # Consider a simple mathematical expression: 13 | 14 | y = :(sin(x^2) * 5) 15 | 16 | # Written as a Wengert list: 17 | 18 | Wengert(y) 19 | 20 | # The ability to take derivatives mechanically relies on two things: Firstly, we 21 | # know derivatives for each basic function in our program (e.g. 22 | # $\frac{dy_2}{dy_1}=cos(y_1)$). Secondly, we have a rule of composition called the 23 | # *chain rule* which lets us compose these basic derivatives together. 24 | # 25 | # $$ 26 | # \frac{dy}{dx} = \frac{dy_1}{dx} \times \frac{dy_2}{dy_1} \times \frac{dy}{dy_2} 27 | # $$ 28 | # 29 | # More specifically: 30 | # 31 | # $$ 32 | # \begin{align} 33 | # \frac{dy}{dx} &= 2x \times cos(y_1) \times 5 \\ 34 | # &= 2x \times cos(x^2) \times 5 35 | # \end{align} 36 | # $$ 37 | # 38 | # The forward/reverse distinction basically amounts to: given that we do 39 | # multiplications one at a time, do we evaluate $\frac{dy_1}{dx} \times 40 | # \frac{dy_2}{dy_1}$ first, or $\frac{dy_2}{dy_1} \times \frac{dy}{dy_2}$? (This seems 41 | # like a pointless question right now, given that either gets us the same 42 | # results, but bear with me.) 43 | # 44 | # It's easier to see the distinction if we think algorithmically. Given some 45 | # enormous Wengert list with $n$ instructions, we have two ways to differentiate 46 | # it: 47 | # 48 | # **(1)**: Start with the known quantity $\frac{dy_0}{dx} = \frac{dx}{dx} = 1$ 49 | # at the beginning of the list. Look up the derivative for the next instruction 50 | # $\frac{dy_{i+1}}{dy_i}$ and multiply out the top, getting $\frac{dy_1}{dx}$, 51 | # $\frac{dy_2}{dx}$, ... $\frac{dy_{n-1}}{dx}$, $\frac{dy}{dx}$. Because we 52 | # walked forward over the Wengert list, this is called *forward mode*. Each 53 | # intermediate derivative $\frac{dy_i}{dx}$ is known as a *perturbation*. 54 | # 55 | # **(2)**: Start with the known quantity $\frac{dy}{dy_n} = \frac{dy}{dy} = 1$ 56 | # at the end of the list. Look up the derivative for the previous instruction 57 | # $\frac{dy_i}{dy_{i-1}}$ and multiply out the bottom, getting 58 | # $\frac{dy}{dy_n}$, $\frac{dy}{dy_{n-1}}$, ... $\frac{dy}{dy_1}$, 59 | # $\frac{dy}{dx}$. Because we walked in reverse over the Wengert list, this is 60 | # called *reverse mode*. Each intermediate derivative $\frac{dy}{dy_i}$ is known 61 | # as a *sensitivity*. 62 | # 63 | # This all seems very academic, so we need to explain why it might make a 64 | # difference to performance. I'll give two related explanations: dealing with 65 | # mulitple variables, and working with vectors rather than scalars. 66 | 67 | # Explanation 1: Multiple Variables 68 | # --------------------------------- 69 | # 70 | # So far we have dealt only with simple functions that take a number, and return 71 | # a number. But more generally we'll deal with functions that take, or produce, 72 | # multiple numbers of interest. 73 | # 74 | # For example what if we have a function that returns *two* numbers, and we want 75 | # derivatives for both? Do we have to do the whole thing twice over? 76 | 77 | y = quote 78 | y2 = sin(x^2) 79 | y3 = y2 * 5 80 | end 81 | 82 | # Let's say we want both of the derivatives $\frac{dy_2}{dx}$ and 83 | # $\frac{dy_3}{dx}$. You can probably see where this is going now; the Wengert 84 | # list representation of this expression has not actually changed! 85 | 86 | Wengert(y) 87 | 88 | # Now, we discussed that when doing forward mode differentiation, we actually 89 | # calculate *every* intermediate derivative $\frac{dy_i}{dx}$, which means we get 90 | # $\frac{dy_2}{dx}$ for free. This property goes all the way back to our 91 | # original, recursive formulation of differentiation, which calculated the 92 | # derivatives of a complex expression by combining the derivatives of simpler 93 | # ones. 94 | 95 | derive(Wengert(y), :x) 96 | 97 | # In our output, $y_7 = \frac{dy_2}{dx}$ and $y_8 = \frac{dy_3}{dx}$. 98 | # 99 | # Let's consider the opposite situation, a function of two variables $a$ and 100 | # $b$, where we'd like to get $\frac{dy}{da}$ and $\frac{dy}{db}$. 101 | 102 | y = :(sin(a) * b) 103 | #- 104 | Wengert(y) 105 | 106 | # This one is a bit tougher. We can start the forward-mode differentiation 107 | # process with $\frac{da}{da} = 1$ or with $\frac{db}{db} = 1$, but if we want 108 | # both we'll have to go through the entire multiplying-out process twice. 109 | # 110 | # But both variables ultimately end up at the same place, $y$, and we know that 111 | # $\frac{dy}{dy} = 1$. Aha, so perhaps we can use reverse mode for this 112 | # instead! 113 | # 114 | # Exactly opposite to forward mode, reverse mode gives us every intermediate 115 | # gradient $\frac{dy_i}{dy}$ for free, ultimately leading back in the inputs 116 | # $\frac{da}{dy}$ and $\frac{db}{dy}$. 117 | # 118 | # It's easy to see, then, why reverse-mode differentiation – or backpropagation 119 | # – is so effective for machine learning. In general we have a large computation 120 | # with millions of parameters, yet only a single scalar loss to optimise. We can 121 | # get gradients even for these millions of inputs in a single pass, enabling ML 122 | # to scale to complex tasks like image and voice recognition. 123 | 124 | # Explanation 2: Vector Calculus 125 | # ------------------------------ 126 | # 127 | # So far we have dealt only with simple functions that take a number, and return 128 | # a number. But more generally we'll deal with functions that take, or produce, 129 | # *vectors* containing multiple numbers of interest. 130 | # 131 | # It's useful to consider how our idea of differentiation works when we have 132 | # vectors. For example, a function that takes a vector of length $2$ to another 133 | # vector of length $2$: 134 | 135 | f(x) = [x[1] * x[2], cos(x[1])] 136 | 137 | x = [2, 3] 138 | y = f(x) 139 | 140 | # We now need to talk about what we mean by $\frac{d}{dx}f(x)$, given that we 141 | # can't apply the usual limit rule. What we *can* do is take the derivative of 142 | # any scalar *element* of $y$ with respect to any element of $x$. (We'll use 143 | # subscripts $x_n$ to refer to the $n^{th}$ index of $x$.) For example: 144 | # 145 | # $$ 146 | # \begin{align} 147 | # \frac{dy_1}{dx_1} &= \frac{d}{dx_1} x_1 \times x_2 = x_2 \\ 148 | # \frac{dy_1}{dx_2} &= \frac{d}{dx_2} x_1 \times x_2 = x_1 \\ 149 | # \frac{dy_2}{dx_1} &= \frac{d}{dx_1} \cos(x_1) = -\sin(x_1) \\ 150 | # \frac{dy_2}{dx_2} &= \frac{d}{dx_2} \cos(x_1) = 0 \\ 151 | # \end{align} 152 | # $$ 153 | # 154 | # It's a little easier if we organise all of these derivatives into a matrix. 155 | # 156 | # $$ 157 | # J_{ij} = \frac{dy_i}{dx_j} 158 | # $$ 159 | # 160 | # This $2\times2$ matrix is called the *Jacobian*, and in general it's what we mean by 161 | # $\frac{dy}{dx}$. (The Jacobian for a scalar function like $y = \sin(x)$ only 162 | # has one element, so it's consistent with our current idea of the derivative 163 | # $\frac{dy}{dx}$.) The key point here is that the Jacobian is a potentially 164 | # large object: it has a size `length(y) * length(x)`. Now, we discussed that 165 | # the distinction between forward and reverse mode is whether we propagate 166 | # $\frac{dy_i}{dx}$ or $\frac{dy}{dy_i}$, which can have a size of either 167 | # `length(y_i) * length(x)` or `length(y) * length(y_i)`. 168 | # 169 | # It should be clear, then, what mode is better if we have a gazillion inputs 170 | # and one output. In forward mode we need to carry around a gazillion 171 | # "perturbations" *for each* element of $y_i$, whereas in reverse we only need a 172 | # gradient of the same size of $x$. And vice versa. 173 | -------------------------------------------------------------------------------- /src/forward.jl: -------------------------------------------------------------------------------- 1 | # Implementing Forward Mode 2 | # ========================= 3 | # 4 | # In the [intro notebook](./intro.ipynb) we covered forward-mode differentiation 5 | # thoroughly, but we don't yet have a real implementation that can work on our 6 | # programs. Implementing AD effectively and efficiently is a field of its own, 7 | # and we'll need to learn a few more tricks to get off the ground. 8 | 9 | include("utils.jl"); 10 | 11 | # Up to now, we have differentiated things by creating a new Wengert list which 12 | # contains parts of the original expression. 13 | 14 | y = Wengert(:(5sin(log(x)))) 15 | derive(y, :x) 16 | 17 | # We're now going to explicitly split our lists into two pieces: the original 18 | # expression, and a new one which only calculates derivatives (but might refer 19 | # back to values from the first). For example: 20 | 21 | y = Wengert(:(5sin(log(x)))) 22 | #- 23 | dy = derive(y, :x, out = Wengert(variable = :dy)) 24 | #- 25 | Expr(dy) 26 | 27 | # If we want to distinguish them, we can call `y` the *primal code* and `dy` 28 | # the *tangent code*. Nothing fundamental has changed here, but it's useful 29 | # to organise things this way. 30 | # 31 | # Almost all of the subtlety in differentiating programs comes from a 32 | # mathematically trivial question: in what order do we evaluate the statements 33 | # of the Wengert list? We have discussed the 34 | # [forward/reverse](./backandforth.ipynb) distinction, but even once that choice 35 | # is made, we have plenty of flexibility, and those choices can affect efficiency. 36 | # 37 | # For example, imagine if we straightforwardly evaluate `y` followed by `dy`. If 38 | # we only cared about the final output of `y`, this would be no problem at all, 39 | # but in general `dy` also needs to re-use variables like `y1` (or possibly any 40 | # $y_i$). If our primal Wengert list has, say, a billion instructions, we end up 41 | # having to store a billion intermediate $y_i$ before we run our tangent code. 42 | # 43 | # Alternatively, one can imagine running each instruction of the tangent code as 44 | # early as possible; as soon as we run `y1 = log(x)`, for example, we know we 45 | # can run `dy2 = cos(y1)` also. Then our final, combined program would look 46 | # something like this: 47 | 48 | # ```julia 49 | # y0 = x 50 | # dy = 1 51 | # y1 = log(y0) 52 | # dy = dy/y0 53 | # y2 = cos(y1) 54 | # dy = dy*sin(y1) 55 | # ... 56 | # ``` 57 | 58 | # Now we can throw out `y1` soon after creating it, and we no longer have to 59 | # store those billion intermediate results. 60 | # 61 | # The ability to do this is a very general property of forward differentiation; 62 | # once we run $a = f(b)$, we can then run $\frac{da}{dx} = \frac{da}{db} 63 | # \frac{db}{dx}$ using only `a` and `b`. It's really just a case of replacing 64 | # basic instructions like `cos` with versions that calculate both the primal and 65 | # tangent at once. 66 | 67 | # Dual Numbers 68 | # ------------ 69 | # 70 | # Finally, the trick that we've been building up to: making our programming 71 | # language do this all for us! Almost all common languages – with the notable 72 | # exception of C – provide good support for *operator overloading*, which allows 73 | # us to do exactly this replacement. 74 | # 75 | # To start with, we'll make a container that holds both a $y$ and a 76 | # $\frac{dy}{dx}$, called a *dual number*. 77 | 78 | struct Dual{T<:Real} <: Real 79 | x::T 80 | ϵ::T 81 | end 82 | 83 | Dual(1, 2) 84 | #- 85 | Dual(1.0,2.0) 86 | 87 | # Let's print it nicely. 88 | 89 | Base.show(io::IO, d::Dual) = print(io, d.x, " + ", d.ϵ, "ϵ") 90 | 91 | Dual(1, 2) 92 | 93 | # And add some of our rules for differentiation. The rules have the same basic 94 | # pattern-matching structure as the ones we originally applied to our Wengert 95 | # list, just with different notation. 96 | 97 | import Base: +, -, *, / 98 | a::Dual + b::Dual = Dual(a.x + b.x, a.ϵ + b.ϵ) 99 | a::Dual - b::Dual = Dual(a.x - b.x, a.ϵ - b.ϵ) 100 | a::Dual * b::Dual = Dual(a.x * b.x, b.x * a.ϵ + a.x * b.ϵ) 101 | a::Dual / b::Dual = Dual(a.x / b.x, (b.x * a.ϵ - a.x * b.ϵ) / b.x^2) 102 | 103 | Base.sin(d::Dual) = Dual(sin(d.x), d.ϵ * cos(d.x)) 104 | Base.cos(d::Dual) = Dual(cos(d.x), - d.ϵ * sin(d.x)) 105 | 106 | Dual(2, 2) * Dual(3, 4) 107 | 108 | # Finally, we'll hook into Julia's number promotion system; this isn't essential 109 | # to understand, but just makes it easier to work with Duals since we can now 110 | # freely mix them with other number types. 111 | 112 | Base.convert(::Type{Dual{T}}, x::Dual) where T = Dual(convert(T, x.x), convert(T, x.ϵ)) 113 | Base.convert(::Type{Dual{T}}, x::Real) where T = Dual(convert(T, x), zero(T)) 114 | Base.promote_rule(::Type{Dual{T}}, ::Type{R}) where {T,R} = Dual{promote_type(T,R)} 115 | 116 | Dual(1, 2) * 3 117 | 118 | # We already have enough to start taking derivatives of some simple functions. 119 | # If we pass a dual number into a function, the $\epsilon$ component represents 120 | # the derivative. 121 | 122 | f(x) = x / (1 + x*x) 123 | 124 | f(Dual(5., 1.)) 125 | 126 | # We can make a utility which allows us to differentiate any function. 127 | 128 | D(f, x) = f(Dual(x, one(x))).ϵ 129 | 130 | D(f, 5.) 131 | 132 | # Dual numbers seem pretty scarcely related to all the Wengert list stuff we 133 | # were talking about earlier. But we need take a closer look at how this is 134 | # working. To start with, look at Julia's internal representation of `f`. 135 | 136 | @code_typed f(1.0) 137 | 138 | # This is just a Wengert list! Though the naming is a little different – `mul_float` 139 | # rather than the more general `*` and so on – it's still essentially the same 140 | # data structure we were working with earlier. Moreover, you'll recognise 141 | # the code for the derivative, too! 142 | 143 | @code_typed D(f, 1.0) 144 | 145 | # This code is again the same as the Wengert list derivative we worked out at 146 | # the very beginning of this handbook. The order of operations is just a little 147 | # different, and there's the odd missing or new instruction due to the different 148 | # set of optimisations that Julia applies. Still, we have not escaped our fundamental 149 | # symbolic differentiation algorithm, just tricked the compiler into doing most 150 | # of the work for us. 151 | 152 | derive(Wengert(:(sin(cos(x)))), :x) 153 | #- 154 | @code_typed D(x -> sin(cos(x)), 0.5) 155 | 156 | # What of data structures, control flow, function calls? Although these things 157 | # are all present in Julia's internal "Wengert list", they end up being the same 158 | # in the tangent program as in the primal; so an operator overloading approach 159 | # need not deal with them explicitly to do the right thing. This won't be true 160 | # when we come to talk more about reverse mode, which demands a more complex 161 | # approach. 162 | 163 | # Perturbation Confusion 164 | # ---------------------- 165 | # 166 | # Actually, that's not quite true. Operator-overloading-based forward mode 167 | # *almost always* does the right thing, but it is not flawless. This more 168 | # advanced section will talk about nested differentiation and the nasty bug that 169 | # can come with it. 170 | # 171 | # We can differentiate any function we want, as long as we have the right 172 | # primitive definitions for it. For example, the derivative of $\sin(x)$ is 173 | # $\cos(x)$. 174 | 175 | D(sin, 0.5), cos(0.5) 176 | 177 | # We can also differentiate the differentiation operator itself. We'll find that 178 | # the second derivative of $\sin(x)$ is $-\sin(x)$. 179 | 180 | D(x -> D(sin, x), 0.5), -sin(0.5) 181 | 182 | # This worked because we ended up nesting dual numbers. If we create a dual number 183 | # whose $\epsilon$ component is another dual number, then we end up tracking the 184 | # derivative of the derivative. 185 | 186 | # The issue comes about when we close over a variable that *is itself* being 187 | # differentiated. 188 | 189 | D(x -> x*D(y -> x+y, 1), 1) # == 1 190 | 191 | # The derivative $\frac{d}{dy} (x + y) = 1$, so this is equivalent to 192 | # $\frac{d}{dx}x$, which should also be $1$. So where did this go wrong? The 193 | # problem is that when we closed over $x$, we didn't just get a numeric value 194 | # but a dual number with $\epsilon = 1$. When we then calculated $x + y$, both 195 | # epsilons were added as if $\frac{dx}{dy} = 1$ (effectively $x = y$). If we had 196 | # written this down, the answer would be correct. 197 | 198 | D(x -> x*D(y -> y+y, 1), 1) 199 | 200 | # I leave this second example as an excercise to the reader. Needless to say, 201 | # this has caught out many an AD implementor. 202 | 203 | D(x -> x*D(y -> x*y, 1), 4) # == 8 204 | 205 | # More on Dual Numbers 206 | # -------------------- 207 | # 208 | # The above discussion presented dual numbers as essentially being a trick for 209 | # applying the chain rule. I wanted to take the opportunity to present an 210 | # alternative viewpoint, which might be appealing if, like me, you have any 211 | # training in physics. 212 | # 213 | # Complex arithmetic involves a new number, $i$, which behaves like no other: 214 | # specifically, because $i^2 = -1$. We'll introduce a number called $\epsilon$, 215 | # which is a bit like $i$ except that $\epsilon^2 = 0$; this is effectively a 216 | # way of saying the $\epsilon$ is a very small number. The relevance of this 217 | # comes from the original definition of differentiation, which also requires 218 | # $\epsilon$ to be very small. 219 | # 220 | # $$ 221 | # \frac{d}{dx} f(x) = \lim_{\epsilon \to 0} \frac{f(x+\epsilon)-f(x)}{\epsilon} 222 | # $$ 223 | # 224 | # We can see how our definition of $\epsilon$ works out by applying it to 225 | # $f(x+\epsilon)$; let's say that $f(x) = sin(x^2)$. 226 | # 227 | # \begin{align} 228 | # f(x + \epsilon) &= \sin((x + \epsilon)^2) \\ 229 | # &= \sin(x^2 + 2x\epsilon + \epsilon^2) \\ 230 | # &= \sin(x^2 + 2x\epsilon) \\ 231 | # &= \sin(x^2)\cos(2x\epsilon) + \cos(x^2)\sin(2x\epsilon) \\ 232 | # &= \sin(x^2) + 2x\cos(x^2)\epsilon \\ 233 | # \end{align} 234 | # 235 | # A few things have happened here. Firstly, we directly expand $(x+\epsilon)^2$ 236 | # and remove the $\epsilon^2$ term. We expand $sin(a+b)$ and then apply a *small 237 | # angle approximation*: for small $\theta$, $\sin(\theta) \approx \theta$ and 238 | # $\cos(\theta) \approx 1$. (This sounds pretty hand-wavy, but does follow from 239 | # our original definition of $\epsilon$ if we look at the Taylor expansion of 240 | # both functions). Finally we can plug this into our derivative rule. 241 | # 242 | # \begin{align} 243 | # \frac{d}{dx} f(x) &= \frac{f(x+\epsilon)-f(x)}{\epsilon} \\ 244 | # &= 2x\cos(x^2) 245 | # \end{align} 246 | # 247 | # This is, in my opinion, a rather nice way to derive functions by hand. 248 | # 249 | # This also leads to another nice trick, and a third way to look at forward-mode 250 | # AD; if we replace $x + \epsilon$ with $x + \epsilon i$ then we still have 251 | # $(\epsilon i)^2 = 0$. If $\epsilon$ is a small real number (say 252 | # $1\times10^{-10}$), this is still true within floating point error, so our 253 | # derivative still works out. 254 | 255 | ϵ = 1e-10im 256 | x = 0.5 257 | 258 | f(x) = sin(x^2) 259 | 260 | (f(x+ϵ) - f(x)) / ϵ 261 | #- 262 | 2x*cos(x^2) 263 | 264 | # So complex numbers can be used to get exact derivatives! This is very efficient 265 | # and can be written using only one call to `f`. 266 | 267 | imag(f(x+ϵ)) / imag(ϵ) 268 | 269 | # Another way of looking at this is that we are doing bog-standard numerical 270 | # differentiation, but the use of complex numbers avoids the typical problem 271 | # with that technique (i.e. that a small perturbation ends up being overwhelmed 272 | # by floating point error). The dual number is then a slight variation which 273 | # makes the limit $\epsilon \rightarrow 0$ exact, rather than approximate. 274 | # Forward mode AD can be described as "just" a clever implementation of 275 | # numerical differentiation. Both numerical and forward derivatives propagate 276 | # a perturbation $\epsilon$ using the same basic rules, and they have the 277 | # same algorithmic properties. 278 | -------------------------------------------------------------------------------- /src/intro.jl: -------------------------------------------------------------------------------- 1 | # Differentiation for Hackers 2 | # =========================== 3 | # 4 | # These notebooks are an exploration of various approaches to analytical 5 | # differentiation. Differentiation is something you learned in school; we start 6 | # with an expression like $y = 3x^2 + 2x + 1$ and find an expression for the 7 | # derivative like $\frac{dy}{dx} = 6x + 2$. Once we have such an expression, we 8 | # can *evaluate* it by plugging in a specific value for $x$ (say 0.5) to find 9 | # the derivative at that point (in this case $\frac{dy}{dx} = 5$). 10 | # 11 | # Despite its surface simplicity, this technique lies at the core of all modern 12 | # machine learning and deep learning, alongside many other parts of statistics, 13 | # mathematical optimisation and engineering. There has recently been an 14 | # explosion in automatic differentiation (AD) tools, all with different designs 15 | # and tradeoffs, and it can be difficult to understand how they relate to each 16 | # other. 17 | # 18 | # We aim to fix this by beginning with the "calculus 101" rules that you are 19 | # familiar with and implementing simple symbolic differentiators over mathematical 20 | # expressions. Then we show how tweaks to this basic framework generalise from 21 | # expressions to programming languages, leading us to modern automatic 22 | # differentiation tools and machine learning frameworks like TensorFlow and 23 | # PyTorch, and giving us a unified view across the AD landscape. 24 | 25 | # Symbolic Differentiation 26 | # ------------------------ 27 | # 28 | # To talk about derivatives, we need to talk about *expressions*, which are 29 | # symbolic forms like $x^2 + 1$ (as opposed to numbers like $5$). Normal Julia 30 | # programs only work with numbers; we can write down $x^2 + 1$ but this only 31 | # lets us calculate its value for a specific $x$. 32 | 33 | x = 2 34 | y = x^2 + 1 35 | 36 | # However, Julia also offers a *quotation operator* which lets us talk about the 37 | # expression itself, without needing to know what $x$ is. 38 | 39 | y = :(x^2 + 1) 40 | #- 41 | typeof(y) 42 | 43 | # Expressions are a tree data structure. They have a `head` which tells us what 44 | # kind of expression they are (say, a function call or if statement). They have 45 | # `args`, their children, which may be further sub-expressions. For example, 46 | # $x^2 + 1$ is a call to $+$, and one of its children is the expression $x^2$. 47 | 48 | y.head 49 | #- 50 | y.args 51 | 52 | # We could have built this expression by hand rather than using quotation. It's 53 | # just a bog-standard tree data structure that happens to have nice printing. 54 | 55 | x2 = Expr(:call, :^, :x, 2) 56 | #- 57 | y = Expr(:call, :+, x2, 1) 58 | 59 | # We can evaluate our expression to get a number out. 60 | 61 | eval(y) 62 | 63 | # When we differentiate something, we'll start by manipulating an expression 64 | # like this, and then we can optionally evaluate it with numbers to get a 65 | # numerical derivative. I'll call these the "symbolic phase" and the "numeric 66 | # phase" of differentiation, respectively. 67 | 68 | # How might we differentiate an expression like $x^2 + 1$? We can start by 69 | # looking at the basic rules in differential calculus. 70 | # 71 | # $$ 72 | # \begin{align} 73 | # \frac{d}{dx} x &= 1 \\ 74 | # \frac{d}{dx} (-u) &= - \frac{du}{dx} \\ 75 | # \frac{d}{dx} (u + v) &= \frac{du}{dx} + \frac{dv}{dx} \\ 76 | # \frac{d}{dx} (u * v) &= v \frac{du}{dx} + u \frac{dv}{dx} \\ 77 | # \frac{d}{dx} (u / v) &= (v \frac{du}{dx} - u \frac{dv}{dx}) / v^2 \\ 78 | # \frac{d}{dx} u^n &= n u^{n-1} \\ 79 | # \end{align} 80 | # $$ 81 | # 82 | # Seeing $\frac{d}{dx}(u)$ as a function, these rules look a lot like a 83 | # recursive algorithm. To differentiate something like `y = a + b`, we 84 | # differentiate `a` and `b` and combine them together. To differentiate `a` we 85 | # do the same thing, and so on; eventually we'll hit something like `x` or `3` 86 | # which has a trivial derivative ($1$ or $0$). 87 | 88 | # Let's start by handling the obvious cases, $y = x$ and $y = 1$. 89 | 90 | function derive(ex, x) 91 | ex == x ? 1 : 92 | ex isa Union{Number,Symbol} ? 0 : 93 | error("$ex is not differentiable") 94 | end 95 | #- 96 | y = :(x) 97 | derive(y, :x) 98 | #- 99 | y = :(1) 100 | derive(y, :x) 101 | 102 | # We can look for expressions of the form `y = a + b` using pattern matching, 103 | # with a package called 104 | # [MacroTools](https://github.com/MikeInnes/MacroTools.jl). If `@capture` 105 | # returns true, then we can work with the sub-expressions `a` and `b`. 106 | 107 | using MacroTools 108 | 109 | y = :(x + 1) 110 | #- 111 | @capture(y, a_ * b_) 112 | #- 113 | @capture(y, a_ + b_) 114 | #- 115 | a, b 116 | 117 | # Let's use this to add a rule to `derive`, following the chain rule above. 118 | 119 | function derive(ex, x) 120 | ex == x ? 1 : 121 | ex isa Union{Number,Symbol} ? 0 : 122 | @capture(ex, a_ + b_) ? :($(derive(a, x)) + $(derive(b, x))) : 123 | error("$ex is not differentiable") 124 | end 125 | #- 126 | y = :(x + 1) 127 | derive(y, :x) 128 | #- 129 | y = :(x + (1 + (x + 1))) 130 | derive(y, :x) 131 | 132 | # These are the correct derivatives, even if they could be simplified a bit. We 133 | # can go on to add the rest of the rules similarly. 134 | 135 | function derive(ex, x) 136 | ex == x ? 1 : 137 | ex isa Union{Number,Symbol} ? 0 : 138 | @capture(ex, a_ + b_) ? :($(derive(a, x)) + $(derive(b, x))) : 139 | @capture(ex, a_ * b_) ? :($a * $(derive(b, x)) + $b * $(derive(a, x))) : 140 | @capture(ex, a_^n_Number) ? :($(derive(a, x)) * ($n * $a^$(n-1))) : 141 | @capture(ex, a_ / b_) ? :($b * $(derive(a, x)) - $a * $(derive(b, x)) / $b^2) : 142 | error("$ex is not differentiable") 143 | end 144 | 145 | # This is enough to get us a slightly more difficult derivative. 146 | 147 | y = :(3x^2 + (2x + 1)) 148 | dy = derive(y, :x) 149 | 150 | # This is correct – it's equivalent to $6x + 2$ – but it's also a bit noisy, with a 151 | # lot of redundant terms like $x + 0$. We can clean this up by creating some 152 | # smarter functions to do our symbolic addition and multiplication. They'll just 153 | # avoid actually doing anything if the input is redundant. 154 | 155 | addm(a, b) = a == 0 ? b : b == 0 ? a : :($a + $b) 156 | mulm(a, b) = 0 in (a, b) ? 0 : a == 1 ? b : b == 1 ? a : :($a * $b) 157 | mulm(a, b, c...) = mulm(mulm(a, b), c...) 158 | #- 159 | addm(:a, :b) 160 | #- 161 | addm(:a, 0) 162 | #- 163 | mulm(:b, 1) 164 | 165 | # Our tweaked `derive` function: 166 | 167 | function derive(ex, x) 168 | ex == x ? 1 : 169 | ex isa Union{Number,Symbol} ? 0 : 170 | @capture(ex, a_ + b_) ? addm(derive(a, x), derive(b, x)) : 171 | @capture(ex, a_ * b_) ? addm(mulm(a, derive(b, x)), mulm(b, derive(a, x))) : 172 | @capture(ex, a_^n_Number) ? mulm(derive(a, x),n,:($a^$(n-1))) : 173 | @capture(ex, a_ / b_) ? :($(mulm(b, derive(a, x))) - $(mulm(a, derive(b, x))) / $b^2) : 174 | error("$ex is not differentiable") 175 | end 176 | 177 | # And the output is much cleaner. 178 | 179 | y = :(3x^2 + (2x + 1)) 180 | dy = derive(y, :x) 181 | 182 | # Having done this, we can also calculate a nested derivative 183 | # $\frac{d^2y}{dx^2}$, and so on. 184 | 185 | ddy = derive(dy, :x) 186 | #- 187 | derive(ddy, :x) 188 | 189 | # There is a deeper problem with our differentiation algorithm, though. Look at 190 | # how big this derivative is. 191 | 192 | derive(:(x / (1 + x^2)), :x) 193 | 194 | # Adding an extra `* x` makes it even bigger! There's a bunch of redundant work 195 | # here, repeating the expression $1 + x^2$ three times over. 196 | 197 | derive(:(x / (1 + x^2) * x), :x) 198 | 199 | # This happens because our rules look like, 200 | # $\frac{d(u*v)}{dx} = u*\frac{dv}{dx} + v*\frac{du}{dx}$. 201 | # Every multiplication repeats the whole sub-expression and its derivative, 202 | # making the output exponentially large in the size of its input. 203 | # 204 | # This seems to be an achilles heel for our little differentiator, since it will 205 | # make it impractical to run on any realistically-sized program. But wait! 206 | # Things are not quite as simple as they seem, because this expression is not 207 | # *actually* as big as it looks. 208 | # 209 | # Imagine we write down: 210 | 211 | y1 = :(1 * 2) 212 | y2 = :($y1 + $y1 + $y1 + $y1) 213 | 214 | # This looks like a large expression, but in actual fact it does not contain 215 | # $1*2$ four times over, just four pointers to $y1$; it is not really a tree but 216 | # a graph that gets printed as a tree. We can show this by explicitly printing 217 | # the expression in a way that preserves structure. 218 | # 219 | # (The definition of `printstructure` is not important to understand, but is 220 | # here for reference.) 221 | 222 | printstructure(x, _, _) = x 223 | 224 | function printstructure(ex::Expr, cache = IdDict(), n = Ref(0)) 225 | haskey(cache, ex) && return cache[ex] 226 | args = map(x -> printstructure(x, cache, n), ex.args) 227 | cache[ex] = sym = Symbol(:y, n[] += 1) 228 | println(:($sym = $(Expr(ex.head, args...)))) 229 | return sym 230 | end 231 | 232 | printstructure(y2); 233 | 234 | # Note that this is *not* the same as running common subexpression elimination 235 | # to simplify the tree, which would have an $O(n^2)$ computational cost. If 236 | # there is real duplication in the expression, it'll show up. 237 | 238 | :(1*2 + 1*2) |> printstructure; 239 | 240 | # This is effectively a change in notation: we were previously using a kind of 241 | # "calculator notation" in which any computation used more than once had to be 242 | # repeated in full. Now we are allowed to use variable bindings to get the same 243 | # effect. 244 | 245 | # If we try `printstructure` on our differentiated code, we'll see that the 246 | # output is not so bad after all: 247 | 248 | :(x / (1 + x^2)) |> printstructure; 249 | #- 250 | derive(:(x / (1 + x^2)), :x) 251 | #- 252 | derive(:(x / (1 + x^2)), :x) |> printstructure; 253 | 254 | # The expression $x^2 + 1$ is now defined once and reused rather than being 255 | # repeated, and adding the extra `* x` now adds a couple of instructions to our 256 | # derivative, rather than doubling its size. It turns out that our "naive" 257 | # symbolic differentiator actually preserves structure in a very sensible way, 258 | # and we just needed the right program representation to exploit that. 259 | 260 | derive(:(x / (1 + x^2) * x), :x) 261 | #- 262 | derive(:(x / (1 + x^2) * x), :x) |> printstructure; 263 | 264 | # Calculator notation – expressions without variable bindings – is a terrible 265 | # format for anything, and will tend to blow up in size whether you 266 | # differentiate it or not. Symbolic differentiation is commonly criticised for 267 | # its susceptability to "expression swell", but in fact has nothing to do with 268 | # the differentiation algorithm itself, and we need not change it to get better 269 | # results. 270 | # 271 | # Conversely, the way we have used `Expr` objects to represent variable bindings 272 | # is perfectly sound, if a little unusual. This format could happily be used to 273 | # illustrate all of the concepts in this handbook, and the recursive algorithms 274 | # used to do so are elegant. However, it will clarify some things if they are 275 | # written a little more explicitly; for this we'll introduce a new, equivalent 276 | # representation for expressions. 277 | 278 | # The Wengert List 279 | # ---------------- 280 | # 281 | # The output of `printstructure` above is known as a "Wengert List", an explicit 282 | # list of instructions that's a bit like writing assembly code. Really, Wengert 283 | # lists are nothing more or less than mathematical expressions written out 284 | # verbosely, and we can easily convert to and from equivalent `Expr` objects. 285 | 286 | include("utils.jl"); 287 | #- 288 | y = :(3x^2 + (2x + 1)) 289 | #- 290 | wy = Wengert(y) 291 | #- 292 | Expr(wy) 293 | 294 | # Inside, we can see that it really is just a list of function calls, where 295 | # $y_n$ refers to the result of the $n^{th}$. 296 | 297 | wy.instructions 298 | 299 | # Like `Expr`s, we can also build Wengert lists by hand. 300 | 301 | w = Wengert() 302 | tmp = push!(w, :(x^2)) 303 | w 304 | #- 305 | push!(w, :($tmp + 1)) 306 | w 307 | 308 | # Armed with this, we can quite straightforwardly port our recursive symbolic 309 | # differentiation algorithm to the Wengert list. 310 | 311 | function derive(ex, x, w) 312 | ex isa Variable && (ex = w[ex]) 313 | ex == x ? 1 : 314 | ex isa Union{Number,Symbol} ? 0 : 315 | @capture(ex, a_ + b_) ? push!(w, addm(derive(a, x, w), derive(b, x, w))) : 316 | @capture(ex, a_ * b_) ? push!(w, addm(mulm(a, derive(b, x, w)), mulm(b, derive(a, x, w)))) : 317 | @capture(ex, a_^n_Number) ? push!(w, mulm(derive(a, x, w),n,:($a^$(n-1)))) : 318 | @capture(ex, a_ / b_) ? push!(w, :($(mulm(b, derive(a, x, w))) - $(mulm(a, derive(b, x, w))) / $b^2)) : 319 | error("$ex is not differentiable") 320 | end 321 | 322 | derive(w::Wengert, x) = (derive(w[end], x, w); w) 323 | 324 | # It behaves identically to what we wrote before; we have only changed the 325 | # underlying representation. 326 | 327 | derive(Wengert(:(3x^2 + (2x + 1))), :x) |> Expr 328 | 329 | # In fact, we can compare them directly using the `printstructure` function we 330 | # wrote earlier. 331 | 332 | derive(:(x / (1 + x^2)), :x) |> printstructure 333 | #- 334 | derive(Wengert(:(x / (1 + x^2))), :x) 335 | 336 | # They are *almost* identical; the only difference is the unused variable `y3` 337 | # in the Wengert version. This happens because our `Expr` format effectively 338 | # removes dead code for us automatically. We'll see the same thing happen if 339 | # we convert the Wengert list back into an `Expr`. 340 | 341 | derive(Wengert(:(x / (1 + x^2))), :x) |> Expr 342 | 343 | function derive(w::Wengert, x) 344 | ds = Dict() 345 | ds[x] = 1 346 | d(x) = get(ds, x, 0) 347 | for v in keys(w) 348 | ex = w[v] 349 | Δ = @capture(ex, a_ + b_) ? addm(d(a), d(b)) : 350 | @capture(ex, a_ * b_) ? addm(mulm(a, d(b)), mulm(b, d(a))) : 351 | @capture(ex, a_^n_Number) ? mulm(d(a),n,:($a^$(n-1))) : 352 | @capture(ex, a_ / b_) ? :($(mulm(b, d(a))) - $(mulm(a, d(b))) / $b^2) : 353 | error("$ex is not differentiable") 354 | ds[v] = push!(w, Δ) 355 | end 356 | return w 357 | end 358 | 359 | derive(Wengert(:(x / (1 + x^2))), :x) |> Expr 360 | 361 | # One more thing. The astute reader may notice that our differentiation 362 | # algorithm begins with $\frac{dx}{dx}=1$ and propagates this forward to the 363 | # output; in other words it does [forward-mode 364 | # differentiation](./backandforth.ipynb). We can tweak our code a little to do 365 | # reverse mode instead. 366 | 367 | function derive_r(w::Wengert, x) 368 | ds = Dict() 369 | d(x) = get(ds, x, 0) 370 | d(x, Δ) = ds[x] = haskey(ds, x) ? addm(ds[x],Δ) : Δ 371 | d(lastindex(w), 1) 372 | for v in reverse(collect(keys(w))) 373 | ex = w[v] 374 | Δ = d(v) 375 | if @capture(ex, a_ + b_) 376 | d(a, Δ) 377 | d(b, Δ) 378 | elseif @capture(ex, a_ * b_) 379 | d(a, push!(w, mulm(Δ, b))) 380 | d(b, push!(w, mulm(Δ, a))) 381 | elseif @capture(ex, a_^n_Number) 382 | d(a, mulm(Δ, n, :($a^$(n-1)))) 383 | elseif @capture(ex, a_ / b_) 384 | d(a, push!(w, mulm(Δ, b))) 385 | d(b, push!(w, :(-$(mulm(Δ, a))/$b^2))) 386 | else 387 | error("$ex is not differentiable") 388 | end 389 | end 390 | push!(w, d(x)) 391 | return w 392 | end 393 | 394 | # There are only two distinct algorithms in this handbook, and this is the 395 | # second! It's quite similar to forward mode, with the difference that we 396 | # walk backwards over the list, and each time we see a usage of a variable 397 | # $y_i$ we accumulate a gradient for that variable. 398 | 399 | derive_r(Wengert(:(x / (1 + x^2))), :x) |> Expr 400 | 401 | # For now, the output looks pretty similar to that of forward mode; we'll 402 | # explain why the [distinction makes a difference](./backandforth.ipynb) in future 403 | # notebooks. 404 | -------------------------------------------------------------------------------- /src/notebooks.jl: -------------------------------------------------------------------------------- 1 | root = joinpath(@__DIR__, "..") 2 | using Pkg; Pkg.activate(root) 3 | 4 | src = joinpath(root, "src") 5 | out = joinpath(root, "notebooks") 6 | 7 | using Literate 8 | 9 | mkpath(out) 10 | 11 | for f in ["Project.toml", "Manifest.toml"] 12 | cp(joinpath(root, f), joinpath(out, f), force = true) 13 | end 14 | 15 | function preprocess(s) 16 | s = "using Pkg; Pkg.activate(\".\"); Pkg.instantiate()\n#-\n" * s 17 | end 18 | 19 | for f in ["utils.jl"] 20 | cp(joinpath(src, f), joinpath(out, f), force = true) 21 | end 22 | 23 | for f in ["intro.jl", "backandforth.jl", "forward.jl", "tracing.jl", "reverse.jl"] 24 | Literate.notebook(joinpath(src, f), out, 25 | preprocess = preprocess, 26 | credit = false) 27 | end 28 | -------------------------------------------------------------------------------- /src/reverse.jl: -------------------------------------------------------------------------------- 1 | include("utils.jl"); 2 | 3 | # Source to Source Reverse Mode 4 | # ============================= 5 | # 6 | # [Forward mode](./forward.ipynb) works well because all of the symbolic 7 | # operations happen at Julia's compile time; Julia can then optimise the 8 | # resulting program (say, by applying SIMD instructions) and we get very fast 9 | # derivative code. Although we can differentiate Julia code by [compiling it to 10 | # a Wengert list](./tracing.ipynb), we'd be much better off if we could handle 11 | # Julia code directly; then reverse mode can benefit from these optimisations 12 | # too. 13 | # 14 | # However, Julia code is much more complex than a Wengert list, with constructs 15 | # like control flow, data structures and function calls. To do this we'll have 16 | # to handle each of these things in turn. 17 | # 18 | # The first thing to realise is that Julia code is much closer to a Wengert list 19 | # than it looks. Despite its rich syntax, the compiler works with a Wengert-like 20 | # format. The analyses and optimisations that compilers already carry out also 21 | # benefit from this easily-work-with structure. 22 | 23 | f(x) = x / (1 + x^2) 24 | 25 | @code_typed f(1.0) 26 | 27 | # Code with control flow is pnly a little different. We add `goto` statements 28 | # and a construct called the "phi function"; the result is called [SSA 29 | # form](https://en.wikipedia.org/wiki/Static_single_assignment_form). 30 | 31 | function pow(x, n) 32 | r = 1 33 | while n > 0 34 | n -= 1 35 | r *= x 36 | end 37 | return r 38 | end 39 | 40 | pow(2, 3) 41 | 42 | @code_typed pow(2, 3) 43 | 44 | # The details of this format are not too important. SSA form is powerful but 45 | # somewhat fiddly to work with in practice, so the aim of this notebook is 46 | # to give a broad intuition for how we handle this. 47 | -------------------------------------------------------------------------------- /src/tracing.jl: -------------------------------------------------------------------------------- 1 | # Tracing-based Automatic Differentiation 2 | # ======================================= 3 | # 4 | # Machine learning primarily needs [reverse-mode AD](./backandforth.ipynb), and 5 | # tracing / operator overloading approaches are by far the most popular way to 6 | # it; this is the technique used by ML frameworks from Theano to PyTorch. This 7 | # notebook will cover the techniques used by those frameworks, as well as 8 | # clarifying the distinction between the "static declaration" 9 | # (Theano/TensorFlow) and "eager execution" (Chainer/PyTorch/Flux) approaches to 10 | # AD. 11 | 12 | include("utils.jl") 13 | 14 | # Partial Evaluation 15 | # ------------------ 16 | # 17 | # Say we have a simple implementation of $x^n$ which we want to differentiate. 18 | 19 | function pow(x, n) 20 | r = 1 21 | for i = 1:n 22 | r *= x 23 | end 24 | return r 25 | end 26 | 27 | pow(2, 3) 28 | 29 | # We already know how to [differentiate Wengert lists](./intro.ipynb), but this 30 | # doesn't look much like one of those. In fact, we can't write this program as a 31 | # Wengert list at all, given that it contains control flow; and more generally 32 | # our programs might have things like data structures or function calls that we 33 | # don't know how to differentiate either. 34 | # 35 | # Though it's possible to generalise the Wengert list to handle these things, 36 | # there's actually a simple and surprisingly effective alternative, called 37 | # "partial evaluation". This means running some part of a program without 38 | # running all of it. For example, given an expression like $x + 5 * n$ where we 39 | # know $n = 3$, we can simplify to $x + 15$ even though we don't know what $x$ 40 | # is. This is a common trick in compilers, and Julia will often do it for you: 41 | 42 | f(x, n) = x + 5 * n 43 | g(x) = f(x, 3) 44 | 45 | code_typed(g, Tuple{Int})[1] 46 | 47 | # This suggests a solution to our dilemma above. If we know what $n$ is (say, 48 | # $3$), we can write `pow(x, 3)` as $((1*x)*x)*x$, which _is_ a Wengert 49 | # expression that we can differentiate. In effect, this is a kind of compilation 50 | # from a complex language (Julia, Python) to a much simpler one. 51 | 52 | # Static Declaration 53 | # ------------------ 54 | # 55 | # We want to trace all of the basic mathematical operations in the program, 56 | # stripping away everything else. We'll do this using Julia's operator 57 | # overloading; the idea is to create a new type which, rather than actually executing 58 | # operations like $a + b$, records them into a Wengert list. 59 | 60 | import Base: +, - 61 | 62 | struct Staged 63 | w::Wengert 64 | var 65 | end 66 | 67 | a::Staged + b::Staged = Staged(w, push!(a.w, :($(a.var) + $(b.var)))) 68 | 69 | a::Staged - b::Staged = Staged(w, push!(a.w, :($(a.var) - $(b.var)))) 70 | 71 | # Actually, all of our staged definitions follow the same pattern, so we can 72 | # just do them in a loop. We also add an extra method so that we can multiply 73 | # staged values by constants. 74 | 75 | for f in [:+, :*, :-, :^, :/] 76 | @eval Base.$f(a::Staged, b::Staged) = Staged(a.w, push!(a.w, Expr(:call, $(Expr(:quote, f)), a.var, b.var))) 77 | @eval Base.$f(a, b::Staged) = Staged(b.w, push!(b.w, Expr(:call, $(Expr(:quote, f)), a, b.var))) 78 | end 79 | 80 | # The idea here is to begin by creating a Wengert list (the "graph" in ML 81 | # framework parlance), and create some symbolic variables which do not yet 82 | # have numerical values. 83 | 84 | w = Wengert() 85 | x = Staged(w, :x) 86 | y = Staged(w, :y) 87 | 88 | # When we manipulate these variables, we'll get Wengert lists. 89 | 90 | z = 2x + y 91 | z.w |> Expr 92 | 93 | # Crucially, this works with our original `pow` function! 94 | 95 | w = Wengert() 96 | x = Staged(w, :x) 97 | 98 | y = pow(x, 3) 99 | y.w |> Expr 100 | 101 | # The rest is almost too easy! We already know how to derive this. 102 | 103 | dy = derive_r(y.w, :x) 104 | Expr(dy) 105 | 106 | # If we dump the derived code into a function, we get code for the derivative 107 | # of $x^3$ at any point (i.e. $3x^2$). 108 | 109 | @eval dcube(x) = $(Expr(dy)) 110 | 111 | dcube(5) 112 | 113 | # Congratulations, you just implemented TensorFlow. 114 | 115 | # Eager Execution 116 | # --------------- 117 | 118 | # This approach has a crucial problem; because it works by stripping out control 119 | # flow and parameters like $n$, it effectively freezes all of these things. We 120 | # can get a specific derivative for $x^3$, $x^4$ and so on, but we can't get the 121 | # general derivative of $x^n$ with a single Wengert list. This puts a severe 122 | # limitation on the kinds of models we can express.$^1$ 123 | # 124 | # The solution? Well, just re-build the Wengert list from scratch every time! 125 | 126 | function D(f, x) 127 | x_ = Staged(w, :x) 128 | dy = derive(f(x_).w, :x) 129 | eval(:(let x = $x; $(Expr(dy)) end)) 130 | end 131 | 132 | D(x -> pow(x, 3), 5) 133 | #- 134 | D(x -> pow(x, 5), 5) 135 | 136 | # This gets us our gradients, but it's not going to be fast – there's a lot of overhead 137 | # to building and evaluating the list/graph every time. There are two things we can 138 | # do to alleviate this: 139 | # 140 | # 1. Interpret, rather compile, the Wengert list. 141 | # 2. Fuse interpretation of the list (the numeric phase) with the building 142 | # and manipulation of the Wengert list (the symbolic phase). 143 | # 144 | # Implementing this looks a lot like the `Staged` object above. The key difference 145 | # is that alongside the Wengert list, we store the numerical values of each variable 146 | # and instruction as we go along. Also, rather than explicitly naming variables 147 | # `x`, `y` etc, we generate names using `gensym()`. 148 | 149 | gensym() 150 | #- 151 | struct Tape 152 | instructions::Wengert 153 | values 154 | end 155 | 156 | Tape() = Tape(Wengert(), Dict()) 157 | 158 | struct Tracked 159 | w::Tape 160 | var 161 | end 162 | 163 | function track(t::Tape, x) 164 | var = gensym() 165 | t.values[var] = x 166 | Tracked(t, var) 167 | end 168 | 169 | Base.getindex(x::Tracked) = x.w.values[x.var] 170 | 171 | for f in [:+, :*, :-, :^, :/] 172 | @eval function Base.$f(a::Tracked, b::Tracked) 173 | var = push!(a.w.instructions, Expr(:call, $(Expr(:quote, f)), a.var, b.var)) 174 | a.w.values[var] = $f(a[], b[]) 175 | Tracked(a.w, var) 176 | end 177 | @eval function Base.$f(a, b::Tracked) 178 | var = push!(b.w.instructions, Expr(:call, $(Expr(:quote, f)), a, b.var)) 179 | b.w.values[var] = $f(a, b[]) 180 | Tracked(b.w, var) 181 | end 182 | @eval function Base.$f(a::Tracked, b) 183 | var = push!(a.w.instructions, Expr(:call, $(Expr(:quote, f)), a.var, b)) 184 | a.w.values[var] = $f(a[], b) 185 | Tracked(a.w, var) 186 | end 187 | end 188 | 189 | # Now, when we call `pow` it looks a lot more like we are dealing with normal 190 | # numeric values; but there is still a Wengert list inside. 191 | 192 | t = Tape() 193 | x = track(t, 5) 194 | 195 | y = pow(x, 3) 196 | y[] 197 | 198 | y.w.instructions |> Expr 199 | 200 | # Finally, we need to alter how we derive this list. The key insight is that 201 | # since we already have values available, we don't need to explicitly build 202 | # and evaluate the derivative code; instead, we can just evaluate each instruction 203 | # numerically as we go along. We more-or-less just need to replace our symbolic 204 | # functions like (`addm`) with the regular ones (`+`). 205 | # 206 | # This is, of course, not a particularly optimised implementation, and faster 207 | # versions have many more tricks up their sleaves. But this gets at all the key 208 | # ideas. 209 | 210 | function derive(w::Tape, xs...) 211 | ds = Dict() 212 | val(x) = get(w.values, x, x) 213 | d(x) = get(ds, x, 0) 214 | d(x, Δ) = ds[x] = d(x) + Δ 215 | d(lastindex(w.instructions), 1) 216 | for v in reverse(collect(keys(w.instructions))) 217 | ex = w.instructions[v] 218 | Δ = d(v) 219 | if @capture(ex, a_ + b_) 220 | d(a, Δ) 221 | d(b, Δ) 222 | elseif @capture(ex, a_ * b_) 223 | d(a, Δ * val(b)) 224 | d(b, Δ * val(a)) 225 | elseif @capture(ex, a_^n_Number) 226 | d(a, Δ * n * val(a) ^ (n-1)) 227 | elseif @capture(ex, a_ / b_) 228 | d(a, Δ * val(b)) 229 | d(b, -Δ*val(a)/val(b)^2) 230 | else 231 | error("$ex is not differentiable") 232 | end 233 | end 234 | return map(x -> d(x.var), xs) 235 | end 236 | 237 | derive(y.w, x) 238 | 239 | # With this we can implement a more general gradient function. 240 | 241 | function gradient(f, xs...) 242 | t = Tape() 243 | xs = map(x -> track(t, x), xs) 244 | f(xs...) 245 | derive(t, xs...) 246 | end 247 | 248 | # Even with the limited set of gradients that we have, we're well on our way to 249 | # differentiating more complex programs, like a custom `sin` function. 250 | 251 | gradient((a, b) -> a*b, 2, 3) 252 | #- 253 | mysin(x) = sum((-1)^k/factorial(1.0+2k) * x^(1+2k) for k = 0:5) 254 | #- 255 | gradient(mysin, 0.5) 256 | #- 257 | cos(0.5) 258 | 259 | # We can even take nested derivatives! 260 | 261 | gradient(x -> gradient(mysin, x)[1], 0.5) 262 | #- 263 | -sin(0.5) 264 | 265 | # Though the tracing approach has significant limitations, its power is in how 266 | # easy it is to implement: one can build a fairly full-featured implementation, 267 | # in almost any language, in a weekend. Almost all languages have the 268 | # operator-overloading features required, and no matter how complex the host 269 | # language, one ends up with a simple Wengert list. 270 | 271 | # Note that we have not removed the need to apply our basic symbolic 272 | # differentiation algorithm here. We are still looking up gradient definitions, 273 | # reversing data flow and applying the chain rule – it's just interleaved with 274 | # our numerical operations, and we avoid building the output into an explicit 275 | # Wengert list. 276 | # 277 | # It's somewhat unusual to emphasise the symbolic side of AD, but I think it 278 | # gives us an incisive way to understand the tradeoffs that different systems 279 | # make. For example: TensorFlow-style AD has its numeric phase separate from 280 | # Python's runtime, which makes it awkward to use. Conversely, PyTorch does run 281 | # its numerical phase at runtime, but also its symbolic phase, making it 282 | # impossible to optimise the backwards pass. 283 | # 284 | # We [observed](./forward.ipynb) that OO-based forward mode is particularly 285 | # successful because it carries out its symbolic and numeric operations at 286 | # Julia's compile and run time, respectively. In the [source to source reverse 287 | # mode](./reverse.ipynb) notebook, we'll explore doing this for reverse mode as 288 | # well. 289 | 290 | # ### Footnotes 291 | 292 | # $^1$ Systems like TensorFlow can also just provide ways to inject control flow 293 | # into the graph. This brings us closer to a [source-to-source 294 | # approach](./reverse.ipynb) where Python is used to build an expression in 295 | # TensorFlows internal graph language. 296 | 297 | # Fun fact: PyTorch and Flux's tapes are actually closer to the `Expr` format 298 | # that we originally used, in which "tracked" tensors just have pointers to 299 | # their parents (implicitly forming a graph/Wengert list/expression tree). A 300 | # naive algorithm for backpropagation suffers from exponential runtime for the 301 | # *exact* same reason that naive symbolic diff does; "flattening" this graph 302 | # into a tree causes it to blow up in size. 303 | -------------------------------------------------------------------------------- /src/utils.jl: -------------------------------------------------------------------------------- 1 | using MacroTools, InteractiveUtils, SpecialFunctions 2 | 3 | struct Variable 4 | name::Symbol 5 | number::Int 6 | end 7 | 8 | Symbol(x::Variable) = Symbol(x.name, x.number) 9 | 10 | Base.show(io::IO, x::Variable) = print(io, ":(", x.name, x.number, ")") 11 | 12 | Base.print(io::IO, x::Variable) = Base.show_unquoted(io, x, 0, -1) 13 | Base.show_unquoted(io::IO, x::Variable, ::Int, ::Int) = 14 | print(io, x.name, x.number) 15 | 16 | struct Wengert 17 | variable::Symbol 18 | instructions::Vector{Any} 19 | end 20 | 21 | Wengert(; variable = :y) = Wengert(variable, []) 22 | 23 | Base.keys(w::Wengert) = (Variable(w.variable, i) for i = 1:length(w.instructions)) 24 | Base.lastindex(w::Wengert) = Variable(w.variable, length(w.instructions)) 25 | 26 | Base.getindex(w::Wengert, v::Variable) = w.instructions[v.number] 27 | 28 | function Base.show(io::IO, w::Wengert) 29 | println(io, "Wengert List") 30 | for (i, x) in enumerate(w.instructions) 31 | print(io, Variable(w.variable, i), " = ") 32 | Base.println(io, x) 33 | end 34 | end 35 | 36 | Base.push!(w::Wengert, x) = x 37 | 38 | function Base.push!(w::Wengert, x::Expr) 39 | isexpr(x, :block) && return pushblock!(w, x) 40 | x = Expr(x.head, map(x -> x isa Expr ? push!(w, x) : x, x.args)...) 41 | push!(w.instructions, x) 42 | return lastindex(w) 43 | end 44 | 45 | function pushblock!(w::Wengert, x) 46 | bs = Dict() 47 | rename(ex) = Expr(ex.head, map(x -> get(bs, x, x), ex.args)...) 48 | for arg in MacroTools.striplines(x).args 49 | if @capture(arg, x_ = y_) 50 | bs[x] = push!(w, rename(y)) 51 | else 52 | push!(w, rename(arg)) 53 | end 54 | end 55 | return Variable(w.variable, length(w.instructions)) 56 | end 57 | 58 | function Wengert(ex; variable = :y) 59 | w = Wengert(variable = variable) 60 | push!(w, ex) 61 | return w 62 | end 63 | 64 | function Expr(w::Wengert) 65 | cs = Dict() 66 | for x in w.instructions 67 | x isa Expr || continue 68 | for v in x.args 69 | v isa Variable || continue 70 | cs[v] = get(cs, v, 0) + 1 71 | end 72 | end 73 | bs = Dict() 74 | rename(ex::Expr) = Expr(ex.head, map(x -> get(bs, x, x), ex.args)...) 75 | rename(x) = x 76 | ex = :(;) 77 | for v in keys(w) 78 | if get(cs, v, 0) > 1 79 | push!(ex.args, :($(Symbol(v)) = $(rename(w[v])))) 80 | bs[v] = Symbol(v) 81 | else 82 | bs[v] = rename(w[v]) 83 | end 84 | end 85 | push!(ex.args, rename(bs[lastindex(w)])) 86 | return unblock(ex) 87 | end 88 | 89 | addm(a, b) = a == 0 ? b : b == 0 ? a : :($a + $b) 90 | mulm(a, b) = 0 in (a, b) ? 0 : a == 1 ? b : b == 1 ? a : :($a * $b) 91 | mulm(a, b, c...) = mulm(mulm(a, b), c...) 92 | 93 | function derive(w::Wengert, x; out = w) 94 | ds = Dict() 95 | ds[x] = 1 96 | d(x) = get(ds, x, 0) 97 | for v in keys(w) 98 | ex = w[v] 99 | Δ = @capture(ex, a_ + b_) ? addm(d(a), d(b)) : 100 | @capture(ex, a_ * b_) ? addm(mulm(a, d(b)), mulm(b, d(a))) : 101 | @capture(ex, a_^n_Number) ? mulm(d(a),n,:($a^$(n-1))) : 102 | @capture(ex, a_ / b_) ? :($(mulm(b, d(a))) - $(mulm(a, d(b))) / $b^2) : 103 | @capture(ex, sin(a_)) ? mulm(:(cos($a)), d(a)) : 104 | @capture(ex, cos(a_)) ? mulm(:(-sin($a)), d(a)) : 105 | @capture(ex, exp(a_)) ? mulm(v, d(a)) : 106 | @capture(ex, log(a_)) ? mulm(:(1/$a), d(a)) : 107 | error("$ex is not differentiable") 108 | ds[v] = push!(out, Δ) 109 | end 110 | return out 111 | end 112 | 113 | function derive_r(w::Wengert, x) 114 | ds = Dict() 115 | d(x) = get(ds, x, 0) 116 | d(x, Δ) = ds[x] = haskey(ds, x) ? addm(ds[x],Δ) : Δ 117 | d(lastindex(w), 1) 118 | for v in reverse(collect(keys(w))) 119 | ex = w[v] 120 | Δ = d(v) 121 | if @capture(ex, a_ + b_) 122 | d(a, Δ) 123 | d(b, Δ) 124 | elseif @capture(ex, a_ * b_) 125 | d(a, push!(w, mulm(Δ, b))) 126 | d(b, push!(w, mulm(Δ, a))) 127 | elseif @capture(ex, a_^n_Number) 128 | d(a, mulm(Δ, n, :($a^$(n-1)))) 129 | elseif @capture(ex, a_ / b_) 130 | d(a, push!(w, mulm(Δ, b))) 131 | d(b, push!(w, :(-$(mulm(Δ, a))/$b^2))) 132 | else 133 | error("$ex is not differentiable") 134 | end 135 | end 136 | push!(w, d(x)) 137 | return w 138 | end 139 | -------------------------------------------------------------------------------- /test/runtests.jl: -------------------------------------------------------------------------------- 1 | # just make sure notebooks are generated properly 2 | include("../src/notebooks.jl") 3 | --------------------------------------------------------------------------------