├── .github └── workflows │ └── ci.yml ├── .gitignore ├── README.md ├── build.sbt ├── hello ├── Hello.scala ├── counter.yml ├── methods.csv └── results.csv ├── plugin └── src │ └── main │ ├── resources │ └── plugin.properties │ └── scala │ ├── Phases.scala │ ├── Plugin.scala │ └── Setting.scala ├── project └── build.properties └── runtime └── src └── main └── scala └── Counter.scala /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: test 2 | on: 3 | push: 4 | branches: 5 | - master 6 | pull_request: 7 | jobs: 8 | test: 9 | strategy: 10 | fail-fast: false 11 | matrix: 12 | include: 13 | - java: 17 14 | os: ubuntu-latest 15 | runs-on: ${{matrix.os}} 16 | steps: 17 | - uses: actions/checkout@v3 18 | - uses: coursier/cache-action@v6 19 | - uses: actions/setup-java@v2 20 | with: 21 | distribution: temurin 22 | java-version: ${{matrix.java}} 23 | - name: test 24 | run: | 25 | sbt counter/publishLocal hello/run 26 | cat hello/methods.csv 27 | cat hello/results.csv 28 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.DS_Store 2 | *.class 3 | *.tasty 4 | *.hasTasty 5 | *.log 6 | *.swp 7 | *~ 8 | tags 9 | 10 | # sbt specific 11 | dist/* 12 | target/ 13 | lib_managed/ 14 | src_managed/ 15 | project/boot/ 16 | project/plugins/project/ 17 | project/local-plugins.sbt 18 | .history 19 | .ensime 20 | .ensime_cache/ 21 | .sbt-scripted/ 22 | local.sbt 23 | 24 | # npm 25 | node_modules 26 | 27 | # VS Code 28 | .vscode/ 29 | # Metals 30 | .bloop/ 31 | .metals/ 32 | metals.sbt 33 | 34 | # Scala-IDE specific 35 | .scala_dependencies 36 | .cache 37 | .cache-main 38 | .cache-tests 39 | .classpath 40 | .project 41 | .settings 42 | classes/ 43 | */bin/ 44 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Scala 3 Compiler Plugin Example 2 | 3 | A Scala 3 compiler plugin to count the most frequently called methods in a program. 4 | 5 | ## Usage 6 | 7 | First, clone the repo and publish the plugin locally: 8 | 9 | ``` 10 | sbt > plugin/publishLocal 11 | ``` 12 | 13 | Then enable the plugin in your SBT build: 14 | 15 | ``` scala 16 | libraryDependencies += "org.mycompany" %% "scala-counter-runtime" % "0.1.0", 17 | 18 | libraryDependencies += compilerPlugin("org.mycompany" %% "scala-counter-plugin" % "0.1.0") 19 | ``` 20 | 21 | Now compile your program, a file named `methods.csv` will be generated: 22 | 23 | ``` bash 24 | # id, method, class, top-level class, file, line 25 | 0, main, Hello$, Hello$, hello/Hello.scala, 1 26 | 1, foo, Hello$, Hello$, Hello.scala, 5 27 | 2, bar, Hello$, Hello$, Hello.scala, 10 28 | ``` 29 | 30 | Run your program with some sample input, a file named `results.csv` will be generated: 31 | 32 | ``` bash 33 | # id, calls 34 | 0, 1 35 | 1, 6 36 | 2, 1 37 | ``` 38 | 39 | You can use standard tools like [xsv](https://github.com/BurntSushi/xsv) to join the two files: 40 | 41 | ``` bash 42 | xsv join 1 hello/methods.csv 1 hello/results.csv | xsv table # pretty print 43 | xsv join 1 hello/methods.csv 1 hello/results.csv > joined.csv # for input to spreadsheet 44 | ``` 45 | 46 | You can also supply a config file to the plugin in the SBT build: 47 | 48 | ``` scala 49 | scalacOptions += "-P:counter:hello/counter.yml" 50 | ``` 51 | 52 | The config file has the following format: 53 | 54 | ``` yml 55 | methodsCSV: hello/methods.csv 56 | resultsCSV: hello/results.csv 57 | ``` 58 | 59 | Please check the configuration for the subproject `hello` in 60 | [build.sbt](build.sbt) for more detail. 61 | 62 | ## Development 63 | 64 | First, publish the plugin locally: 65 | 66 | ``` 67 | sbt > counter/publishLocal 68 | ``` 69 | 70 | Run test 71 | 72 | ``` 73 | sbt > hello/compile; hello/run 74 | ``` 75 | 76 | Check the files under `hello`: 77 | 78 | ``` 79 | hello/ 80 | ├── Hello.scala 81 | ├── counter.yml 82 | ├── methods.csv 83 | ├── results.csv 84 | ``` 85 | 86 | ## License 87 | 88 | MIT License 89 | -------------------------------------------------------------------------------- /build.sbt: -------------------------------------------------------------------------------- 1 | ThisBuild / scalaVersion := "3.1.3" 2 | 3 | val libVersion = "0.1.0" 4 | val org = "org.mycompany" 5 | 6 | lazy val plugin = project 7 | .settings( 8 | name := "scala-counter-plugin", 9 | organization := org, 10 | version := libVersion, 11 | 12 | libraryDependencies += "org.scala-lang" %% "scala3-compiler" % scalaVersion.value % "provided" 13 | ) 14 | 15 | lazy val runtime = project 16 | .settings( 17 | name := "scala-counter-runtime", 18 | organization := "org.mycompany", 19 | version := libVersion 20 | ) 21 | 22 | lazy val counter = project 23 | .aggregate(plugin, runtime) 24 | .settings( 25 | name := "scala-counter", 26 | organization := org, 27 | version := libVersion 28 | ) 29 | 30 | 31 | lazy val hello = project 32 | .settings( 33 | name := "hello", 34 | version := "0.1.0", 35 | 36 | scalacOptions += "-P:counter:hello/counter.yml", 37 | 38 | libraryDependencies += "org.mycompany" %% "scala-counter-runtime" % "0.1.0", 39 | libraryDependencies += compilerPlugin("org.mycompany" %% "scala-counter-plugin" % "0.1.0") 40 | ) 41 | 42 | 43 | lazy val root = project 44 | .aggregate(plugin, runtime) 45 | -------------------------------------------------------------------------------- /hello/Hello.scala: -------------------------------------------------------------------------------- 1 | object Hello { 2 | def main(args: Array[String]): Unit = { 3 | foo(5) 4 | } 5 | 6 | def foo(x: Int): Unit = 7 | if (x > 0) 8 | foo(x - 1) 9 | else bar() 10 | 11 | def bar(): Unit = println("hello") 12 | } 13 | -------------------------------------------------------------------------------- /hello/counter.yml: -------------------------------------------------------------------------------- 1 | methodsCSV: hello/methods.csv 2 | resultsCSV: hello/results.csv 3 | -------------------------------------------------------------------------------- /hello/methods.csv: -------------------------------------------------------------------------------- 1 | 0, main, Hello$, Hello$, /Users/fliu/Documents/scala-plugin-example/hello/Hello.scala, 1 2 | 1, foo, Hello$, Hello$, /Users/fliu/Documents/scala-plugin-example/hello/Hello.scala, 5 3 | 2, bar, Hello$, Hello$, /Users/fliu/Documents/scala-plugin-example/hello/Hello.scala, 10 4 | -------------------------------------------------------------------------------- /hello/results.csv: -------------------------------------------------------------------------------- 1 | 0, 1 2 | 1, 6 3 | 2, 1 4 | -------------------------------------------------------------------------------- /plugin/src/main/resources/plugin.properties: -------------------------------------------------------------------------------- 1 | pluginClass=counter.Plugin 2 | -------------------------------------------------------------------------------- /plugin/src/main/scala/Phases.scala: -------------------------------------------------------------------------------- 1 | package counter 2 | 3 | import dotty.tools.dotc._ 4 | 5 | import plugins._ 6 | 7 | import core._ 8 | import Contexts._ 9 | import Symbols._ 10 | import Flags._ 11 | import SymDenotations._ 12 | 13 | import Decorators._ 14 | import ast.Trees._ 15 | import ast.tpd 16 | import StdNames.nme 17 | import Names._ 18 | import Constants.Constant 19 | 20 | import scala.language.implicitConversions 21 | 22 | 23 | class PhaseA(setting: Setting) extends PluginPhase { 24 | import tpd._ 25 | 26 | val phaseName = "PhaseA" 27 | 28 | private var enterSym: Symbol = _ 29 | 30 | override val runsAfter = Set(transform.Pickler.name) 31 | override val runsBefore = Set("PhaseB") 32 | 33 | override def prepareForUnit(tree: Tree)(using Context): Context = 34 | val runtime = requiredModule(setting.runtimeObject) 35 | enterSym = runtime.requiredMethod("enter") 36 | ctx 37 | 38 | override def transformDefDef(tree: DefDef)(using Context): Tree = { 39 | val sym = tree.symbol 40 | 41 | // ignore abstract and synthetic methods 42 | if tree.rhs.isEmpty|| sym.isOneOf(Synthetic | Deferred | Private | Accessor) 43 | then return tree 44 | 45 | val methId = setting.add(tree) 46 | val enterTree = ref(enterSym).appliedTo(Literal(Constant(methId))) 47 | 48 | val rhs1 = tpd.Block(enterTree :: Nil, tree.rhs) 49 | 50 | cpy.DefDef(tree)(rhs = rhs1) 51 | } 52 | } 53 | 54 | class PhaseB(setting: Setting) extends PluginPhase { 55 | import tpd._ 56 | 57 | val phaseName: String = "PhaseB" 58 | 59 | override val runsAfter = Set("PhaseA") 60 | override val runsBefore = Set(transform.Erasure.name) 61 | 62 | private var initSym: Symbol = _ 63 | private var dumpSym: Symbol = _ 64 | private var dumped: Boolean = false 65 | 66 | override def prepareForUnit(tree: Tree)(using Context): Context = 67 | if !dumped then 68 | dumped = true 69 | setting.writeMethods() 70 | 71 | val runtime = requiredModule(setting.runtimeObject) 72 | initSym = runtime.requiredMethod("init") 73 | dumpSym = runtime.requiredMethod("dump") 74 | ctx 75 | 76 | override def transformDefDef(tree: DefDef)(using Context): Tree = 77 | if ctx.platform.isMainMethod(tree.symbol) then 78 | val size = setting.methodCount 79 | val initTree = ref(initSym).appliedTo(Literal(Constant(size))) 80 | val dumpTree = ref(dumpSym).appliedTo(Literal(Constant(setting.runtimeOutputFile))) 81 | val rhs1 = Block(initTree :: tree.rhs :: Nil, dumpTree) 82 | cpy.DefDef(tree)(rhs = rhs1) 83 | else tree 84 | } 85 | -------------------------------------------------------------------------------- /plugin/src/main/scala/Plugin.scala: -------------------------------------------------------------------------------- 1 | package counter 2 | 3 | import dotty.tools.dotc.plugins._ 4 | 5 | class Plugin extends StandardPlugin { 6 | val name: String = "counter" 7 | override val description: String = "Count method calls" 8 | 9 | def init(options: List[String]): List[PluginPhase] = 10 | val setting = new Setting(options.headOption) 11 | (new PhaseA(setting)) :: (new PhaseB(setting)) :: Nil 12 | } 13 | -------------------------------------------------------------------------------- /plugin/src/main/scala/Setting.scala: -------------------------------------------------------------------------------- 1 | package counter 2 | 3 | import scala.language.implicitConversions 4 | 5 | import dotty.tools.dotc._ 6 | import core._ 7 | import Contexts._ 8 | import Symbols._ 9 | import Flags._ 10 | import SymDenotations._ 11 | 12 | import Decorators._ 13 | import ast.Trees._ 14 | import ast.tpd 15 | 16 | class Setting(configFile: Option[String]) { 17 | private[this] val methods = new scala.collection.mutable.ArrayBuffer[tpd.DefDef](256) 18 | 19 | private[this] var config: Config = readConfig() 20 | 21 | def add(meth: tpd.DefDef): Int = 22 | methods.append(meth) 23 | methods.size - 1 24 | 25 | def methodCount: Int = methods.size 26 | 27 | def writeMethods()(using Context) = { 28 | val file = new java.io.File(config.methodsCSV) 29 | val bw = new java.io.BufferedWriter(new java.io.FileWriter(file)) 30 | (0 until methods.size).foreach { id => 31 | val methTree = methods(id) 32 | val meth = methTree.symbol 33 | // id, method, enclosing class, top-level class path, line number 34 | bw.write( 35 | id.toString + ", " + 36 | meth.name + ", " + 37 | meth.enclosingClass.name + ", " + 38 | meth.topLevelClass.showFullName + ", " + 39 | methTree.namePos.source + ", " + 40 | methTree.namePos.line + "\n" 41 | ) 42 | } 43 | bw.close() 44 | } 45 | 46 | def runtimeOutputFile: String = config.resultsCSV 47 | 48 | def runtimeObject: String = "counter.Counter" 49 | 50 | private def readConfig(): Config = { 51 | val default = Config(methodsCSV = "methods.csv", resultsCSV = "results.csv") 52 | 53 | configFile.map { file => 54 | import scala.io.Source 55 | val bufferedSource = Source.fromFile(file) 56 | 57 | val config = bufferedSource.getLines.foldLeft(default) { (config, line) => 58 | if line.startsWith("#") then config 59 | else { 60 | val parts = line.split(':') 61 | assert(parts.size == 2, "incorrect config file " + file + ", line = " + line) 62 | parts(0) match 63 | case "methodsCSV" => config.copy(methodsCSV = parts(1).trim()) 64 | case "resultsCSV" => config.copy(resultsCSV = parts(1).trim()) 65 | } 66 | } 67 | bufferedSource.close() 68 | 69 | config 70 | }.getOrElse(default) 71 | } 72 | 73 | private case class Config(methodsCSV: String, resultsCSV: String) 74 | } 75 | -------------------------------------------------------------------------------- /project/build.properties: -------------------------------------------------------------------------------- 1 | sbt.version=1.6.2 2 | -------------------------------------------------------------------------------- /runtime/src/main/scala/Counter.scala: -------------------------------------------------------------------------------- 1 | package counter 2 | 3 | object Counter { 4 | private var count: Array[Long] = null 5 | 6 | def init(num: Int) = 7 | count = new Array(num) 8 | 9 | def enter(id: Int): Unit = 10 | count(id) += 1 11 | 12 | def dump(outputFile: String) = { 13 | val file = new java.io.File(outputFile) 14 | val bw = new java.io.BufferedWriter(new java.io.FileWriter(file)) 15 | (0 until count.size).foreach { id => 16 | bw.write(id.toString + ", " + count(id) + "\n") 17 | } 18 | bw.close() 19 | } 20 | } 21 | --------------------------------------------------------------------------------