├── .gitignore ├── README.md ├── build.sbt └── src ├── main └── scala │ └── io │ └── estatico │ └── macros │ └── ADT.scala └── test └── scala └── io └── estatico └── macros └── ADTSpec.scala /.gitignore: -------------------------------------------------------------------------------- 1 | project/ 2 | target/ 3 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # scala-adt 2 | 3 | Haskell-style Algebraic Data Types for Scala 4 | 5 | ## Usage 6 | 7 | Scala supports ADTs by using sealed traits and case classes and objects. For instance, 8 | if we wanted to implement a `Maybe` type, we could do something like the following - 9 | 10 | ```scala 11 | sealed trait Maybe[+A] 12 | final case class Just[+A](a: A) extends Maybe[A] 13 | case object Nix extends Maybe[Nothing] 14 | ``` 15 | 16 | This can get pretty verbose, especially when there are more constructors. 17 | 18 | In Haskell, we can define an ADT using the following syntax - 19 | 20 | ```haskell 21 | data Maybe a = Just a | Nix 22 | ``` 23 | 24 | Thanks to the `@ADT` macro, we can get pretty close to the Haskell syntax - 25 | 26 | ```scala 27 | @ADT trait Maybe[A] { Just(a: A); Nix } 28 | ``` 29 | 30 | This will generate code semantically equivalent to the following - 31 | 32 | ```scala 33 | sealed trait Maybe[+A] 34 | object Maybe { 35 | final case class Just[+A](a: A) extends Maybe[A] 36 | case object Nix extends Maybe[Nothing] 37 | 38 | object ctors { 39 | val Just = Maybe.Just 40 | val Nix = Maybe.Nix 41 | } 42 | } 43 | ``` 44 | 45 | You can then use your new ADT just as you'd expect. You can add methods to the `trait` 46 | or companion `object` as needed. Constructors are inferred as the first statements defined 47 | in the `trait`. 48 | 49 | ```scala 50 | @ADT trait Maybe[A] { 51 | Just(a: A) 52 | Nix 53 | 54 | def map[B](f: A => B): Maybe[B] = this match { 55 | case Just(x) => Just(f(x)) 56 | case Nix => Nix 57 | } 58 | } 59 | 60 | object Maybe { 61 | def apply[A](a: A): Maybe[A] = { 62 | if (a == null) Nix else Just(a) 63 | } 64 | } 65 | 66 | // Optional, used to import constructors to avoid always qualifying them. 67 | import Maybe.ctors._ 68 | 69 | object Example { 70 | def run() = { 71 | val m1 = Maybe(1) 72 | println(m1.map(_ + 2)) 73 | // prints "Just(3)" 74 | 75 | val m2: Maybe[Int] = Nix 76 | m2 match { 77 | case Just(_) => println(s"Found $x") 78 | case Nix => println("Not found") 79 | } 80 | // prints "Not found" 81 | } 82 | } 83 | ``` 84 | -------------------------------------------------------------------------------- /build.sbt: -------------------------------------------------------------------------------- 1 | name := "scala-adt" 2 | 3 | version := "1.0" 4 | 5 | scalaVersion := "2.11.8" 6 | 7 | scalacOptions ++= Seq("-unchecked", "-deprecation", "-feature", "-Xfatal-warnings") 8 | scalacOptions in Test ++= Seq("-Yrangepos") 9 | 10 | crossScalaVersions := Seq("2.10.2", "2.10.3", "2.10.4", "2.10.5", "2.10.6", "2.11.0", "2.11.1", "2.11.2", "2.11.3", "2.11.4", "2.11.5", "2.11.6", "2.11.7", "2.11.8") 11 | 12 | resolvers += Resolver.sonatypeRepo("snapshots") 13 | resolvers += Resolver.sonatypeRepo("releases") 14 | 15 | addCompilerPlugin("org.scalamacros" % "paradise" % "2.1.0" cross CrossVersion.full) 16 | 17 | libraryDependencies <+= scalaVersion("org.scala-lang" % "scala-reflect" % _) 18 | 19 | libraryDependencies += "org.scalacheck" %% "scalacheck" % "1.13.2" % "test" 20 | //libraryDependencies += "org.specs2" %% "specs2-core" % "3.6.6" % "test" 21 | -------------------------------------------------------------------------------- /src/main/scala/io/estatico/macros/ADT.scala: -------------------------------------------------------------------------------- 1 | package io.estatico.macros 2 | 3 | import scala.annotation.StaticAnnotation 4 | import scala.collection.mutable 5 | import scala.language.experimental.macros 6 | import scala.reflect.macros.whitebox 7 | 8 | class ADT extends StaticAnnotation { 9 | def macroTransform(annottees: Any*): Any = macro ADT.impl 10 | } 11 | 12 | object ADT { 13 | 14 | val MACRO_NAME = "@ADT" 15 | 16 | def impl(c: whitebox.Context)(annottees: c.Expr[Any]*): c.universe.Tree = { 17 | import c.universe._ 18 | 19 | def fail(msg: String) = c.abort(c.enclosingPosition, msg) 20 | 21 | // Executed at the end of this macro 22 | def run(): Tree = annottees match { 23 | // @ADT trait Foo { ... } 24 | case List(Expr(cls: ClassDef)) => runClass(cls) 25 | 26 | // @ADT trait Foo { ... }; object Foo { ... } 27 | case List(Expr(cls: ClassDef), Expr(obj: ModuleDef)) => runClassWithObj(cls, obj) 28 | case List(Expr(obj: ModuleDef), Expr(cls: ClassDef)) => runClassWithObj(cls, obj) 29 | 30 | case _ => fail(s"Invalid $MACRO_NAME usage") 31 | } 32 | 33 | type Ctor = (TypeName, List[CtorArg]) 34 | type CtorArg = (TermName, TypeName) 35 | 36 | def runClass(cls: ClassDef) = runClassWithCompanion(cls, Nil) 37 | 38 | def runClassWithObj(cls: ClassDef, obj: ModuleDef) = { 39 | val (clsName, objName) = (cls.name.toString, obj.name.toString) 40 | if (clsName != objName) fail(s"Companion name mismatch: trait $clsName, object $objName") 41 | val ModuleDef(_, _, objTemplate) = obj 42 | runClassWithCompanion(cls, objTemplate.body) 43 | } 44 | 45 | def runClassWithCompanion(cls: ClassDef, objBody: List[Tree]) = { 46 | val ClassDef(clsMods, clsName, clsParams, clsTemplate) = cls 47 | if (!clsMods.hasFlag(Flag.TRAIT)) fail(s"$MACRO_NAME requires trait") 48 | val (ctors, clsRestBody) = partitionCtorsAndBody(clsTemplate) 49 | q""" 50 | ${mkCompanion(clsName.toTermName, objBody, ctors, clsParams)} 51 | 52 | sealed trait $clsName[..$clsParams] { 53 | ${importCtors(clsName.toTermName, ctors)} 54 | ..$clsRestBody 55 | } 56 | """ 57 | } 58 | 59 | def importCtors(name: TermName, ctors: List[Ctor]) = { 60 | val importNames = ctors.map { case (ctorName, _) => pq"$ctorName" } 61 | q"import $name.{ ..$importNames }" 62 | } 63 | 64 | def mkCompanion 65 | (name: TermName, 66 | objBody: List[Tree], 67 | ctors: List[Ctor], 68 | clsParams: List[TypeDef]) = { 69 | val ctorDefs = ctors.map(mkCtorDef(_, name.toTypeName, getTypeParams(clsParams))) 70 | val ctorVals = ctors.map { case (ctorName, _) => 71 | q"val ${ctorName.toTermName} = $name.${ctorName.toTermName}" 72 | } 73 | val ctorsObj = q"object ctors { ..$ctorVals }" 74 | q""" 75 | object $name { 76 | ..${skipInitDef(objBody)} 77 | ..$ctorDefs 78 | $ctorsObj 79 | } 80 | """ 81 | } 82 | 83 | def getTypeParams(params: List[TypeDef]): List[TypeName] = { 84 | params.collect { 85 | case t: TypeDef if t.mods.hasFlag(Flag.PARAM) => t.name 86 | } 87 | } 88 | 89 | def partitionCtorsAndBody(template: Template): (List[Ctor], List[Tree]) = { 90 | val ctors = new mutable.ListBuffer[Ctor] 91 | val body = new mutable.ListBuffer[Tree] 92 | skipInitDef(template.body).foreach { 93 | // NOTE: We check `if body.isEmpty` to ensure constructors are only inferred 94 | // from the start of the body before any vals, defs, etc. 95 | case Apply(Ident(name: TermName), args) if body.isEmpty => 96 | ctors += ((name.toTypeName, getCtorArgs(args))) 97 | case Ident(name: TermName) if body.isEmpty => 98 | ctors += ((name.toTypeName, Nil)) 99 | 100 | case other => body += other 101 | } 102 | (ctors.toList, body.toList) 103 | } 104 | 105 | lazy val INIT_DEFS = Set("$init$", "") 106 | 107 | def skipInitDef(body: List[Tree]) = { 108 | // NOTE: We ignore the $init$ and methods as they are generated by our quasiquotes. 109 | body.filter { 110 | case DefDef(_, name, _, _, _, _) => !INIT_DEFS.contains(name.toString) 111 | case _ => true 112 | } 113 | } 114 | 115 | def getCtorArgs(args: List[Tree]): List[CtorArg] = { 116 | args.map { 117 | case Typed(Ident(name: TermName), Ident(typ: TypeName)) => (name, typ) 118 | case other => fail(s"Unsupported constructor argument: $other; AST: ${showRaw(other)}") 119 | } 120 | } 121 | 122 | def mkCtorDef(ctor: Ctor, parentName: TypeName, typeParams: List[TypeName]) = { 123 | ctor match { 124 | case (name, Nil) => mkCaseObject(name.toTermName, typeParams, parentName) 125 | case (name, args) => mkCaseClass(name, args, typeParams, parentName) 126 | } 127 | } 128 | 129 | def mkCaseObject(name: TermName, typeParams: List[TypeName], parentName: TypeName) = { 130 | val traitTypeParams = typeParams.map(_ => Ident(TypeName("Nothing"))) 131 | val parent = AppliedTypeTree(Ident(parentName), traitTypeParams) 132 | q"case object ${name.toTermName} extends $parent" 133 | } 134 | 135 | def mkCaseClass(name: TypeName, args: List[CtorArg], typeParams: List[TypeName], parentName: TypeName) = { 136 | val ctorTypeParams = args.map(_._2) 137 | val traitTypeParams = typeParams.map( 138 | p => if (ctorTypeParams.contains(p)) Ident(p) else Ident(TypeName("Nothing")) 139 | ) 140 | val parent = AppliedTypeTree(Ident(parentName), traitTypeParams) 141 | val clsTypeParams = ctorTypeParams.map( 142 | p => TypeDef(Modifiers(Flag.PARAM), p, List(), TypeBoundsTree(EmptyTree, EmptyTree)) 143 | ) 144 | val clsArgs = args.map { case (argName, argType) => 145 | ValDef( 146 | Modifiers(Flag.CASEACCESSOR | Flag.PARAMACCESSOR), 147 | argName, Ident(argType), EmptyTree 148 | ) 149 | } 150 | // final case class $name extends $parent 151 | q"final case class $name[..$clsTypeParams](..$clsArgs) extends $parent" 152 | } 153 | 154 | run() 155 | } 156 | } 157 | -------------------------------------------------------------------------------- /src/test/scala/io/estatico/macros/ADTSpec.scala: -------------------------------------------------------------------------------- 1 | package io.estatico.macros 2 | 3 | import org.scalacheck.Prop.forAll 4 | import org.scalacheck.{Arbitrary, Gen, Properties} 5 | 6 | @ADT trait Maybe[+A] { 7 | Just(a: A) 8 | Nix 9 | 10 | def isJust: Boolean = this match { 11 | case Just(_) => true 12 | case Nix => false 13 | } 14 | 15 | def isNix: Boolean = !isJust 16 | 17 | def map[B](f: A => B): Maybe[B] = this match { 18 | case Just(x) => Just(f(x)) 19 | case Nix => Nix 20 | } 21 | 22 | def toOption: Option[A] = this match { 23 | case Just(x) => Some(x) 24 | case Nix => None 25 | } 26 | } 27 | 28 | object Maybe { 29 | def apply[A](a: A): Maybe[A] = { 30 | if (a == null) Nix else Just(a) 31 | } 32 | 33 | def fromOption[A](oa: Option[A]): Maybe[A] = oa match { 34 | case Some(a) => Just(a) 35 | case None => Nix 36 | } 37 | } 38 | 39 | object ADTSpec extends Properties("@ADT macro") { 40 | 41 | import Maybe.ctors._ 42 | 43 | property("generates Just constructor") = forAll { (x: Int) => 44 | Just(x) == Just(x) 45 | } 46 | 47 | property("generates Nix singleton") = { 48 | Nix == Nix 49 | } 50 | 51 | property("covariant Just constructor") = { 52 | (Just("foo"): Maybe.Just[Any]).isJust 53 | } 54 | 55 | property("retains trait methods") = forAll { (m: Maybe[Int]) => 56 | m.map(identity) == identity(m) 57 | } 58 | 59 | property("retains companion methods") = forAll { (m: Maybe[Int]) => 60 | Maybe.fromOption(m.toOption) == m 61 | } 62 | 63 | implicit def arbMaybe[A : Arbitrary]: Arbitrary[Maybe[A]] = { 64 | Arbitrary(implicitly[Arbitrary[Boolean]].arbitrary.flatMap { 65 | case true => implicitly[Arbitrary[A]].arbitrary.map(Just(_)) 66 | case false => Gen.const(Nix) 67 | }) 68 | } 69 | } 70 | --------------------------------------------------------------------------------