├── .gitignore ├── project ├── build.properties └── build │ ├── ad.scala │ └── target │ └── scala_2.7.7 │ ├── analysis │ ├── dependencies │ ├── external │ ├── generated_files │ ├── hashes │ └── projects │ └── classes │ └── ad.class ├── sbt ├── sbt-launch-0.7.4.jar ├── sbt.boot.properties ├── sbt.cmd └── src └── main └── scala ├── ad.scala └── ad ├── AD.scala ├── Floating.scala ├── Forward.scala ├── Jacobian.scala ├── Mode.scala ├── ModeCompanion.scala └── Reverse.scala /.gitignore: -------------------------------------------------------------------------------- 1 | lib_managed 2 | project/boot 3 | target 4 | -------------------------------------------------------------------------------- /project/build.properties: -------------------------------------------------------------------------------- 1 | #Project properties 2 | #Fri Apr 29 16:16:06 EDT 2011 3 | project.organization=com.comonad 4 | project.name=ad 5 | sbt.version=0.7.4 6 | project.version=0.1 7 | build.scala.versions=2.8.1 8 | project.initialize=false 9 | -------------------------------------------------------------------------------- /project/build/ad.scala: -------------------------------------------------------------------------------- 1 | import sbt._ 2 | 3 | class ad(info: ProjectInfo) extends DefaultProject(info) { 4 | val scalaToolsSnapshots = "Scala Tools Snapshots" at "http://scala-tools.org/repo-snapshots/" 5 | val scalazCore = "org.scalaz" %% "scalaz-core" % "6.0-SNAPSHOT" 6 | } 7 | 8 | -------------------------------------------------------------------------------- /project/build/target/scala_2.7.7/analysis/dependencies: -------------------------------------------------------------------------------- 1 | #Source Dependencies 2 | #Fri Apr 29 22:27:34 EDT 2011 3 | ad.scala= 4 | -------------------------------------------------------------------------------- /project/build/target/scala_2.7.7/analysis/external: -------------------------------------------------------------------------------- 1 | #External Dependencies 2 | #Fri Apr 29 22:27:34 EDT 2011 3 | /Users/ekmett/scala-ad/project/boot/scala-2.7.7/org.scala-tools.sbt/sbt/0.7.4/sbt_2.7.7-0.7.4.jar=ad.scala 4 | /Users/ekmett/scala-ad/project/boot/scala-2.7.7/lib/scala-library.jar=ad.scala 5 | -------------------------------------------------------------------------------- /project/build/target/scala_2.7.7/analysis/generated_files: -------------------------------------------------------------------------------- 1 | #Generated Classes 2 | #Fri Apr 29 22:27:34 EDT 2011 3 | ad.scala=target/scala_2.7.7/classes/ad.class 4 | -------------------------------------------------------------------------------- /project/build/target/scala_2.7.7/analysis/hashes: -------------------------------------------------------------------------------- 1 | #Source Hashes 2 | #Fri Apr 29 22:27:34 EDT 2011 3 | ad.scala=d9d72491c11d5e343ecbd1a30aaf4c2c3a47465e 4 | -------------------------------------------------------------------------------- /project/build/target/scala_2.7.7/analysis/projects: -------------------------------------------------------------------------------- 1 | #Project Definitions 2 | #Fri Apr 29 22:27:34 EDT 2011 3 | ad.scala=ad 4 | -------------------------------------------------------------------------------- /project/build/target/scala_2.7.7/classes/ad.class: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ekmett/scala-ad/c91142ed9b42902f8660111c6719701415d03cd8/project/build/target/scala_2.7.7/classes/ad.class -------------------------------------------------------------------------------- /sbt: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | java $SBT_OPTS -Dfile.encoding=UTF-8 -Xss4M -Xmx1024M -XX:MaxPermSize=256M -XX:NewSize=128M -XX:NewRatio=3 -jar `dirname $0`/sbt-launch-0.7.4.jar @sbt.boot.properties "$@" 4 | -------------------------------------------------------------------------------- /sbt-launch-0.7.4.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ekmett/scala-ad/c91142ed9b42902f8660111c6719701415d03cd8/sbt-launch-0.7.4.jar -------------------------------------------------------------------------------- /sbt.boot.properties: -------------------------------------------------------------------------------- 1 | [scala] 2 | version: 2.7.7 3 | classifiers: sources 4 | 5 | [app] 6 | org: org.scala-tools.sbt 7 | name: sbt 8 | version: read(sbt.version) 9 | class: sbt.xMain 10 | components: xsbti 11 | cross-versioned: true 12 | 13 | [repositories] 14 | local 15 | maven-local 16 | sbt-db: http://databinder.net/repo/, [organization]/[module]/[revision]/[type]s/[artifact](-[classifier]).[ext] 17 | maven-central 18 | scala-tools-releases 19 | scala-tools-snapshots 20 | 21 | [boot] 22 | directory: project/boot 23 | properties: project/build.properties 24 | prompt-create: Project does not exist, create new project? 25 | prompt-fill: true 26 | quick-option: true 27 | 28 | [log] 29 | level: info 30 | 31 | [app-properties] 32 | project.name: quick=set(test), new=prompt(Name), fill=prompt(Name) 33 | project.organization: new=prompt(Organization) 34 | project.version: quick=set(1.0), new=prompt(Version)[1.0], fill=prompt(Version)[1.0] 35 | build.scala.versions: quick=set(2.7.7), new=prompt(Scala version)[2.7.7], fill=prompt(Scala version)[2.7.7] 36 | sbt.version: quick=set(0.7.3), new=prompt(sbt version)[0.7.3], fill=prompt(sbt version)[0.7.3] 37 | project.scratch: quick=set(true) 38 | project.initialize: quick=set(true), new=set(true) 39 | -------------------------------------------------------------------------------- /sbt.cmd: -------------------------------------------------------------------------------- 1 | set SCRIPT_DIR=%~dp0 2 | java %SBT_OPTS% -Dfile.encoding=UTF-8 -Xss4M -Xmx1024M -XX:MaxPermSize=256M -XX:NewSize=128M -XX:NewRatio=3 -jar "%SCRIPT_DIR%sbt-launch-0.7.4.jar" @sbt.boot.properties %* 3 | -------------------------------------------------------------------------------- /src/main/scala/ad.scala: -------------------------------------------------------------------------------- 1 | import scala.collection.mutable.Buffer 2 | 3 | import scalaz._ 4 | import scalaz.Scalaz._ 5 | 6 | package object ad { 7 | trait FF[F[_],G[_],A] { 8 | def apply[S[_]](f : F[AD[S,A]])(implicit mode: Mode[S,A]): G[AD[S,A]] 9 | } 10 | 11 | type UU[A] = FF[Id, Id, A] 12 | type FU[F[_],A] = FF[F, Id, A] 13 | type UF[F[_],A] = FF[Id, F, A] 14 | 15 | def diffa[A](f: UU[A])(implicit A: Numeric[A]) = (x: A) => { 16 | val Forward(y, dy) = f(AD(Forward[A](x, A.one))).guts 17 | (y, dy) 18 | } 19 | 20 | def grada[F[_]:Traverse, A:Numeric](f: FU[F, A]): F[A] => (A, F[A]) = error("TODO") 21 | def grad[F[_]:Traverse, A:Numeric](f: FU[F, A]): F[A] => F[A] = error("TODO") 22 | def grads[F[_]:Traverse, A:Numeric](f: FU[F, A]): F[A] => Cofree_[F,A] = error("TODO") 23 | def jacobians[F[_]:Traverse, G[_]:Functor, A:Numeric](f: FF[F, G, A]): F[A] => G[Cofree_[F,A]] = error("TODO") 24 | 25 | implicit def lift[S[_], A](a: A)(implicit mode: Mode[S, A], A: Numeric[A]): AD[S, A] = AD[S,A](mode.lift(a)) 26 | 27 | def foo[S[_]](x: AD[S,Double])(implicit mode: Mode[S, Double]): AD[S, Double] = x * x + fromInt[Double](1) 28 | 29 | // generalized scala.math 30 | def Pi[A](implicit A: Floating[A]) = A.pi 31 | def fromInt[A](i: Int)(implicit A: Numeric[A]) = A.fromInt(i) 32 | def exp[A](a: A)(implicit A: Floating[A]): A = A.exp(a) 33 | def log[A](a: A)(implicit A: Floating[A]): A = A.log(a) 34 | def sqrt[A](a: A)(implicit A: Floating[A]): A = A.sqrt(a) 35 | def logBase[A](a: A, b: A)(implicit A: Floating[A]): A = A.logBase(a, b) 36 | def expBase[A](a: A, b: A)(implicit A: Floating[A]): A = A.expBase(a, b) 37 | def sin[A](a: A)(implicit A: Floating[A]): A = A.sin(a) 38 | def cos[A](a: A)(implicit A: Floating[A]): A = A.cos(a) 39 | def tan[A](a: A)(implicit A: Floating[A]): A = A.tan(a) 40 | def asin[A](a: A)(implicit A: Floating[A]): A = A.asin(a) 41 | def acos[A](a: A)(implicit A: Floating[A]): A = A.acos(a) 42 | def atan[A](a: A)(implicit A: Floating[A]): A = A.atan(a) 43 | def sinh[A](a: A)(implicit A: Floating[A]): A = A.sinh(a) 44 | def cosh[A](a: A)(implicit A: Floating[A]): A = A.cosh(a) 45 | def tanh[A](a: A)(implicit A: Floating[A]): A = A.tanh(a) 46 | def signum[A](a: A)(implicit A: Numeric[A]) = A.signum(a) 47 | def abs[A](a: A)(implicit A: Numeric[A]) = A.abs(a) 48 | 49 | // tests 50 | def test = diffa(new FF[Id,Id,Double] { def apply[S[_]](x: AD[S, Double])(implicit mode: Mode[S, Double]): AD[S, Double] = foo(x) }) 51 | def test2 = diffa(new FF[Id,Id,Double] { def apply[S[_]](x: AD[S, Double])(implicit mode: Mode[S, Double]): AD[S, Double] = cos(foo(x)) }) 52 | 53 | } 54 | 55 | -------------------------------------------------------------------------------- /src/main/scala/ad/AD.scala: -------------------------------------------------------------------------------- 1 | package ad 2 | 3 | import scalaz._ 4 | import scalaz.Scalaz._ 5 | 6 | class AD[S[_], A](val guts: S[A])(implicit val mode: Mode[S, A]) { 7 | def apply(f: S[A] => S[A]): AD[S, A] = AD(f(guts)) 8 | def +(that: AD[S, A])(implicit A: Numeric[A]) = AD(mode.plus(this.guts,that.guts)) 9 | def -(that: AD[S, A])(implicit A: Numeric[A]) = AD(mode.minus(this.guts,that.guts)) 10 | def *(that: AD[S, A])(implicit A: Numeric[A]) = AD(mode.times(this.guts,that.guts)) 11 | def /(that: AD[S, A])(implicit A: Fractional[A]) = AD(mode.div(this.guts,that.guts)) 12 | } 13 | 14 | object AD { 15 | def apply[S[_], A](value: S[A])(implicit mode: Mode[S, A]) = new AD[S,A](value) 16 | 17 | class ADNumeric[S[_], A](implicit mode: Mode[S, A], A: Numeric[A]) extends Numeric[AD[S,A]] { 18 | def compare(a: AD[S, A], b: AD[S, A]) = mode.compare(a.guts, b.guts) 19 | def plus(a: AD[S, A], b: AD[S, A]): AD[S, A] = AD(mode.plus(a.guts, b.guts)) 20 | def minus(a: AD[S, A], b: AD[S, A]): AD[S, A] = AD(mode.minus(a.guts, b.guts)) 21 | def times(a: AD[S, A], b: AD[S, A]): AD[S, A] = AD(mode.times(a.guts, b.guts)) 22 | def negate(a: AD[S, A]): AD[S, A] = a (mode negate _) 23 | def fromInt(a: Int): AD[S, A] = AD(mode.lift(A.fromInt(a))) 24 | def toInt(a: AD[S, A]): Int = mode.toInt(a.guts) // derivative is 0 wherever defined, so this is grudgingly ok 25 | def toLong(a: AD[S, A]): Long = mode.toLong(a.guts) // derivative is 0 wherever defined, so this is grudgingly ok 26 | def toFloat(a: AD[S, A]): Float = mode.toFloat(a.guts) 27 | def toDouble(a: AD[S, A]): Double = mode.toDouble(a.guts) 28 | override def abs(a: AD[S, A]): AD[S, A] = a (mode abs _) 29 | override def signum(a: AD[S, A]): Int = mode.signum(a.guts) 30 | } 31 | 32 | implicit def ADIsNumeric[S[_],A](implicit mode: Mode[S, A], A: Numeric[A]) : Numeric[AD[S,A]] = new ADNumeric[S,A]() 33 | 34 | class ADFractional[S[_], A](implicit mode: Mode[S, A], A: Fractional[A]) extends ADNumeric[S,A] with Fractional[AD[S,A]] { 35 | def div(a: AD[S, A], b: AD[S, A]): AD[S, A] = AD(mode.div(a.guts,b.guts)) 36 | } 37 | 38 | implicit def ADIsFractional[S[_],A](implicit mode: Mode[S, A], A: Fractional[A]) : Fractional[AD[S,A]] = new ADFractional[S,A]() 39 | 40 | class ADFloating[S[_], A](implicit mode: Mode[S, A], A: Floating[A]) extends ADFractional[S,A] with Floating[AD[S,A]] { 41 | lazy val pi: AD[S,A] = AD(mode.lift(A.pi)) 42 | def log(a: AD[S,A]): AD[S,A] = AD(mode.log(a.guts)) 43 | def exp(a: AD[S,A]): AD[S,A] = AD(mode.exp(a.guts)) 44 | def sin(a: AD[S,A]): AD[S,A] = AD(mode.sin(a.guts)) 45 | def cos(a: AD[S,A]): AD[S,A] = AD(mode.cos(a.guts)) 46 | override def tan(a: AD[S,A]): AD[S,A] = AD(mode.tan(a.guts)) 47 | def asin(a: AD[S,A]): AD[S,A] = AD(mode.asin(a.guts)) 48 | def acos(a: AD[S,A]): AD[S,A] = AD(mode.acos(a.guts)) 49 | def atan(a: AD[S,A]): AD[S,A] = AD(mode.atan(a.guts)) 50 | def sinh(a: AD[S,A]): AD[S,A] = AD(mode.sinh(a.guts)) 51 | def cosh(a: AD[S,A]): AD[S,A] = AD(mode.cosh(a.guts)) 52 | override def tanh(a: AD[S,A]): AD[S,A] = AD(mode.tanh(a.guts)) 53 | } 54 | 55 | implicit def ADIsFloating[S[_], A](implicit mode: Mode[S, A], A: Floating[A]): Floating[AD[S, A]] = new ADFloating[S, A]() 56 | } 57 | -------------------------------------------------------------------------------- /src/main/scala/ad/Floating.scala: -------------------------------------------------------------------------------- 1 | package ad 2 | 3 | import scala._ 4 | 5 | trait Floating[A] extends Fractional[A] { 6 | def one_half: A = div(one, fromInt(2)) 7 | def pi: A 8 | def exp(a: A): A 9 | def log(a: A): A 10 | def sqrt(a: A): A = expBase(a, one_half) 11 | def logBase(a: A, b: A): A = div(log(b), log(a)) 12 | def expBase(a: A, b: A): A = exp(times(log(a), b)) 13 | def sin(a: A): A 14 | def cos(a: A): A 15 | def tan(a: A): A = sin(a) / cos(a) 16 | def asin(a: A): A 17 | def acos(a: A): A 18 | def atan(a: A): A 19 | def sinh(a: A): A 20 | def cosh(a: A): A 21 | def tanh(a: A): A = sinh(a) / cosh(a) 22 | } 23 | 24 | object Floating { 25 | implicit object DoubleFloating extends Floating[Double] { 26 | def div(a: Double, b: Double): Double = a / b 27 | def toDouble(a: Double): Double = a 28 | def toFloat(a: Double): Float = a toFloat 29 | def toLong(a: Double): Long = a toLong 30 | def toInt(a: Double): Int = a toInt 31 | def fromInt(a: Int): Double = a 32 | def negate(a: Double): Double = -a 33 | def times(a: Double, b: Double) = a * b 34 | def minus(a: Double, b: Double) = a - b 35 | def plus(a: Double, b: Double) = a + b 36 | def compare(a: Double, b: Double) = 37 | if (a < b) -1 38 | else if (a > b) 1 39 | else 0 40 | val pi = math.Pi 41 | def exp(a: Double) = math.exp(a) 42 | def log(a: Double) = math.log(a) 43 | override def sqrt(a: Double) = math.sqrt(a) 44 | def sin(a: Double) = math.sin(a) 45 | def cos(a: Double) = math.cos(a) 46 | override def tan(a: Double) = math.tan(a) 47 | def asin(a: Double) = math.asin(a) 48 | def acos(a: Double) = math.acos(a) 49 | def atan(a: Double) = math.atan(a) 50 | def sinh(a: Double) = math.sinh(a) 51 | def cosh(a: Double) = math.cosh(a) 52 | override def tanh(a: Double) = math.tanh(a) 53 | } 54 | } 55 | -------------------------------------------------------------------------------- /src/main/scala/ad/Forward.scala: -------------------------------------------------------------------------------- 1 | package ad 2 | 3 | import scalaz._ 4 | import scalaz.Scalaz._ 5 | 6 | case class Forward[+A](primal: A, tangent: A) 7 | 8 | object Forward extends ModeCompanion { 9 | implicit def ForwardMode[A](implicit A0: Numeric[A]) : Mode[Forward, A] = new Jacobian[Forward, Id, A] { 10 | val A = A0 11 | 12 | def lift(a: A) = Forward[A](a, A.zero) 13 | 14 | /* 15 | def times( 16 | a: Forward[A], 17 | b: Forward[A] 18 | ) : Forward[A] = new Forward[A]( 19 | A.times(a.primal, b.primal), 20 | A.plus(A.times(a.primal, b.tangent), A.times(a.tangent, b.primal)) 21 | ) 22 | */ 23 | 24 | def D : Mode[Id, A] = Mode.IdMode[A](A) 25 | 26 | def primal(a: Forward[A]): A = a.primal 27 | 28 | def unary(f: A => A, dadb : => A, b: Forward[A]) = Forward[A](f(b.primal), A.times(dadb, b.tangent)) 29 | 30 | def lift1(f : A => A, df: A => A, b: Forward[A]) = { 31 | val Forward(pb, db) = b 32 | val dadb = df(pb) 33 | Forward[A](f(pb), A.times(dadb, db)) 34 | } 35 | 36 | def lift1_(f: A => A, df: (A,A) => A, b: Forward[A]): Forward[A] = { 37 | val Forward(pb, db) = b 38 | val a = f(pb) 39 | Forward[A](a, A.times(df(a, pb), db)) 40 | } 41 | 42 | def binary(f: (A,A) => A, dadb: => A, dadc: => A, b: Forward[A], c: Forward[A]): Forward[A] = 43 | Forward[A](f(b.primal,c.primal), A.plus(A.times(dadb,b.tangent), A.times(dadc,c.tangent))) 44 | 45 | def lift2 (f: (A,A) => A, df: (A,A) => (A,A), b: Forward[A], c: Forward[A]): Forward[A] = { 46 | val Forward(pb, db) = b 47 | val Forward(pc, dc) = c 48 | val a = f(pb,pc) 49 | val (dadb, dadc) = df(pb, pc) 50 | Forward[A](a, A.plus(A.times(dadb,db),A.times(dc,dadc))) 51 | } 52 | 53 | def lift2_(f: (A,A) => A, df: (A,A,A) => (A,A), b: Forward[A], c: Forward[A]): Forward[A] = { 54 | val Forward(pb, db) = b 55 | val Forward(pc, dc) = c 56 | val a = f(pb,pc) 57 | val (dadb, dadc) = df(a, pb, pc) 58 | Forward[A](a, A.plus(A.times(dadb, db),A.times(dc,dadc))) 59 | } 60 | 61 | def vdiv(a: Forward[A], b: A)(implicit A: Fractional[A]): Forward[A] = 62 | Forward[A](A.div(a.primal, b), A.div(a.tangent, b)) 63 | def vtimes(a: Forward[A], b: A): Forward[A] = 64 | Forward[A](A.times(a.primal, b), A.times(a.tangent, b)) 65 | } 66 | 67 | def diffa[A](f: UU[A])(implicit A: Numeric[A]) = (x: A) => { 68 | val Forward(y, dy) = f(AD(Forward[A](x, A.one))).guts 69 | (y, dy) 70 | } 71 | } 72 | -------------------------------------------------------------------------------- /src/main/scala/ad/Jacobian.scala: -------------------------------------------------------------------------------- 1 | package ad 2 | 3 | import scala.collection.mutable.Buffer 4 | import scalaz._ 5 | import scalaz.Scalaz._ 6 | 7 | trait Jacobian[S[_], D[_], A] extends Mode[S, A] { 8 | // interface 9 | def D: Mode[D, A] 10 | def primal(f: S[A]): A 11 | def unary(f: A => A, dadb : => D[A], b: S[A]): S[A] 12 | def lift1(f : A => A, df: D[A] => D[A], b: S[A]): S[A] 13 | def lift1_(f: A => A, df: (D[A],D[A]) => D[A], b: S[A]): S[A] 14 | def binary(f: (A,A) => A, dadb: => D[A], dadc: => D[A], b: S[A], c: S[A]): S[A] 15 | def lift2 (f: (A,A) => A, df: (D[A],D[A]) => (D[A],D[A]), b: S[A], c: S[A]): S[A] 16 | def lift2_(f: (A,A) => A, df: (D[A],D[A],D[A]) => (D[A],D[A]), b: S[A], c: S[A]): S[A] 17 | 18 | // the automatic derivation of automatic differentiation 19 | def compare(a: S[A], b: S[A]): Int = A.compare(primal(a), primal(b)) 20 | override def abs(a: S[A]): S[A] = if (A.compare(primal(a),A.zero) == -1) negate(a) else a 21 | override def signum(a: S[A]): Int = A.signum(primal(a)) 22 | def toLong(a: S[A]): Long = A.toLong(primal(a)) 23 | def toInt(a: S[A]): Int = A.toInt(primal(a)) 24 | def toFloat(a: S[A]): Nothing = error("Jacobian.toFloat disallowed") 25 | def toDouble(a: S[A]): Nothing = error("Jacobian.toDouble disallowed") 26 | def plus(a: S[A], b: S[A]): S[A] = binary(A.plus(_,_),D.one,D.one,a,b) 27 | def times(a: S[A], b: S[A]): S[A] = binary(A.times(_,_),D.lift(primal(b)),D.lift(primal(a)),a,b) 28 | def minus(a: S[A], b: S[A]): S[A] = binary(A.minus(_,_),D.one,D.negate(D.one),a,b) 29 | def negate(a: S[A]) = unary(A negate _, D.negate(D.one), a) 30 | def recip(a: S[A])(implicit A: Fractional[A]) = lift1_(x => A.div(A.one, x), (y, x) => D.negate(D.times(y, y)), a) 31 | def pi(implicit A: Floating[A]): S[A] = lift(A.pi) 32 | def div(a: S[A], b: S[A])(implicit A: Fractional[A]) = times(a, recip(b)) 33 | def exp(a: S[A])(implicit A: Floating[A]): S[A] = lift1_(A exp _, (y: D[A], x : D[A]) => y, a) 34 | def log(a: S[A])(implicit A: Floating[A]): S[A] = lift1(A log _, x => D.div(D.one, x), a) 35 | def sin(a: S[A])(implicit A: Floating[A]): S[A] = lift1(A sin _, D cos _, a) 36 | def cos(a: S[A])(implicit A: Floating[A]): S[A] = lift1(A cos _, x => D negate (D sin x), a) 37 | def sinh(a: S[A])(implicit A: Floating[A]): S[A] = lift1(A sinh _, D cosh _, a) 38 | def cosh(a: S[A])(implicit A: Floating[A]): S[A] = lift1(A cosh _, D sinh _, a) 39 | def asin(a: S[A])(implicit A: Floating[A]): S[A] = lift1(A asin _, x => D.div(D.one, D.sqrt(D.minus(D.one, D.times(x,x)))), a) 40 | def acos(a: S[A])(implicit A: Floating[A]): S[A] = lift1(A acos _, x => D.negate(D.div(D.one, D.sqrt(D.minus(D.one, D.times(x,x))))), a) 41 | def atan(a: S[A])(implicit A: Floating[A]): S[A] = lift1(A atan _, x => D.div(D.one, D.plus(D.one, D.times(x, x))), a) 42 | } 43 | -------------------------------------------------------------------------------- /src/main/scala/ad/Mode.scala: -------------------------------------------------------------------------------- 1 | package ad 2 | 3 | import scala.math._ 4 | import scalaz._ 5 | import scalaz.Scalaz._ 6 | 7 | trait Mode[S[_], A] extends Numeric[S[A]] { m => 8 | def A: Numeric[A] 9 | def lift(a: A): S[A] 10 | 11 | def fromInt(i: Int): S[A] = lift(A fromInt i) 12 | 13 | def vtimes(a: S[A], b: A): S[A] 14 | // def timesv(a: A, b: S[A]): S[A] 15 | 16 | def div(a: S[A], b: S[A])(implicit A: Fractional[A]): S[A] 17 | def vdiv(a: S[A], b: A)(implicit A: Fractional[A]): S[A] 18 | 19 | def pi(implicit A: Floating[A]): S[A] 20 | def exp(a: S[A])(implicit A: Floating[A]): S[A] 21 | def log(a: S[A])(implicit A: Floating[A]): S[A] 22 | def sin(a: S[A])(implicit A: Floating[A]): S[A] 23 | def cos(a: S[A])(implicit A: Floating[A]): S[A] 24 | def tan(a: S[A])(implicit A: Floating[A]): S[A] = div(sin(a), cos(a)) 25 | def sinh(a: S[A])(implicit A: Floating[A]): S[A] 26 | def cosh(a: S[A])(implicit A: Floating[A]): S[A] 27 | def tanh(a: S[A])(implicit A: Floating[A]): S[A] = div(sinh(a), cosh(a)) 28 | def asin(a: S[A])(implicit A: Floating[A]): S[A] 29 | def acos(a: S[A])(implicit A: Floating[A]): S[A] 30 | def atan(a: S[A])(implicit A: Floating[A]): S[A] 31 | 32 | def toFractional(implicit A: Fractional[A]): Fractional[S[A]] = new Mode.FractionalModeProxy[S,A](this,A) 33 | def toFloating(implicit A: Floating[A]): Floating[S[A]] = new Mode.FloatingModeProxy[S,A](this,A) 34 | } 35 | 36 | object Mode { 37 | class FractionalModeProxy[S[_],A](S: Mode[S,A], A: Fractional[A]) extends Fractional[S[A]] { 38 | override def abs(a: S[A]): S[A] = S.abs(a) 39 | override def signum(a: S[A]): Int = S.signum(a) 40 | def compare(a: S[A], b: S[A]): Int = S.compare(a, b) 41 | def plus(a: S[A], b: S[A]): S[A] = S.plus(a, b) 42 | def times(a: S[A], b: S[A]): S[A] = S.times(a, b) 43 | def minus(a: S[A], b: S[A]): S[A] = S.minus(a, b) 44 | def negate(a: S[A]): S[A] = S.negate(a) 45 | def fromInt(i: Int): S[A] = S.fromInt(i) 46 | def toInt(a: S[A]): Int = S.toInt(a) 47 | def toLong(a: S[A]): Long = S.toLong(a) 48 | def toFloat(a: S[A]): Float = S.toFloat(a) 49 | def toDouble(a: S[A]): Double = S.toDouble(a) 50 | def div(a: S[A], b: S[A]): S[A] = S.div(a, b)(A) 51 | } 52 | implicit def FractionalMode[S[_],A](S: Mode[S,A])(implicit A: Fractional[A]): Fractional[S[A]] = S.toFractional 53 | 54 | class FloatingModeProxy[S[_],A](S: Mode[S,A], A: Floating[A]) extends FractionalModeProxy[S,A](S,A) with Floating[S[A]] { 55 | def pi: S[A] = S.pi(A) 56 | def exp(a: S[A]): S[A] = S.exp(a)(A) 57 | def log(a: S[A]): S[A] = S.log(a)(A) 58 | def sin(a: S[A]): S[A] = S.sin(a)(A) 59 | def cos(a: S[A]): S[A] = S.cos(a)(A) 60 | override def tan(a: S[A]): S[A] = S.tan(a)(A) 61 | def sinh(a: S[A]): S[A] = S.sinh(a)(A) 62 | def cosh(a: S[A]): S[A] = S.cosh(a)(A) 63 | override def tanh(a: S[A]): S[A] = S.tanh(a)(A) 64 | def asin(a: S[A]): S[A] = S.asin(a)(A) 65 | def acos(a: S[A]): S[A] = S.acos(a)(A) 66 | def atan(a: S[A]): S[A] = S.atan(a)(A) 67 | } 68 | implicit def FloatingMode[S[_],A](S: Mode[S,A])(implicit A: Floating[A]): Floating[S[A]] = S.toFloating 69 | 70 | def IdMode[A](implicit num: Numeric[A]) : Mode[Id, A] = new Mode[Id, A] { 71 | val A = num 72 | def lift(a: A): A = a 73 | def compare(a: A, b: A): Int = A.compare(a, b) 74 | def plus(a: A, b: A): A = A.plus(a, b) 75 | def times(a: A, b: A): A = A.times(a, b) 76 | def minus(a: A, b: A): A = A.minus(a, b) 77 | override def abs(a: A): A = A abs a 78 | override def signum(a: A): Int = A signum a 79 | def toLong(a: A): Long = A toLong a 80 | def toInt(a: A): Int = A toInt a 81 | def toFloat(a: A): Float = A toFloat a 82 | def toDouble(a: A): Double = A toDouble a 83 | def negate(a: A) = A negate a 84 | def pi(implicit A: Floating[A]): A = A.pi 85 | def div(a: A, b: A)(implicit A: Fractional[A]): A = A.div(a,b) 86 | def exp(a: A)(implicit A: Floating[A]): A = A exp a 87 | def log(a: A)(implicit A: Floating[A]): A = A log a 88 | def sin(a: A)(implicit A: Floating[A]): A = A sin a 89 | def cos(a: A)(implicit A: Floating[A]): A = A cos a 90 | override def tan(a: A)(implicit A: Floating[A]): A = A tan a 91 | def sinh(a: A)(implicit A: Floating[A]): A = A sinh a 92 | def cosh(a: A)(implicit A: Floating[A]): A = A cosh a 93 | override def tanh(a: A)(implicit A: Floating[A]): A = A tanh a 94 | def asin(a: A)(implicit A: Floating[A]): A = A asin a 95 | def acos(a: A)(implicit A: Floating[A]): A = A acos a 96 | def atan(a: A)(implicit A: Floating[A]): A = A atan a 97 | def vtimes(a: A, b: A): A = A.times(a,b) 98 | def vdiv(a: A, b: A)(implicit A: Fractional[A]): A = A.div(a,b) 99 | override def toFractional(implicit A: Fractional[A]): Fractional[A] = A 100 | override def toFloating(implicit A: Floating[A]): Floating[A] = A 101 | } 102 | } 103 | -------------------------------------------------------------------------------- /src/main/scala/ad/ModeCompanion.scala: -------------------------------------------------------------------------------- 1 | package ad 2 | 3 | import scala.collection.mutable.Buffer 4 | import scalaz._ 5 | import scalaz.Scalaz._ 6 | 7 | trait ModeCompanion { 8 | def diffa[A:Numeric](f: UU[A]): A => (A, A) 9 | def diff[A:Numeric](f: UU[A]): A => A = a => diffa(f).apply(a)._2 10 | 11 | // tests 12 | def test = diffa(new FF[Id,Id,Double] { def apply[S[_]](x: AD[S, Double])(implicit mode: Mode[S, Double]): AD[S, Double] = foo(x) }) 13 | def test2 = diffa(new FF[Id,Id,Double] { def apply[S[_]](x: AD[S, Double])(implicit mode: Mode[S, Double]): AD[S, Double] = cos(foo(x)) }) 14 | def test3 = diffa(new FF[Id,Id,Double] { def apply[S[_]](x: AD[S, Double])(implicit mode: Mode[S, Double]): AD[S, Double] = sin(x) }) 15 | } 16 | -------------------------------------------------------------------------------- /src/main/scala/ad/Reverse.scala: -------------------------------------------------------------------------------- 1 | package ad 2 | 3 | import scala.collection.mutable.Buffer 4 | import scalaz._ 5 | import scalaz.Scalaz._ 6 | 7 | case class Reverse[+A](primal: A, slot: Int) 8 | 9 | object Reverse extends ModeCompanion { 10 | private [ad] trait Entry[+A] 11 | private [ad] case object Zero extends Entry[Nothing] 12 | private [ad] case object Var extends Entry[Nothing] 13 | private [ad] case class Unary[+A](di: A, i: Int) extends Entry[A] 14 | private [ad] case class Binary[+A](di: A, i: Int, dj: A, j: Int) extends Entry[A] 15 | 16 | // TODO: the problem here is the A argument to Tape, we need to coerce all of other arguments to make this happy 17 | private[ad] class Tape[A](implicit val A: Numeric[A]) extends Jacobian[Reverse, Id, A] { 18 | val D : Mode[Id, A] = Mode.IdMode[A] 19 | 20 | def lift(a: A) = Reverse[A](a, 0) 21 | def primal(a: Reverse[A]): A = a.primal 22 | val buffer = Buffer[Entry[A]](Zero) 23 | def pushSlot(e: Entry[A]) : Int = synchronized { 24 | val len = buffer.length 25 | buffer += e 26 | len 27 | } 28 | def push(a: A, e: Entry[A]): Reverse[A] = Reverse[A](a, pushSlot(e)) 29 | def fresh(a: A): Reverse[A] = push(a, Var) 30 | 31 | def vtimes(a: Reverse[A], b: A): Reverse[A] = unary(A.times(_, b), D.lift(b), a) 32 | def vdiv(a: Reverse[A], b: A)(implicit A: Fractional[A]): Reverse[A] = unary(A.div(_, b), D.lift(A.div(A.one, b)), a) 33 | 34 | def sensitivities(top : Int): Reverse[A] => A = { 35 | var result = Buffer.tabulate[A](top + 1)(n => if (n == top) A.one else A.zero) 36 | top max 1 to 1 by -1 foreach { 37 | n => buffer(n) match { 38 | case Zero => () 39 | case Var => () 40 | case Unary(dadb, bix) => result.update(bix, A.plus(result(bix), A.times(dadb,result(n)))) 41 | case Binary(dadb, bix, dadc, cix) => { 42 | result.update(bix, A.plus(result(bix), A.times(dadb, result(n)))) 43 | result.update(cix, A.plus(result(cix), A.times(dadc, result(n)))) 44 | } 45 | } 46 | } 47 | (x : Reverse[A]) => result(x.slot) 48 | } 49 | 50 | def unary(f: A => A, dadb : => A, b: Reverse[A]) = 51 | Reverse[A]( f(b.primal), 52 | if (b.slot == 0) 0 53 | else pushSlot(Unary[A](dadb, b.slot)) 54 | ) 55 | 56 | def lift1(f : A => A, df: A => A, b: Reverse[A]): Reverse[A] = unary(f, df(b.primal), b) 57 | def lift1_(f: A => A, df: (A,A) => A, b: Reverse[A]): Reverse[A] = { 58 | val pb = b.primal 59 | val a = f(pb) 60 | unary(_ => a, df(a,pb), b) 61 | } 62 | 63 | def binary(f: (A, A) => A, dadb: => A, dadc: => A, b: Reverse[A], c: Reverse[A]) = 64 | Reverse[A]( f(b.primal,c.primal), 65 | if (b.slot == 0) { 66 | if (c.slot == 0) 0 67 | else pushSlot(Unary[A](dadc, c.slot)) 68 | } else { 69 | if (c.slot == 0) pushSlot(Unary[A](dadb, b.slot)) 70 | else pushSlot(Binary[A](dadb, b.slot, dadc, c.slot)) 71 | } 72 | ) 73 | def lift2(f: (A,A) => A, df: (A,A) => (A,A), b: Reverse[A], c: Reverse[A]): Reverse[A] = { 74 | val (dadb, dadc) = df(b.primal, c.primal) 75 | binary(f, dadb, dadc, b, c) 76 | } 77 | 78 | def lift2_(f: (A,A) => A, df: (A,A,A) => (A,A), b: Reverse[A], c: Reverse[A]): Reverse[A] = { 79 | val pb = b.primal 80 | val pc = c.primal 81 | val a = f(pb, pc) 82 | val (dadb, dadc) = df(a, pb, pc) 83 | binary((_,_) => a, dadb, dadc, b, c) 84 | } 85 | } 86 | 87 | def diffa[A:Numeric](f: UU[A]) = (a:A) => { 88 | val tape = new Tape[A]() 89 | implicit val mode : Mode[Reverse, A] = tape 90 | val x = tape.fresh(a) 91 | val y = f(AD(x)).guts 92 | val ybar = tape.sensitivities(y.slot) 93 | (y.primal, ybar(x)) 94 | } 95 | } 96 | 97 | --------------------------------------------------------------------------------