├── README.org ├── astgrad.nim ├── astgrad.nimble └── changelog.org /README.org: -------------------------------------------------------------------------------- 1 | * astGrad - Symbolic differentiation based on the Nim AST 2 | 3 | This library performs symbolic differentiation based on the nodes of 4 | the Nim AST. This allows for compile time generation of derivatives to 5 | avoid approximations due to numerical methods. 6 | 7 | Note: as this is dealing with symbolic differentiation and the code 8 | isn't extremely smart about doing simplification yet, the resulting 9 | function can become relatively large on higher orders. 10 | For example the 8th derivative of =tanh(x)= produces about ~7000 lines 11 | of code... 12 | 13 | For lower orders it works perfectly fine though and as the additional 14 | code generation is exponential in nature, even something like the 5th 15 | order is still reasonable in case of =tanh= (O(70) lines). 16 | 17 | Some simple simplification (of same addition / subtraction terms, 18 | multiplication / division) should help a lot. 19 | 20 | ** Usage 21 | 22 | Using this library is pretty straightforward. There is essentially 23 | only a single public macro: 24 | #+begin_src nim 25 | macro derivative(arg, wrt: untyped): untyped 26 | #+end_src 27 | 28 | This macro takes a Nim expression and a symbol to differentiate by. 29 | 30 | For example: 31 | #+begin_src nim 32 | echo derivative(x * x, x) == 2 * x 33 | echo derivative(x * y, y) == x 34 | echo derivative(exp(x), x) == exp(x) 35 | echo derivative(sin(x) * cos(μ)^2 + exp(-((x - μ)^2) / (2 * σ^2)), μ) 36 | #... 37 | #+end_src 38 | you get the idea. 39 | Of course every symbol used in the expression for which the derivative 40 | is to be computed must exist in the Nim code (otherwise you get 41 | "undeclared identifier" errors after the macro computed your gradient). 42 | 43 | In addition there is a helper template: 44 | #+begin_src nim 45 | template ∂(arg, wrt: untyped): untyped 46 | #+end_src 47 | to make the code a bit more pretty. In addition to this template 48 | higher order versions are defined using superscript unicode, 49 | e.g. =∂²=, =∂³= etc. 50 | 51 | Feel free to wrap the call in a procedure to generate the full 52 | gradient procedure: 53 | #+begin_src nim 54 | proc gradSin(x: float): float = 55 | result = ∂(sin(x), x) 56 | 57 | doAssert gradSin(-Pi) == cos(-Pi) 58 | doAssert gradSin(0.0) == cos(0.0) 59 | doAssert gradSin(Pi/2.0) == cos(Pi/2.0) 60 | #+end_src 61 | 62 | Feel free to go crazy on your derivatives. 63 | 64 | Note: the library currently has no introspection functionality to 65 | compute derivatives of user defined functions. For a purely 66 | mathematical procedure it should be rather straight forward. More 67 | complex statements are not really the goal of this library. Its aim is 68 | to provide a convenient way to generate gradients of functions that a) 69 | one is too lazy to write down or b) that might already be a bit 70 | annoying to compute by hand. 71 | 72 | ** Extra fun 73 | 74 | Guess what we can do 😎: 75 | 76 | #+begin_src nim 77 | import unchained 78 | import scinim/experimental/sugar # just used for the mathScope macro 79 | 80 | mathScope: 81 | f(t, a) = ∂(1.0/2.0 * a * t^2, t) 82 | echo "Speed after ", 1.s, ": ", f(1.0.s, 9.81.m•s⁻²) 83 | echo "Speed after ", 2.s, ": ", f(2.0.s, 9.81.m•s⁻²) 84 | echo "Speed after ", 2.s, ": ", f(3.0.s, 9.81.m•s⁻²) 85 | # Speed after 1 Second: 9.81 Meter•Second⁻¹ 86 | # Speed after 2 Second: 19.62 Meter•Second⁻¹ 87 | # Speed after 2 Second: 29.43 Meter•Second⁻¹ 88 | #+end_src 89 | 90 | Doing gradients with units? Pretty neat, huh? 91 | 92 | You want more? 93 | What if you have some measurement uncertainties associated with your 94 | time values? And maybe include the variation of =g= around the world? 95 | 96 | #+begin_src nim 97 | import measuremancer 98 | # And guess what if you have some measurement errors on top of your 99 | # measurement? 100 | echo "Speed after ", 1.s, ": ", f(1.0.s ± 0.05.s, 9.81.m•s⁻² ± 0.03.m•s⁻²) 101 | echo "Speed after ", 2.s, ": ", f(2.0.s ± 0.05.s, 9.81.m•s⁻² ± 0.03.m•s⁻²) 102 | echo "Speed after ", 2.s, ": ", f(3.0.s ± 0.05.s, 9.81.m•s⁻² ± 0.03.m•s⁻²) 103 | # Speed after 1 Second: 9.81 ± 0.491 Meter•Second⁻¹ 104 | # Speed after 2 Second: 19.6 ± 0.494 Meter•Second⁻¹ 105 | # Speed after 2 Second: 29.4 ± 0.499 Meter•Second⁻¹ 106 | #+end_src 107 | 108 | Yup. 109 | -------------------------------------------------------------------------------- /astgrad.nim: -------------------------------------------------------------------------------- 1 | import macros, tables, math 2 | 3 | type 4 | SymbolKind = enum 5 | skPlus, skMinus, skMul, skDiv, skPower, skInvalid 6 | SymbolicVariable = object 7 | n: NimNode # the corresponding nim node 8 | id: uint64 # unique identifier (mainly for debugging) 9 | processed: bool # indicates whether derivative has already been computed for this variable 10 | SymbolicParameter = object 11 | n: NimNode 12 | kind: SymbolKind 13 | SymbolicFunction = object 14 | n: NimNode 15 | processed: bool # required? 16 | Number = distinct 17 | 18 | 19 | var FunctionTab {.compileTime.} = initTable[string, NimNode]() 20 | var DerivativeTab {.compileTime.} = initTable[string, SymbolicFunction]() 21 | macro defineSupportedFunctions(body: untyped): untyped = 22 | for fn in body: 23 | doAssert fn.kind == nnkInfix and fn[0].strVal == "->" 24 | let fnName = fn[1].strVal 25 | let fnId = ident(fnName) 26 | FunctionTab[fnName] = fnId 27 | DerivativeTab[fnName] = SymbolicFunction(n: fn[2], processed: true) 28 | 29 | ## NOTE: some of the following functions are not implemented in Nim atm 30 | defineSupportedFunctions: 31 | sqrt -> 1.0 / 2.0 / sqrt(x) 32 | cbrt -> 1.0 / 3.0 / (cbrt(x)^2.0) 33 | abs2 -> 1.0 * 2.0 * x 34 | inv -> -1.0 * abs2(inv(x)) 35 | log -> 1.0 / x 36 | log10 -> 1.0 / x / log(10) 37 | log2 -> 1.0 / x / log(2.0) 38 | log1p -> 1.0 / (x + 1.0) 39 | exp -> exp(x) 40 | exp2 -> log(2.0) * exp2(x) 41 | expm1 -> exp(x) 42 | sin -> cos(x) 43 | cos -> -sin(x) 44 | tan -> (1.0 + (tan(x)^2)) 45 | sec -> sec(x) * tan(x) 46 | csc -> -csc(x) * cot(x) 47 | cot -> -(1.0 + (cot(x)^2)) 48 | sind -> Pi / 180.0 * cosd(x) 49 | cosd -> -Pi / 180.0 * sind(x) 50 | tand -> Pi / 180.0 * (1.0 + (tand(x)^2)) 51 | secd -> Pi / 180.0 * secd(x) * tand(x) 52 | cscd -> -Pi / 180.0 * cscd(x) * cotd(x) 53 | cotd -> -Pi / 180.0 * (1.0 + (cotd(x)^2)) 54 | arcsin -> 1.0 / sqrt(1.0 - (x^2)) 55 | arccos -> -1.0 / sqrt(1.0 - (x^2)) 56 | arctan -> 1.0 / (1.0 + (x^2)) 57 | arcsec -> 1.0 / abs(x) / sqrt(x^2 - 1.0) 58 | arccsc -> -1.0 / abs(x) / sqrt(x^2 - 1.0) 59 | arccot -> -1.0 / (1.0 + (x^2)) 60 | arcsind -> 180.0 / Pi / sqrt(1.0 - (x^2)) 61 | arccosd -> -180.0 / Pi / sqrt(1.0 - (x^2)) 62 | arctand -> 180.0 / Pi / (1.0 + (x^2)) 63 | arcsecd -> 180.0 / Pi / abs(x) / sqrt(x^2 - 1.0) 64 | arccscd -> -180.0 / Pi / abs(x) / sqrt(x^2 - 1.0) 65 | arccotd -> -180.0 / Pi / (1.0 + (x^2)) 66 | sinh -> cosh(x) 67 | cosh -> sinh(x) 68 | tanh -> sech(x)^2 69 | sech -> -tanh(x) * sech(x) 70 | csch -> -coth(x) * csch(x) 71 | coth -> -(csch(x)^2) 72 | arcsinh -> 1.0 / sqrt(x^2 + 1.0) 73 | arccosh -> 1.0 / sqrt(x^2 - 1.0) 74 | arctanh -> 1.0 / (1.0 - (x^2)) 75 | arcsech -> -1.0 / x / sqrt(1.0 - (x^2)) 76 | arccsch -> -1.0 / abs(x) / sqrt(1.0 + (x^2)) 77 | arccoth -> 1.0 / (1.0 - (x^2)) 78 | deg2rad -> Pi / 180.0 79 | rad2deg -> 180.0 / Pi 80 | erf -> 2.0 * exp(-x*x) / sqrt(Pi) 81 | erfinv -> 0.5 * sqrt(Pi) * exp(erfinv(x) * erfinv(x)) 82 | erfc -> -2.0 * exp(-x*x) / sqrt(Pi) 83 | erfcinv -> -0.5 * sqrt(Pi) * exp(erfcinv(x) * erfcinv(x)) 84 | erfi -> 2.0 * exp(x*x) / sqrt(Pi) 85 | gamma -> digamma(x) * gamma(x) 86 | lgamma -> digamma(x) 87 | digamma -> trigamma(x) 88 | invdigamma -> inv(trigamma(invdigamma(x))) 89 | trigamma -> polygamma(2.0 x) 90 | airyai -> airyaiprime(x) 91 | airybi -> airybiprime(x) 92 | airyaiprime -> x * airyai(x) 93 | airybiprime -> x * airybi(x) 94 | besselj0 -> -besselj1(x) 95 | besselj1 -> (besselj0(x) - besselj(2.0, x)) / 2.0 96 | bessely0 -> -bessely1(x) 97 | bessely1 -> (bessely0(x) - bessely(2.0, x)) / 2.0 98 | erfcx -> (2.0 * x * erfcx(x) - 2.0 / sqrt(Pi)) 99 | dawson -> (1.0 - 2.0 * x * dawson(x)) 100 | 101 | when false: 102 | import hashes 103 | proc hash(x: SymbolicVariable): Hash = 104 | result = result !& hash(x.n.repr) 105 | result = result !& hash(x.id) 106 | result = result !& hash(x.processed) 107 | result = !$ result 108 | 109 | import sets 110 | var NodeSet {.compileTime.} = initHashSet[SymbolicVariable]() 111 | 112 | var IDCounter {.compileTime.} = 0'u64 113 | template getID(): untyped = 114 | inc IDCounter 115 | IDCounter 116 | 117 | proc evaluateFunction(fn: SymbolicFunction, arg: SymbolicVariable): SymbolicVariable = 118 | ## inserts the symbolic variable into the `x` fields and returns a new variable 119 | ## with the evaluated tree as the node 120 | var tree = fn.n 121 | proc insert(n, arg: NimNode): NimNode = 122 | case n.kind 123 | of nnkIdent, nnkSym: 124 | if n.strVal == "x": # this node needs to be replaced 125 | result = arg 126 | else: 127 | result = n 128 | else: 129 | if n.len == 0: result = n 130 | else: 131 | result = newTree(n.kind) 132 | for ch in n: 133 | result.add insert(ch, arg) 134 | let repl = tree.insert(arg.n) 135 | result = SymbolicVariable(n: repl, processed: true, id: getID()) 136 | 137 | proc isNumber(n: NimNode): bool = 138 | # maybe this: ? 139 | (n.kind != nnkSym and n.typeKind in {ntyInt .. ntyUInt64}) or 140 | n.kind in {nnkIntLit .. nnkFloat128Lit} 141 | 142 | proc isNumberLit(n: NimNode): bool = 143 | # maybe this: ? 144 | n.kind in {nnkIntLit .. nnkFloat128Lit} 145 | 146 | proc isNumber(x: SymbolicVariable): bool = x.n.isNumber 147 | proc kind(x: SymbolicVariable): NimNodeKind = x.n.kind 148 | proc `[]`(x: SymbolicVariable, idx: int): SymbolicVariable = 149 | result = SymbolicVariable(n: x.n[idx], processed: x.processed, id: x.id) 150 | 151 | iterator items(x: SymbolicVariable): SymbolicVariable = 152 | for i in 0 ..< x.n.len: 153 | yield x[i] 154 | 155 | proc add(x: var SymbolicVariable, y: SymbolicVariable) = 156 | var n = x.n 157 | n.add y.n 158 | x = SymbolicVariable(n: n, processed: x.processed, id: x.id) 159 | 160 | proc isZero(x: SymbolicVariable): bool = x.n.kind in {nnkFloatLit, nnkFloat64Lit} and x.n.floatVal == 0.0 161 | proc isOne(x: SymbolicVariable): bool = x.n.kind in {nnkFloatLit, nnkFloat64Lit} and x.n.floatVal == 1.0 162 | 163 | proc name(fn: SymbolicFunction): string = result = fn.n.strVal 164 | 165 | proc toSymbolicVariable(n: NimNode, processed = false): SymbolicVariable = 166 | #doAssert n.kind in {nnkIdent, nnkSym, nnkIntLit .. nnkFloat128Lit} 167 | result = SymbolicVariable(n: n, processed: processed, id: getID()) 168 | 169 | proc symbolicOne(): SymbolicVariable = 170 | SymbolicVariable(n: newLit(1.0), processed: true, id: getID()) 171 | 172 | proc symbolicZero(): SymbolicVariable = 173 | SymbolicVariable(n: newLit(0.0), processed: true, id: getID()) 174 | 175 | proc symbolicPower(): SymbolicParameter = 176 | SymbolicParameter(n: ident"^", kind: skPower) 177 | 178 | proc `==`(a, b: SymbolicVariable): bool = 179 | result = a.n == b.n and a.id == b.id 180 | 181 | proc isIndep(a, indep: SymbolicVariable): bool = 182 | ## checks whether `a` is the independent variable. 183 | result = a.n == indep.n 184 | 185 | # not required anymore, we untype the tree 186 | proc san(n: NimNode): NimNode {.inline.} = n 187 | # case n.kind 188 | # of nnkStmtListExpr: result = n[1].san 189 | # of nnkHiddenStdConv, nnkConv: result = n[1].san 190 | # else: result = n 191 | 192 | ## TODO: simplify these such that if the second arg is identity element, not included 193 | proc `-`(n: SymbolicVariable): SymbolicVariable = 194 | result = SymbolicVariable(n: nnkPrefix.newTree(ident"-", n.n.san), processed: true, id: getID()) 195 | 196 | proc setProcessed(x: SymbolicVariable): SymbolicVariable = 197 | result = x 198 | result.n = result.n.san # make sure to sanitize as well 199 | result.processed = true # most likely already true 200 | 201 | proc `+`(x, y: SymbolicVariable): SymbolicVariable = 202 | if x.isZero: result = y.setProcessed 203 | elif y.isZero: result = x.setProcessed 204 | else: result = SymbolicVariable(n: nnkInfix.newTree(ident"+", x.n.san, y.n.san), processed: true, id: getID()) 205 | 206 | proc litDiff(x, y: NimNode): NimNode = 207 | if x.kind == y.kind: 208 | if x.kind == nnkIntLit: 209 | result = newLit(x.intVal - y.intVal) 210 | else: 211 | result = newLit(x.floatVal - y.floatVal) 212 | else: 213 | # use float 214 | template getVal(a: untyped): untyped = 215 | if a.kind == nnkIntLit: a.intVal.float 216 | else: a.floatVal 217 | result = newLit(x.getVal - y.getVal) 218 | 219 | proc `-`(x, y: SymbolicVariable): SymbolicVariable = 220 | if x.isZero: result = -y.setProcessed 221 | elif y.isZero: result = x.setProcessed 222 | elif x == y: result = symbolicZero() 223 | elif x.n.isNumberLit and y.n.isNumberLit: # compute result in place 224 | result = SymbolicVariable(n: litDiff(x.n, y.n), processed: true, id: getID()) 225 | else: result = SymbolicVariable(n: nnkInfix.newTree(ident"-", x.n.san, y.n.san), processed: true, id: getID()) 226 | 227 | proc `-`(x: SymbolicVariable, y: SomeNumber): SymbolicVariable = 228 | result = x - toSymbolicVariable(newLit(y), true) 229 | 230 | proc `*`(x, y: SymbolicVariable): SymbolicVariable = 231 | if x.isOne: result = y.setProcessed 232 | elif y.isOne: result = x.setProcessed 233 | elif x.isZero: result = symbolicZero() 234 | elif y.isZero: result = symbolicZero() 235 | else: 236 | result = SymbolicVariable(n: nnkInfix.newTree(ident"*", x.n.san, y.n.san), processed: true, id: getID()) 237 | 238 | proc `/`(x, y: SymbolicVariable): SymbolicVariable = 239 | # if x is one, default is shortest already 240 | if y.isZero: error("Computation contains division by 0!") 241 | elif x.isZero: result = symbolicZero() 242 | elif y.isOne: result = x.setProcessed 243 | elif x == y: result = symbolicOne() 244 | else: result = SymbolicVariable(n: nnkInfix.newTree(ident"/", x.n.san, y.n.san), processed: true, id: getID()) 245 | 246 | proc `^`(x, y: SymbolicVariable): SymbolicVariable = 247 | # if x is one, default is shortest already 248 | ## XXX: add int literals for powers so that we don't have to force `pow` here! 249 | if y.isOne: result = x.setProcessed 250 | elif y.isZero: result = symbolicOne() 251 | elif x.isZero: result = symbolicZero() 252 | else: result = SymbolicVariable(n: nnkCall.newTree(ident"pow", x.n.san, y.n.san), processed: true, id: getID()) 253 | 254 | proc log(x: SymbolicVariable): SymbolicVariable = 255 | if x.isZero: error("Computation yields log(0) and thus -Inf!") 256 | else: result = SymbolicVariable(n: nnkCall.newTree(ident"log", x.n.san), processed: true, id: getID()) 257 | 258 | proc processExpr(arg, wrt: SymbolicVariable): SymbolicVariable 259 | 260 | proc differentiate(x, wrt: SymbolicVariable): SymbolicVariable = 261 | if x.processed: 262 | result = x 263 | else: 264 | result = processExpr(x, wrt) 265 | doAssert result.processed 266 | result = result.setProcessed 267 | 268 | proc diffPlus(x, y, wrt: SymbolicVariable): SymbolicVariable = 269 | # compute gradient of `x + y` w.r.t. `wrt` 270 | result = differentiate(x, wrt) + differentiate(y, wrt) 271 | 272 | proc diffMinus(x, y, wrt: SymbolicVariable): SymbolicVariable = 273 | # compute gradient of `x - y` w.r.t. `wrt` 274 | result = differentiate(x, wrt) - differentiate(y, wrt) 275 | 276 | proc diffMul(x, y, wrt: SymbolicVariable): SymbolicVariable = 277 | # compute gradient of `x * y` w.r.t. `wrt` 278 | result = differentiate(x, wrt) * y + x * differentiate(y, wrt) 279 | 280 | proc diffDiv(x, y, wrt: SymbolicVariable): SymbolicVariable = 281 | # compute gradient of `x / y` w.r.t. `wrt` 282 | result = differentiate(x, wrt) / y + (-x * differentiate(y, wrt) / (y * y)) 283 | 284 | proc diffPower(x, y, wrt: SymbolicVariable): SymbolicVariable = 285 | # compute gradient of `x ^ y` w.r.t. `wrt` 286 | let xp = differentiate(x, wrt) 287 | let yp = differentiate(y, wrt) 288 | if xp.isZero and yp.isZero: 289 | result = symbolicZero() 290 | elif yp.isZero: 291 | result = y * xp * (x ^ (y - 1.0)) 292 | else: 293 | result = x ^ y * (xp * y / x + yp * log(x)) 294 | 295 | proc differentiate(op: SymbolicParameter, 296 | x, y: SymbolicVariable, 297 | wrt: SymbolicVariable): SymbolicVariable = 298 | case op.kind 299 | of skPlus: result = diffPlus(x, y, wrt) 300 | of skMinus: result = diffMinus(x, y, wrt) 301 | of skMul: result = diffMul(x, y, wrt) 302 | of skDiv: result = diffDiv(x, y, wrt) 303 | of skPower: result = diffPower(x, y, wrt) 304 | of skInvalid: error("Differentiation of `skInvalid` not possible. This is a bug.") 305 | 306 | proc differentiate(fn: SymbolicFunction, arg: SymbolicVariable): SymbolicVariable = 307 | result = evaluateFunction(DerivativeTab[fn.name], arg) 308 | 309 | proc parseSymbolicParameter(x: SymbolicVariable): SymbolicParameter = 310 | doAssert x.kind in {nnkIdent, nnkSym} 311 | case x.n.strVal 312 | of "+": result = SymbolicParameter(n: x.n, kind: skPlus) 313 | of "-": result = SymbolicParameter(n: x.n, kind: skMinus) 314 | of "*": result = SymbolicParameter(n: x.n, kind: skMul) 315 | of "/": result = SymbolicParameter(n: x.n, kind: skDiv) 316 | of "^", "**": result = SymbolicParameter(n: x.n, kind: skPower) 317 | else: result = SymbolicParameter(n: newEmptyNode(), kind: skInvalid) 318 | 319 | proc parseSymbolicFunction(x: SymbolicVariable): SymbolicFunction = 320 | doAssert x.kind in {nnkIdent, nnkSym} 321 | result = SymbolicFunction(n: FunctionTab[x.n.strVal]) 322 | 323 | proc toNimCode(x: SymbolicVariable): NimNode = 324 | ## Converts the symbolic back into nim code. Just means we return the 325 | ## NimNode it contains. However, in the future we will add some simple 326 | ## simplification to act against code explosion. 327 | x.n 328 | 329 | proc handleInfix(arg, wrt: SymbolicVariable): SymbolicVariable = 330 | ## handle infix nodes by calling the correct differentiation function 331 | doAssert arg.kind == nnkInfix 332 | let symbol = parseSymbolicParameter(arg[0]) 333 | result = differentiate(symbol, arg[1], arg[2], 334 | wrt) 335 | 336 | proc handleCall(arg, wrt: SymbolicVariable): SymbolicVariable = 337 | ## Essentially handle the chain rule of function calls (and `pow` calls) 338 | doAssert arg.kind == nnkCall, " is : " & $arg.n.treerepr 339 | # check if call might be an `infix` symbol. If so, patch up and call infix instead 340 | if arg[0].parseSymbolicParameter().kind != skInvalid: 341 | ## XXX: this can go I think. It was due to a bug 342 | error("invalid") 343 | doAssert not arg.processed 344 | var inf = SymbolicVariable(n: nnkInfix.newTree(), processed: arg.processed, id: getID()) 345 | for ch in arg: 346 | inf.add ch 347 | result = handleInfix(inf, wrt) 348 | else: 349 | # regular function call 350 | # for now assume single argument functions, i.e. we can evaluate the argument 351 | # as an expression and there is only one argument 352 | if arg[0].n.strVal == "pow": 353 | # power is special case, as it's the only 2 arg function we support so far 354 | result = differentiate(symbolicPower(), arg[1], arg[2], wrt) 355 | else: 356 | let fn = parseSymbolicFunction(arg[0]) 357 | result = differentiate(arg[1], wrt) * differentiate(fn, arg[1]) # chain rule: outer * inner 358 | 359 | proc handlePrefix(arg, wrt: SymbolicVariable): SymbolicVariable = 360 | ## handle prefix, usually `-` or `+` 361 | expectKind(arg.n, nnkPrefix) 362 | # parse the prefix symbol 363 | let fn = parseSymbolicParameter(arg[0]) 364 | case fn.kind 365 | of skPlus, skMinus: 366 | # prefix is nothing to be handled via differentiation. Merge it into the element thats after 367 | result = differentiate(fn, symbolicZero(), # just add / subtract from a zero 368 | arg[1], 369 | wrt) 370 | else: 371 | error("Invalid prefix: " & $fn.n.repr & " from argument: " & $arg.repr) 372 | 373 | proc processExpr(arg, wrt: SymbolicVariable): SymbolicVariable = 374 | ## The heart of the logic. Handles the different nim nodes and performs 375 | ## the actual differentiation if we are looking at a `nnkSym` or literal 376 | case arg.kind 377 | of nnkSym, nnkIdent, nnkIntLit .. nnkFloat128Lit: 378 | if arg.isIndep(wrt): 379 | result = symbolicOne() 380 | else: 381 | result = symbolicZero() 382 | of nnkInfix: 383 | result = handleInfix(arg, wrt) 384 | of nnkCall: 385 | result = handleCall(arg, wrt) 386 | of nnkHiddenStdConv: 387 | # assume contains literals? 388 | if arg.isNumber or arg.n.typeKind == ntyRange: 389 | result = processExpr(arg[1], wrt) 390 | else: 391 | error("unsupported: " & $arg.kind & " and value " & $arg.n.treerepr) 392 | of nnkPrefix: 393 | result = handlePrefix(arg, wrt) 394 | of nnkStmtListExpr: 395 | doAssert false, "Not required anymore, we untype the tree" 396 | doAssert arg[0].kind == nnkEmpty 397 | result = processExpr(arg[1], wrt) 398 | of nnkConv: 399 | doAssert false, "Not required anymore, we untype the tree" 400 | result = processExpr(arg[1], wrt) 401 | else: error("unsupported: " & $arg.kind & " and value " & $arg.n.treerepr) 402 | 403 | proc sanitizeInput(n: NimNode): NimNode = 404 | # remove all `nnkConv, nnkHiddenStdConv and nnkStmtListExpr` 405 | let tree = n 406 | proc sanitize(n: NimNode): NimNode = 407 | if n.len == 0: 408 | case n.kind 409 | of nnkSym: result = ident(n.strVal) 410 | else: result = n 411 | else: 412 | case n.kind 413 | of nnkConv, nnkHiddenStdConv: result = n[1].sanitize 414 | of nnkStmtListExpr: result = n[1].sanitize 415 | else: 416 | result = newTree(n.kind) 417 | for ch in n: 418 | result.add sanitize(ch) 419 | result = tree.sanitize() 420 | 421 | macro derivative*(arg, wrt: typed): untyped = 422 | ## computes the forward derivative of `arg` (a Nim expression) 423 | ## with respect to `wrt` using symbolic differentiation on the 424 | ## Nim AST 425 | let input = arg.sanitizeInput 426 | result = toNimCode processExpr(toSymbolicVariable(input), toSymbolicVariable(wrt.sanitizeInput)) 427 | 428 | template ∂*(arg, wrt: untyped): untyped = 429 | derivative(arg, wrt) 430 | 431 | macro genHelpers(): untyped = 432 | ## Generate higher order derivative helpers. 433 | ## 434 | ## NOTE: 435 | ## It is really unwise to use the higher orders on functions that 436 | ## get larger after each derivative... :) 437 | let idx = ["²", "³", "⁴", "⁵", "⁶", "⁷", "⁸", "⁹"] 438 | result = newStmtList() 439 | let arg = ident"arg" 440 | let wrt = ident"wrt" 441 | for i, el in idx: 442 | let name = ident("∂" & $el) 443 | var body = newStmtList() 444 | for j in 0 ..< i + 2: 445 | if j == 0: 446 | body = quote do: 447 | ∂(`arg`, `wrt`) 448 | else: 449 | body = quote do: 450 | ∂(`body`, `wrt`) 451 | result.add quote do: 452 | template `name`*(`arg`, `wrt`: untyped): untyped = 453 | `body` 454 | genHelpers() 455 | 456 | when false:# isMainModule: 457 | 458 | let x = 1.0 459 | echo ∂(exp(-3.0 * x) * x * x * x * sin(x), x) 460 | 461 | let x1 = 1.0 462 | echo ∂(1/x^2, x) 463 | 464 | let x = 2.5 465 | echo ∂(x, x) 466 | template printAndCheck(arg, eq: untyped): untyped = 467 | echo "is ", derivative(arg, x), " should be ", eq 468 | echo derivative(arg, x), " is ", abs(derivative(arg, x) - eq) < 1e-4 469 | 470 | printAndCheck(exp(x), exp(x)) 471 | printAndCheck(sin(x), cos(x)) 472 | printAndCheck(cos(x), -sin(x)) 473 | printAndCheck(tanh(x), sech(x)*sech(x)) 474 | 475 | 476 | import ggplotnim, sequtils 477 | # 478 | #proc grad(x, y: float): float = 479 | # #result = derivative(x*y + y*y*y, y) 480 | # result = ∂(-2 * (sech(x) ^ 2) * (sech(x) ^ 2) + -2 * tanh(x) * (2 * (-tanh(x) * sech(x)) * pow(sech(x), 2 - 1.0'f64)), x) 481 | 482 | let xs = linspace(-5.0,5.0,1000) 483 | 484 | #echo ∂(∂(tanh(x), x), x) 485 | #let ys = xs.mapIt(grad(it, it)) 486 | #ggplot(seqsToDf(xs, ys), aes("xs", "ys")) + 487 | # geom_line() + ggsave("/tmp/deriv.pdf") 488 | 489 | 490 | #echo ∂(tanh(x), x) 491 | #echo ∂(sech(x)*sech(x), x) 492 | #echo ∂(-2 * sech(x) ^ 2 * sech(x) ^ 2 - 2 * tanh(x) * (2 * (-tanh(x) * sech(x)) * pow(sech(x), 2 - 1.0'f64)), x) 493 | 494 | 495 | #echo ∂(-2*tanh(x) * sech(x)^2, x) 496 | #echo ∂(sin(x) * cos(x) + pow(tanh(x), 2.0 - 1.0'f64), x) 497 | 498 | 499 | #echo ∂(4*tanh(x)^2 * sech(x)^2 - 2*sech(x)^4, x) 500 | #echo ∂(tanh(x), x) 501 | #echo ∂(∂(tanh(x), x), x) 502 | #echo ∂(∂(∂(tanh(x), x), x), x) 503 | #echo ∂(∂(∂(∂(tanh(x), x), x), x), x) 504 | var df = newDataFrame() 505 | block MultiGrad: 506 | 507 | block NoGrad: 508 | let ys = xs.mapIt(tanh(it)) 509 | let dfLoc = seqsToDf({"x" : xs, "y" : ys, "grad" : 0}) 510 | echo dfLoc 511 | df.add dfLoc 512 | block Grad1: 513 | let ys = xs.mapIt(∂(tanh(it), it)) 514 | let dfLoc = seqsToDf({"x" : xs, "y" : ys, "grad" : 1}) 515 | df.add dfLoc 516 | block Grad2: 517 | let ys = xs.mapIt(∂(∂(tanh(it), it), it)) 518 | let dfLoc = seqsToDf({"x" : xs, "y" : ys, "grad" : 2}) 519 | df.add dfLoc 520 | block Grad3: 521 | let ys = xs.mapIt(∂(∂(∂(tanh(it), it), it), it)) 522 | let dfLoc = seqsToDf({"x" : xs, "y" : ys, "grad" : 3}) 523 | df.add dfLoc 524 | block Grad4: 525 | let ys = xs.mapIt(∂(∂(∂(∂(tanh(it), it), it), it), it)) 526 | let dfLoc = seqsToDf({"x" : xs, "y" : ys, "grad" : 4}) 527 | df.add dfLoc 528 | block Grad5: 529 | let ys = xs.mapIt(∂(∂(∂(∂(∂(tanh(it), it), it), it), it), it)) 530 | let dfLoc = seqsToDf({"x" : xs, "y" : ys, "grad" : 5}) 531 | df.add dfLoc 532 | block Grad6: 533 | let ys = xs.mapIt(∂(∂(∂(∂(∂(∂(tanh(it), it), it), it), it), it), it)) 534 | let dfLoc = seqsToDf({"x" : xs, "y" : ys, "grad" : 6}) 535 | df.add dfLoc 536 | #block Grad7: 537 | # let ys = xs.mapIt(∂⁷(tanh(it), it)) 538 | # let dfLoc = seqsToDf({"x" : xs, "y" : ys, "grad" : 7}) 539 | # df.add dfLoc 540 | #block Grad8: 541 | # let ys = xs.mapIt(∂⁸(tanh(it), it)) 542 | # let dfLoc = seqsToDf({"x" : xs, "y" : ys, "grad" : 8}) 543 | # df.add dfLoc 544 | 545 | ggplot(df, aes("x", "y", color = "grad")) + 546 | geom_line() + 547 | ggsave("/tmp/tanh_derivs.pdf") 548 | 549 | import unchained 550 | import scinim/experimental/sugar 551 | 552 | # guess what we can do 😎 553 | mathScope: 554 | f(t, a) = ∂(1.0/2.0 * a * t^2, t) 555 | echo "Speed after ", 1.s, ": ", f(1.0.s, 9.81.m•s⁻²) 556 | echo "Speed after ", 2.s, ": ", f(2.0.s, 9.81.m•s⁻²) 557 | echo "Speed after ", 2.s, ": ", f(3.0.s, 9.81.m•s⁻²) 558 | 559 | import measuremancer 560 | # And guess what if you have some measurement errors on top of your 561 | # measurement? 562 | echo "Speed after ", 1.s, ": ", f(1.0.s ± 0.05.s, 9.81.m•s⁻² ± 0.03.m•s⁻²) 563 | echo "Speed after ", 2.s, ": ", f(2.0.s ± 0.05.s, 9.81.m•s⁻² ± 0.03.m•s⁻²) 564 | echo "Speed after ", 2.s, ": ", f(3.0.s ± 0.05.s, 9.81.m•s⁻² ± 0.03.m•s⁻²) 565 | -------------------------------------------------------------------------------- /astgrad.nimble: -------------------------------------------------------------------------------- 1 | # Package 2 | 3 | version = "0.1.0" 4 | author = "Vindaar" 5 | description = "Symbolic differentiation at compile time based on the Nim AST" 6 | license = "MIT" 7 | 8 | 9 | # Dependencies 10 | 11 | requires "nim >= 1.6.0" 12 | -------------------------------------------------------------------------------- /changelog.org: -------------------------------------------------------------------------------- 1 | * v0.1.0 2 | - initial version supporting symbolic derivation at compile time using 3 | the Nim AST. Basic simplifications are supported. 4 | --------------------------------------------------------------------------------