├── .gitignore ├── core ├── src │ ├── main │ │ └── scala │ │ │ ├── JumpTarget.scala │ │ │ ├── Syntax.scala │ │ │ ├── Effect.scala │ │ │ ├── Collections.scala │ │ │ ├── AccessibleClassLoader.scala │ │ │ ├── DataPort.scala │ │ │ ├── Graphviz.scala │ │ │ ├── FrameItem.scala │ │ │ ├── FieldDescriptor.scala │ │ │ ├── UniqueNamer.scala │ │ │ ├── DataLabel.scala │ │ │ ├── Frame.scala │ │ │ ├── FieldRef.scala │ │ │ ├── Equality.scala │ │ │ ├── Field.scala │ │ │ ├── Algorithm.scala │ │ │ ├── MethodDescriptor.scala │ │ │ ├── FieldAttribute.scala │ │ │ ├── Flags.scala │ │ │ ├── MethodRef.scala │ │ │ ├── Parsers.scala │ │ │ ├── MethodAttribute.scala │ │ │ ├── Reflect.scala │ │ │ ├── errors.scala │ │ │ ├── Pretty.scala │ │ │ ├── AbstractLabel.scala │ │ │ ├── MethodBody.scala │ │ │ ├── Data.scala │ │ │ ├── ClassRef.scala │ │ │ ├── TypeRef.scala │ │ │ ├── DataSource.scala │ │ │ ├── EventLogger.scala │ │ │ ├── Analyze.scala │ │ │ ├── ClassCompiler.scala │ │ │ ├── FrameUpdate.scala │ │ │ ├── CodeFragment.scala │ │ │ ├── Instance.scala │ │ │ ├── Klass.scala │ │ │ ├── DataFlow.scala │ │ │ ├── Transformer.scala │ │ │ ├── Javassist.scala │ │ │ └── Bytecode.scala │ └── test │ │ └── scala │ │ └── spec.scala └── build.sbt ├── project └── plugins.sbt ├── example ├── build.sbt └── src │ └── main │ └── scala │ └── bench.scala ├── README.md └── LICENSE /.gitignore: -------------------------------------------------------------------------------- 1 | target 2 | *.dot 3 | *.png 4 | *.log 5 | -------------------------------------------------------------------------------- /core/src/main/scala/JumpTarget.scala: -------------------------------------------------------------------------------- 1 | package com.todesking.unveil 2 | 3 | case class JumpTarget(name: String) 4 | -------------------------------------------------------------------------------- /project/plugins.sbt: -------------------------------------------------------------------------------- 1 | addSbtPlugin("org.scalariform" % "sbt-scalariform" % "1.6.0") 2 | 3 | addSbtPlugin("pl.project13.scala" % "sbt-jmh" % "0.2.6") 4 | -------------------------------------------------------------------------------- /example/build.sbt: -------------------------------------------------------------------------------- 1 | scalaVersion := "2.11.7" 2 | 3 | scalacOptions ++= Seq("-feature", "-deprecation", "-Xfatal-warnings") 4 | 5 | scalariformSettings 6 | 7 | enablePlugins(JmhPlugin) 8 | -------------------------------------------------------------------------------- /core/src/main/scala/Syntax.scala: -------------------------------------------------------------------------------- 1 | package com.todesking.unveil 2 | object Syntax { 3 | implicit class Upcast[A](val a: A) extends AnyVal { 4 | def upcast[B >: A]: B = a.asInstanceOf[B] 5 | } 6 | 7 | } 8 | -------------------------------------------------------------------------------- /core/src/main/scala/Effect.scala: -------------------------------------------------------------------------------- 1 | package com.todesking.unveil 2 | 3 | final class Effect private extends AbstractLabel 4 | object Effect extends AbstractLabel.NamerProvider[Effect] { 5 | def fresh() = new Effect 6 | } 7 | -------------------------------------------------------------------------------- /core/src/main/scala/Collections.scala: -------------------------------------------------------------------------------- 1 | package com.todesking.unveil 2 | 3 | import scala.collection.mutable 4 | 5 | object Collections { 6 | def newMultiMap[A, B]: mutable.MultiMap[A, B] = 7 | new mutable.HashMap[A, mutable.Set[B]] with mutable.MultiMap[A, B] 8 | } 9 | -------------------------------------------------------------------------------- /core/src/main/scala/AccessibleClassLoader.scala: -------------------------------------------------------------------------------- 1 | package com.todesking.unveil 2 | 3 | class AccessibleClassLoader(parent: ClassLoader) extends ClassLoader(parent) { 4 | def registerClass(name: String, bytes: Array[Byte]): Unit = { 5 | defineClass(name, bytes, 0, bytes.size) 6 | } 7 | } 8 | -------------------------------------------------------------------------------- /core/src/main/scala/DataPort.scala: -------------------------------------------------------------------------------- 1 | package com.todesking.unveil 2 | 3 | sealed abstract class DataPort { 4 | def name: String 5 | } 6 | object DataPort { 7 | case class In(override val name: String) extends DataPort { 8 | override def toString = s"$name(in)" 9 | } 10 | case class Out(override val name: String) extends DataPort { 11 | override def toString = s"$name(out)" 12 | } 13 | } 14 | -------------------------------------------------------------------------------- /core/src/main/scala/Graphviz.scala: -------------------------------------------------------------------------------- 1 | package com.todesking.unveil 2 | 3 | object Graphviz { 4 | def drawAttr(attr: Seq[(Symbol, String)]) = s"""[${attr.map { case (k, v) => k.name + "=\"" + v + "\"" }.mkString(", ")}]""" 5 | def drawNode(id: String, attr: (Symbol, String)*) = s"""${id}${drawAttr(attr)}""" 6 | def drawEdge(from: String, to: String, attr: (Symbol, String)*) = 7 | s"""${from} -> ${to} ${drawAttr(attr)}""" 8 | } 9 | 10 | -------------------------------------------------------------------------------- /core/src/main/scala/FrameItem.scala: -------------------------------------------------------------------------------- 1 | package com.todesking.unveil 2 | 3 | // TODO: add query methods about types(isDoubleWord etc) for FrameUpdate 4 | case class FrameItem(source: DataSource, data: Data) { 5 | def merge(rhs: FrameItem): FrameItem = 6 | FrameItem(source.merge(rhs.source), data.merge(rhs.data)) 7 | def replaceDataBySource(src: DataSource.Single, to: Data): FrameItem = 8 | if(source == src) FrameItem(source, to) else this 9 | } 10 | -------------------------------------------------------------------------------- /core/src/main/scala/FieldDescriptor.scala: -------------------------------------------------------------------------------- 1 | package com.todesking.unveil 2 | 3 | case class FieldDescriptor(typeRef: TypeRef.Public) { 4 | def str = typeRef.str 5 | override def toString = typeRef.toString 6 | } 7 | object FieldDescriptor { 8 | def from(f: java.lang.reflect.Field): FieldDescriptor = 9 | FieldDescriptor(TypeRef.from(f.getType)) 10 | def parse(src: String, cl: ClassLoader): FieldDescriptor = 11 | Parsers.parseFieldDescriptor(src, cl) 12 | } 13 | -------------------------------------------------------------------------------- /core/src/main/scala/UniqueNamer.scala: -------------------------------------------------------------------------------- 1 | package com.todesking.unveil 2 | 3 | class UniqueNamer() { 4 | private[this] var id = 0 5 | private[this] def nextId(): Int = { 6 | id += 1 7 | id = math.abs(id) 8 | id 9 | } 10 | 11 | def apply(baseNames: String*): String = { 12 | val prefix = baseNames 13 | .map(_.replaceAll("""\$[0-9]+\$$""", "")) 14 | .map(_.replaceAll("[^A-Za-z0-9$]", "_")) 15 | .mkString("__") 16 | prefix + "$" + nextId() + "$" 17 | } 18 | } 19 | 20 | -------------------------------------------------------------------------------- /core/src/main/scala/DataLabel.scala: -------------------------------------------------------------------------------- 1 | package com.todesking.unveil 2 | 3 | sealed abstract class DataLabel private (val name: String) extends AbstractLabel 4 | object DataLabel extends AbstractLabel.NamerProvider[DataLabel] { 5 | final class In(name: String) extends DataLabel(name) { 6 | override def toString = s"DataLabel.In(${name})#${innerId}" 7 | } 8 | final class Out(name: String) extends DataLabel(name) { 9 | override def toString = s"DataLabel.Out(${name})#${innerId}" 10 | } 11 | 12 | def in(name: String) = new In(name) 13 | def out(name: String) = new Out(name) 14 | } 15 | -------------------------------------------------------------------------------- /core/src/main/scala/Frame.scala: -------------------------------------------------------------------------------- 1 | package com.todesking.unveil 2 | 3 | case class Frame(locals: Map[Int, FrameItem], stack: List[FrameItem]) { 4 | def replaceDataBySource(src: DataSource.Single, to: Data): Frame = 5 | Frame( 6 | locals.mapValues(_.replaceDataBySource(src, to)), 7 | stack.map(_.replaceDataBySource(src, to)) 8 | ) 9 | 10 | def local(n: Int): FrameItem = 11 | locals(n) 12 | 13 | def stackTop: FrameItem = stack.head 14 | 15 | def pretty: String = s"""Locals: 16 | ${locals.map { case (k, v) => s"${k} = ${v}" }.mkString("\n")} 17 | Stack: 18 | ${stack.mkString("\n")}""" 19 | } 20 | 21 | -------------------------------------------------------------------------------- /core/src/main/scala/FieldRef.scala: -------------------------------------------------------------------------------- 1 | package com.todesking.unveil 2 | 3 | case class FieldRef(name: String, descriptor: FieldDescriptor) { 4 | override def toString: String = s"${name}: ${descriptor}" 5 | def typeRef: TypeRef.Public = descriptor.typeRef 6 | def renamed(newName: String): FieldRef = copy(name = newName) 7 | def anotherUniqueName(baseNames: String*): FieldRef = 8 | if (baseNames.isEmpty) anotherUniqueName(name) 9 | else copy(name = FieldRef.uniqueName(baseNames: _*)) 10 | } 11 | object FieldRef { 12 | def from(f: java.lang.reflect.Field): FieldRef = 13 | FieldRef(f.getName, FieldDescriptor.from(f)) 14 | 15 | val uniqueName = new UniqueNamer 16 | } 17 | -------------------------------------------------------------------------------- /core/src/main/scala/Equality.scala: -------------------------------------------------------------------------------- 1 | package com.todesking.unveil 2 | 3 | object Equality { 4 | trait Reference extends AnyRef { 5 | final override def equals(rhs: Any): Boolean = rhs match { 6 | case rhs: AnyRef => this eq rhs 7 | case _ => false 8 | } 9 | final override def hashCode: Int = 10 | java.lang.System.identityHashCode(this) 11 | } 12 | trait Delegate extends AnyRef with scala.Equals { 13 | def equalityObject: Any 14 | 15 | final override def equals(rhs: Any): Boolean = rhs match { 16 | case rhs: scala.Equals with Delegate => 17 | rhs.canEqual(this) && equalityObject == rhs.equalityObject 18 | case _ => false 19 | } 20 | 21 | final override def hashCode: Int = equalityObject.hashCode 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /core/build.sbt: -------------------------------------------------------------------------------- 1 | scalaVersion := "2.11.7" 2 | 3 | organization := "com.todesking" 4 | 5 | name := "unveil" 6 | 7 | resolvers += "com.todesking" at "http://todesking.github.io/mvn/" 8 | 9 | libraryDependencies += "org.scala-lang" % "scala-reflect" % scalaVersion.value 10 | 11 | libraryDependencies += "org.scala-lang.modules" %% "scala-parser-combinators" % "1.0.4" 12 | 13 | libraryDependencies += "org.javassist" % "javassist" % "3.20.0-GA" 14 | 15 | libraryDependencies += "org.scalatest" %% "scalatest" % "2.2.6" % "test" 16 | 17 | libraryDependencies += "com.todesking" %% "scala-pp" % "0.0.4" 18 | 19 | scalacOptions ++= Seq("-feature", "-deprecation", "-Xfatal-warnings") 20 | 21 | testOptions in Test += Tests.Argument("-oFI") 22 | 23 | scalariformSettings 24 | 25 | fork in test := true 26 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Unveil: Runtime JVM bytecode optimizer 2 | 3 | * [日本語の発表資料](http://techlog.mvrck.co.jp/entry/todesking-runtime-jvm-bytecode-optimization/) 4 | 5 | ## Motivation 6 | 7 | In spite of JVM has JIT compilation, theres remaining some performance overhead. 8 | 9 | These overhead becomes deadly in some application. 10 | 11 | This project aim to reducing overhead with runtime bytecode optimization. 12 | 13 | ## Current Status 14 | 15 | Very experimental. Stay tuned! 16 | 17 | 18 | ## Benchmark 19 | 20 | Coming soon(really). 21 | 22 | 23 | ## Related work 24 | 25 | [Soot](https://sable.github.io/soot/) is JVM Bytecode optimizer framework. 26 | It aims compile-time optimization, not runtime(AFAIK). 27 | 28 | [StreamJIT](https://github.com/jbosboom/streamjit) proposes "Commensal compiler" paradigm. 29 | It uses runtime optimization techniques via `MethodHandle`. 30 | 31 | 32 | -------------------------------------------------------------------------------- /core/src/main/scala/Field.scala: -------------------------------------------------------------------------------- 1 | package com.todesking.unveil 2 | 3 | import scala.language.existentials 4 | import scala.language.higherKinds 5 | 6 | import java.lang.reflect.{ Method => JMethod, Field => JField, Modifier } 7 | 8 | // TODO: remove whole 9 | case class Field( 10 | descriptor: FieldDescriptor, 11 | attribute: FieldAttribute, 12 | data: Data.Concrete 13 | ) { 14 | def isFinal: Boolean = attribute.isFinal 15 | } 16 | object Field { 17 | def from(f: JField, obj: AnyRef): Field = 18 | Field( 19 | FieldDescriptor.from(f), 20 | FieldAttribute.from(f), 21 | data(f, obj) 22 | ) 23 | 24 | private[this] def data(f: JField, obj: AnyRef): Data.Concrete = { 25 | val v = f.get(obj) 26 | TypeRef.from(f.getType) match { 27 | case t: TypeRef.Primitive => Data.ConcretePrimitive(t, v.asInstanceOf[AnyVal]) 28 | case t: TypeRef.Reference if v == null => Data.Null 29 | case t: TypeRef.Reference => Data.ConcreteReference(Instance.of(v)) 30 | } 31 | } 32 | } 33 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2016 todesking 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in 13 | all copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 21 | THE SOFTWARE. 22 | -------------------------------------------------------------------------------- /core/src/main/scala/Algorithm.scala: -------------------------------------------------------------------------------- 1 | package com.todesking.unveil 2 | 3 | object Algorithm { 4 | 5 | def mapZip[A, B, C](a: Map[A, B], b: Map[A, C]): (Map[A, (B, C)], Map[A, B], Map[A, C]) = { 6 | val aOnly = a.keySet -- b.keySet 7 | val bOnly = b.keySet -- a.keySet 8 | val common = a.keySet -- aOnly 9 | Tuple3( 10 | common.map { k => k -> (a(k) -> b(k)) }.toMap, 11 | a.filterKeys(aOnly), 12 | b.filterKeys(bOnly) 13 | ) 14 | } 15 | 16 | def sharedNothingUnion[A, B, C <: B, D <: B](m1: Map[A, C], m2: Map[A, D]): Option[Map[A, B]] = { 17 | val union = m1 ++ m2 18 | if (m1.size + m2.size > union.size) None 19 | else Some(union) 20 | } 21 | 22 | def tsort[A, B](in: Seq[A])(labelOf: A => B)(depsOf: A => Set[B]): Seq[A] = 23 | tsort0(in.map { i => (i, labelOf(i), depsOf(i)) }, Set.empty, Seq.empty) 24 | 25 | private[this] def tsort0[A, B](in: Seq[(A, B, Set[B])], deps: Set[B], sorted: Seq[A]): Seq[A] = 26 | if (in.isEmpty) { 27 | sorted 28 | } else { 29 | val (nodep, dep) = in.partition { case (a, b, bs) => bs.forall(deps.contains) } 30 | if (nodep.isEmpty) throw new IllegalArgumentException(s"Cyclic reference found: ${dep}") 31 | tsort0(dep, deps ++ nodep.map(_._2), sorted ++ nodep.map(_._1)) 32 | } 33 | } 34 | -------------------------------------------------------------------------------- /core/src/main/scala/MethodDescriptor.scala: -------------------------------------------------------------------------------- 1 | package com.todesking.unveil 2 | 3 | import java.lang.reflect.{ Method => JMethod, Constructor } 4 | 5 | case class MethodDescriptor(ret: TypeRef.Public, args: Seq[TypeRef.Public]) { 6 | def argsStr: String = args.map(_.str).mkString("(", "", ")") 7 | def str: String = s"${argsStr}${ret.str}" 8 | 9 | override def toString = s"${args.mkString("(", ", ", ")")}${ret}" 10 | 11 | def isVoid: Boolean = ret == TypeRef.Void 12 | 13 | def argToLocalIndex(arg: Int, isStatic: Boolean): Int = 14 | (if (isStatic) 0 else 1) + argToLocalTable(arg) 15 | 16 | private[this] lazy val argToLocalTable: Seq[Int] = { 17 | var n = 0 18 | val t = scala.collection.mutable.ArrayBuffer.empty[Int] 19 | args.zipWithIndex foreach { 20 | case (arg, i) => 21 | t(i) = n 22 | n += arg.wordSize 23 | } 24 | t.toSeq 25 | } 26 | 27 | } 28 | object MethodDescriptor { 29 | def parse(src: String, cl: ClassLoader): MethodDescriptor = 30 | Parsers.parseMethodDescriptor(src, cl) 31 | 32 | def from(m: JMethod): MethodDescriptor = 33 | MethodDescriptor(TypeRef.from(m.getReturnType), m.getParameterTypes.map(TypeRef.from).toSeq) 34 | 35 | def from(m: Constructor[_]): MethodDescriptor = 36 | MethodDescriptor(TypeRef.Void, m.getParameterTypes.map(TypeRef.from).toSeq) 37 | } 38 | -------------------------------------------------------------------------------- /core/src/main/scala/FieldAttribute.scala: -------------------------------------------------------------------------------- 1 | package com.todesking.unveil 2 | 3 | import java.lang.reflect.{ Field => JField, Modifier } 4 | 5 | sealed abstract class FieldAttribute extends Flags[FieldAttribute] { 6 | def isStatic: Boolean = has(FieldAttribute.Static) 7 | def isFinal: Boolean = has(FieldAttribute.Final) 8 | def isPrivate: Boolean = has(FieldAttribute.Private) 9 | def isPrivateFinal: Boolean = isPrivate && isFinal 10 | def makePrivate: FieldAttribute 11 | } 12 | object FieldAttribute extends FlagsCompanion[FieldAttribute] { 13 | def from(m: JField): FieldAttribute = 14 | items.filter(_.enabledIn(m.getModifiers)).reduce[FieldAttribute](_ | _) 15 | 16 | override def multi(items: Set[SingleFlag]): FieldAttribute = 17 | Multi(items) 18 | 19 | case class Multi(override val items: Set[SingleFlag]) extends FieldAttribute with MultiFlags { 20 | override def makePrivate = Multi(items.filterNot(_ == Public).filterNot(_ == Protected)) | Private 21 | } 22 | 23 | sealed abstract class Single(val toInt: Int) extends FieldAttribute with SingleFlag { 24 | override def makePrivate = Multi(Set(this)).makePrivate 25 | } 26 | 27 | case object Public extends Single(Modifier.PUBLIC) 28 | case object Private extends Single(Modifier.PRIVATE) 29 | case object Protected extends Single(Modifier.PROTECTED) 30 | case object Final extends Single(Modifier.FINAL) 31 | case object Static extends Single(Modifier.STATIC) 32 | 33 | val items = Seq(Public, Private, Protected, Final, Static) 34 | } 35 | -------------------------------------------------------------------------------- /core/src/main/scala/Flags.scala: -------------------------------------------------------------------------------- 1 | package com.todesking.unveil 2 | 3 | trait Flags[Type <: Flags[Type]] { 4 | def |(that: Flags[Type]): Type 5 | def enabledIn(flags: Int): Boolean 6 | def has(flags: Flags[Type]): Boolean 7 | def toInt: Int 8 | } 9 | 10 | trait FlagsCompanion[Type <: Flags[Type]] { 11 | def multi(items: Set[SingleFlag]): Type 12 | 13 | trait MultiFlags extends Flags[Type] { 14 | def items: Set[SingleFlag] 15 | 16 | override def |(that: Flags[Type]): Type = that match { 17 | case that: MultiFlags => multi(items ++ that.items) 18 | case that: SingleFlag => multi(items + that) 19 | } 20 | 21 | override def enabledIn(flags: Int) = items.forall(_.enabledIn(flags)) 22 | 23 | override def has(flags: Flags[Type]) = flags match { 24 | case that: MultiFlags => that.items.subsetOf(this.items) 25 | case that: SingleFlag => items.contains(that) 26 | } 27 | 28 | override def toString = s"${items.mkString(", ")}" 29 | 30 | override def toInt = items.foldLeft[Int](0)(_ | _.toInt) 31 | } 32 | trait SingleFlag extends Flags[Type] { 33 | override def |(that: Flags[Type]): Type = that match { 34 | case that: MultiFlags => multi(that.items + this) 35 | case that: SingleFlag => multi(Set(this, that)) 36 | } 37 | 38 | override def enabledIn(flags: Int) = 39 | (flags & toInt) == toInt 40 | 41 | override def has(flags: Flags[Type]): Boolean = flags match { 42 | case that: MultiFlags => that.items.forall(has(_)) 43 | case that: SingleFlag => this == that 44 | } 45 | } 46 | } 47 | -------------------------------------------------------------------------------- /core/src/main/scala/MethodRef.scala: -------------------------------------------------------------------------------- 1 | package com.todesking.unveil 2 | 3 | import java.lang.reflect.{ Method => JMethod, Constructor } 4 | 5 | case class MethodRef(name: String, descriptor: MethodDescriptor) { 6 | require(!isInit || isVoid) 7 | 8 | def isInit: Boolean = name == "" 9 | def isVoid: Boolean = descriptor.isVoid 10 | 11 | def str: String = name + descriptor.str 12 | override def toString = name + descriptor.toString 13 | def args: Seq[TypeRef.Public] = descriptor.args 14 | def ret: TypeRef.Public = descriptor.ret 15 | def renamed(newName: String): MethodRef = copy(name = newName) 16 | def anotherUniqueName(base: String*): MethodRef = 17 | if (base.isEmpty) anotherUniqueName(name) 18 | else renamed(MethodRef.uniqueName(base: _*)) 19 | } 20 | object MethodRef { 21 | def from(m: Constructor[_]): MethodRef = 22 | constructor(MethodDescriptor(TypeRef.Void, m.getParameterTypes.map(TypeRef.from))) 23 | 24 | def from(m: JMethod): MethodRef = 25 | MethodRef(m.getName, MethodDescriptor.from(m)) 26 | 27 | def constructor(d: MethodDescriptor): MethodRef = 28 | MethodRef("", d) 29 | 30 | val uniqueName = new UniqueNamer 31 | 32 | def parse(src: String, cl: ClassLoader): MethodRef = 33 | parser(cl).parse(src) 34 | 35 | case class parser(classLoader: ClassLoader) { 36 | lazy val all = """([^(]+)(\(.+)""".r 37 | def parse(src: String): MethodRef = 38 | src match { 39 | case `all`(name, desc) => 40 | MethodRef(name, MethodDescriptor.parse(desc, classLoader)) 41 | case unk => 42 | throw new IllegalArgumentException(s"Invalid method ref: ${unk}") 43 | } 44 | } 45 | } 46 | -------------------------------------------------------------------------------- /core/src/main/scala/Parsers.scala: -------------------------------------------------------------------------------- 1 | package com.todesking.unveil 2 | 3 | object Parsers extends scala.util.parsing.combinator.RegexParsers { 4 | def parseTypeRef(src: String, cl: ClassLoader): TypeRef.Public = 5 | parseAll(typeRef(cl).all, src).get 6 | 7 | def parseMethodDescriptor(src: String, cl: ClassLoader): MethodDescriptor = 8 | parseAll(methodDescriptor(cl).all, src).get 9 | 10 | def parseFieldDescriptor(src: String, cl: ClassLoader): FieldDescriptor = 11 | parseAll(fieldDescriptor(cl).all, src).get 12 | 13 | case class typeRef(classLoader: ClassLoader) { 14 | val refPat = """L([^;]+);""".r 15 | lazy val all: Parser[TypeRef.Public] = "B|Z|C|S|I|F|J|D|V|L[^;]+;".r ^^ { 16 | case "B" => TypeRef.Byte 17 | case "Z" => TypeRef.Boolean 18 | case "C" => TypeRef.Char 19 | case "S" => TypeRef.Short 20 | case "I" => TypeRef.Int 21 | case "F" => TypeRef.Float 22 | case "J" => TypeRef.Long 23 | case "D" => TypeRef.Double 24 | case "V" => TypeRef.Void 25 | case `refPat`(ref) => 26 | val cName = ref.replaceAll("/", ".") 27 | val klass = (if (classLoader == null) ClassLoader.getSystemClassLoader else classLoader).loadClass(cName) 28 | TypeRef.Reference(ClassRef.of(klass)) 29 | } 30 | } 31 | 32 | case class methodDescriptor(classLoader: ClassLoader) { 33 | lazy val all = args ~ tpe ^^ { case args ~ ret => MethodDescriptor(ret, args) } 34 | lazy val args = ('(' ~> rep(tpe)) <~ ')' 35 | lazy val tpe = typeRef(classLoader).all 36 | } 37 | 38 | case class fieldDescriptor(classLoader: ClassLoader) { 39 | lazy val all = typeRef(classLoader).all.map { tr => FieldDescriptor(tr) } 40 | } 41 | } 42 | -------------------------------------------------------------------------------- /core/src/main/scala/MethodAttribute.scala: -------------------------------------------------------------------------------- 1 | package com.todesking.unveil 2 | 3 | import java.lang.reflect.{ Method => JMethod, Modifier } 4 | 5 | sealed abstract class MethodAttribute extends Flags[MethodAttribute] { 6 | def isVirtual: Boolean = 7 | !this.has(MethodAttribute.Private) && !this.has(MethodAttribute.Static) 8 | def isFinal: Boolean = 9 | this.has(MethodAttribute.Final) 10 | def isStatic: Boolean = 11 | this.has(MethodAttribute.Static) 12 | def isAbstract: Boolean = 13 | this.has(MethodAttribute.Abstract) 14 | def isNative: Boolean = 15 | this.has(MethodAttribute.Native) 16 | def isPrivate: Boolean = 17 | this.has(MethodAttribute.Private) 18 | def makePrivate: MethodAttribute 19 | def makeNonFinal: MethodAttribute 20 | } 21 | object MethodAttribute extends FlagsCompanion[MethodAttribute] { 22 | def from(m: JMethod): MethodAttribute = 23 | from(m.getModifiers) 24 | 25 | def from(flags: Int): MethodAttribute = 26 | items.filter(_.enabledIn(flags)).foldLeft[MethodAttribute](empty)(_ | _) 27 | 28 | val empty = Multi(Set.empty) 29 | 30 | override def multi(items: Set[SingleFlag]): MethodAttribute = 31 | Multi(items) 32 | 33 | case class Multi(override val items: Set[SingleFlag]) extends MethodAttribute with MultiFlags { 34 | override def makePrivate = Multi(items.filterNot(_ == Public).filterNot(_ == Protected)) | Private 35 | override def makeNonFinal = Multi(items.filterNot(_ == Final)) 36 | } 37 | 38 | sealed abstract class Single(override val toInt: Int) extends MethodAttribute with SingleFlag { 39 | override def makePrivate = Multi(Set(this)).makePrivate 40 | override def makeNonFinal = Multi(Set(this)).makeNonFinal 41 | } 42 | 43 | case object Public extends Single(Modifier.PUBLIC) 44 | case object Private extends Single(Modifier.PRIVATE) 45 | case object Protected extends Single(Modifier.PROTECTED) 46 | case object Native extends Single(Modifier.NATIVE) 47 | case object Abstract extends Single(Modifier.ABSTRACT) 48 | case object Final extends Single(Modifier.FINAL) 49 | case object Synchronized extends Single(Modifier.SYNCHRONIZED) 50 | case object Strict extends Single(Modifier.STRICT) 51 | case object Static extends Single(Modifier.STATIC) 52 | 53 | val items = Seq(Public, Private, Protected, Native, Abstract, Final, Synchronized, Strict, Static) 54 | } 55 | 56 | -------------------------------------------------------------------------------- /core/src/main/scala/Reflect.scala: -------------------------------------------------------------------------------- 1 | package com.todesking.unveil 2 | 3 | import java.lang.reflect.{ Method => JMethod, Field => JField, Constructor => JConstructor } 4 | 5 | object Reflect { 6 | // TODO: default interface method 7 | def allJMethods(jClass: Class[_]): Map[(ClassRef, MethodRef), JMethod] = 8 | supers(jClass) 9 | .flatMap(_.getDeclaredMethods) 10 | .map { m => (ClassRef.of(m.getDeclaringClass) -> MethodRef.from(m)) -> m } 11 | .toMap 12 | 13 | def allJConstructors(jClass: Class[_]): Map[(ClassRef, MethodRef), JConstructor[_]] = 14 | supers(jClass) 15 | .flatMap(_.getDeclaredConstructors) 16 | .map { m => (ClassRef.of(m.getDeclaringClass) -> MethodRef.from(m)) -> m } 17 | .toMap 18 | 19 | // TODO: default interface method 20 | def resolveVirtualMethod(jClass: Class[_], mr: MethodRef): ClassRef.Concrete = 21 | supers(jClass) 22 | .find { c => 23 | c.getDeclaredMethods 24 | .filter { m => MethodAttribute.from(m.getModifiers).isVirtual } 25 | .exists { m => MethodRef.from(m) == mr } 26 | }.map { c => ClassRef.of(c) } 27 | .getOrElse { throw new IllegalArgumentException(s"Can't find virtual method ${mr} in ${jClass}") } 28 | 29 | // TODO: default interface method 30 | def virtualJMethods(jClass: Class[_]): Map[MethodRef, JMethod] = 31 | supers(jClass) 32 | .reverse 33 | .flatMap(_.getDeclaredMethods) 34 | .filterNot { m => MethodAttribute.Private.enabledIn(m.getModifiers) } 35 | .foldLeft(Map.empty[MethodRef, JMethod]) { 36 | case (map, m) => 37 | map + (MethodRef.from(m) -> m) 38 | } 39 | 40 | def allJFields(jClass: Class[_]): Map[(ClassRef, FieldRef), JField] = 41 | supers(jClass) 42 | .flatMap(_.getDeclaredFields) 43 | .map { f => f.setAccessible(true); f } // I believe this it no trigger any bad side-effects 44 | .map { f => (ClassRef.of(f.getDeclaringClass) -> FieldRef.from(f)) -> f } 45 | .toMap 46 | 47 | def supers(klass: Class[_]): Seq[Class[_]] = 48 | klass +: Option(klass.getSuperclass).toSeq.flatMap(supers) 49 | 50 | def superClassOf(cr: ClassRef): Option[ClassRef] = 51 | cr match { 52 | case cr @ ClassRef.Concrete(_, _) => 53 | Option(cr.loadClass.getSuperclass).map(ClassRef.of) 54 | case ClassRef.Extend(s, _, _, _) => 55 | Some(s) 56 | } 57 | } 58 | 59 | -------------------------------------------------------------------------------- /core/src/main/scala/errors.scala: -------------------------------------------------------------------------------- 1 | package com.todesking.unveil 2 | 3 | class UnveilException(msg: String, err: Throwable) extends RuntimeException(msg, err) { 4 | def this(msg: String) = this(msg, null) 5 | } 6 | 7 | object UnveilException { 8 | trait HasMethodBody extends UnveilException { 9 | def methodBody: MethodBody 10 | } 11 | trait HasMethodLocation extends UnveilException { 12 | def methodLocation: (ClassRef, MethodRef) 13 | } 14 | trait HasFieldLocation extends UnveilException { 15 | def fieldLocation: (ClassRef, MethodRef) 16 | } 17 | } 18 | 19 | class MaterializeException(msg: String, err: Throwable) extends UnveilException(msg, err) 20 | 21 | class InvalidClassException(val klass: Klass, err: LinkageError) 22 | extends MaterializeException(err.toString, err) 23 | 24 | class AnalyzeException(msg: String) extends UnveilException(msg) 25 | 26 | class MethodAnalyzeException(classRef: ClassRef, methodRef: MethodRef, msg: String) 27 | extends AnalyzeException(s"Method analyze failed(${classRef}.${methodRef}): ${msg}") 28 | 29 | class MethodBodyAnalyzeException(override val methodBody: MethodBody, msg: String) 30 | extends AnalyzeException(s"Method analyze failed(descriptor: ${methodBody.descriptor}): ${msg}") with UnveilException.HasMethodBody 31 | 32 | class UnsupportedOpcodeException(classRef: ClassRef, methodRef: MethodRef, byte: Int) 33 | extends MethodAnalyzeException(classRef, methodRef, f"Unsupported opcode: 0x$byte%02X") 34 | 35 | class FieldAnalyzeException(classRef: ClassRef, fieldRef: FieldRef, msg: String) 36 | extends AnalyzeException(s"Field analyze failed(${classRef}.${fieldRef}): ${msg}") 37 | 38 | class TransformException(msg: String) extends UnveilException(msg) 39 | 40 | class FieldTransformException(classRef: ClassRef, fieldRef: FieldRef, msg: String) 41 | extends AnalyzeException(s"Transform failed at ${classRef}.${fieldRef}: ${msg}") 42 | 43 | class MethodTransformException(classRef: ClassRef, methodRef: MethodRef, msg: String) 44 | extends AnalyzeException(s"Transform failed at ${classRef}.${methodRef}: ${msg}") 45 | 46 | class BytecodeTransformException(val classRef: ClassRef, val methodRef: MethodRef, override val methodBody: MethodBody, val bytecode: Bytecode, msg: String) 47 | extends MethodTransformException(classRef, methodRef, s"$bytecode: $msg") with UnveilException.HasMethodBody 48 | 49 | class UnveilBugException(msg: String, err: Throwable) extends RuntimeException(msg, err) { 50 | def detail: String = "" 51 | } 52 | -------------------------------------------------------------------------------- /core/src/main/scala/Pretty.scala: -------------------------------------------------------------------------------- 1 | package com.todesking.unveil 2 | 3 | object Pretty { 4 | def format_MethodBody(mb: MethodBody): String = ??? 5 | def format_DataFlow(df: DataFlow): String = ??? 6 | 7 | def format_Instance_Original[A <: AnyRef](o: Instance.Original[A]): String = ??? 8 | def format_Instance_Duplicate[A <: AnyRef](dup: Instance.Duplicate[A]): String = ??? 9 | def format_Instance_New[A <: AnyRef](n: Instance.New[A]): String = ??? 10 | 11 | def format_Klass(klass: Klass): String = klass match { 12 | case k: Klass.MaterializedNative => format_Klass_MaterializedNative(k) 13 | case k: Klass.Modified => format_Klass_Modified(k) 14 | case k => k.toString 15 | } 16 | def format_Klass_MaterializedNative(klass: Klass.MaterializedNative): String = klass.toString 17 | def format_Klass_Modified(klass: Klass.Modified): String = { 18 | s"""class ${klass.ref} { 19 | // new/overriden methods: 20 | ${ 21 | klass.declaredMethods.map { 22 | case (mr, body) => 23 | try { 24 | val df = klass.dataflow(klass.ref, mr) 25 | s""" def ${mr} ${body.attribute} 26 | ${df.pretty.split("\n").map(" " + _).mkString("\n")}""" 27 | } catch { 28 | case scala.util.control.NonFatal(e) => 29 | s"""(dataflow analysis failed: $e) 30 | def ${mr} ${body.attribute} 31 | ${body.pretty.split("\n").map(" " + _).mkString("\n")}""" 32 | } 33 | }.mkString("\n") 34 | } 35 | 36 | // New fields: 37 | ${ 38 | klass.declaredFields.map { 39 | case (fr, attr) => s"$fr $attr" 40 | }.map(" " + _).mkString("\n") 41 | } 42 | 43 | // Super fields: 44 | ${ 45 | klass.instanceFieldAttributes.filterNot(_._1._1 == klass.ref).map { 46 | case ((cr, fr), attr) => s"$cr.$fr ${attr}" 47 | }.mkString("\n") 48 | } 49 | }""" 50 | } 51 | 52 | def format_CodeFragment_Complete(cf: CodeFragment.Complete): String = { 53 | s"""${ 54 | cf.bytecode.map { 55 | case (l, bc) => 56 | val format = "L%03d" 57 | l.format(format) + " " + (bc match { 58 | case bc: Bytecode.HasAJumpTarget => 59 | s"${bc.pretty} # ${cf.jumpDestination(l, bc.jumpTarget).format(format)}" 60 | case bc => 61 | bc.pretty 62 | }) 63 | }.mkString("\n") 64 | } 65 | """ 66 | } 67 | 68 | 69 | // TODO: super class information 70 | private[this] def format_Klass0( 71 | classRef: ClassRef, 72 | declaredMethods: Map[MethodRef, MethodBody], 73 | declaredFields: Map[FieldRef, FieldAttribute], 74 | constructor: Option[(MethodDescriptor, Option[Seq[Data.Concrete]])], // (ctor, args?)? 75 | fieldValues: Map[(ClassRef, FieldRef), Data] 76 | ): String = { 77 | ??? 78 | } 79 | } 80 | -------------------------------------------------------------------------------- /core/src/main/scala/AbstractLabel.scala: -------------------------------------------------------------------------------- 1 | package com.todesking.unveil 2 | 3 | import scala.collection.mutable 4 | 5 | abstract class AbstractLabel() extends AnyRef { 6 | val innerId = AbstractLabel.nextId() 7 | override def equals(other: Any): Boolean = 8 | other match { 9 | case l: AbstractLabel => this.innerId == l.innerId 10 | case _ => false 11 | } 12 | override def hashCode: Int = 13 | innerId.hashCode 14 | 15 | override def toString = s"${getClass.getName}#${innerId}" 16 | } 17 | object AbstractLabel { 18 | private[this] var _nextId = 0 19 | private def nextId(): Int = synchronized { 20 | val n = _nextId 21 | _nextId = _nextId + 1 22 | n 23 | } 24 | class Namer[A <: AbstractLabel](idPrefix: String, namePrefix: String) { 25 | private[this] val ids = mutable.HashMap.empty[A, Int] 26 | private[this] var nextId = 0 27 | 28 | def num(l: A): Int = 29 | ids.get(l) getOrElse { 30 | ids(l) = nextId 31 | nextId += 1 32 | ids(l) 33 | } 34 | def id(l: A): String = s"${idPrefix}${num(l)}" 35 | def name(l: A): String = s"${namePrefix}${num(l)}" 36 | } 37 | class Assigner[A, L <: AbstractLabel](fresh: => L) { 38 | private[this] val mapping = mutable.HashMap.empty[A, L] 39 | def apply(key: A): L = 40 | mapping.get(key) getOrElse { 41 | val l = fresh 42 | mapping(key) = l 43 | l 44 | } 45 | } 46 | class Merger[L <: AbstractLabel](fresh: => L) { 47 | private[this] val merges = new mutable.HashMap[L, mutable.Set[L]] with mutable.MultiMap[L, L] 48 | private[this] val cache = new mutable.HashMap[(L, L), L] 49 | 50 | def toMap: Map[L, Set[L]] = merges.mapValues(_.toSet).toMap 51 | 52 | // TODO: Is this really enough? 53 | def merge(l1: L, l2: L): L = 54 | cache.get((l1, l2)) getOrElse { 55 | val m = merge0(l1, l2) 56 | cache((l1, l2)) = m 57 | cache((l2, l1)) = m 58 | m 59 | } 60 | 61 | private[this] def merge0(l1: L, l2: L): L = 62 | if (l1 == l2) { 63 | l1 64 | } else if (merges.contains(l1)) { 65 | if (merges.contains(l2)) throw new AssertionError 66 | merges.addBinding(l1, l2) 67 | l1 68 | } else if (merges.contains(l2)) { 69 | merges.addBinding(l2, l1) 70 | l2 71 | } else if (merges.find(_._1 == l1).map(_._2.contains(l2)) getOrElse false) { 72 | l1 73 | } else if (merges.find(_._1 == l2).map(_._2.contains(l1)) getOrElse false) { 74 | l2 75 | } else { 76 | val m = fresh 77 | merges.addBinding(m, l1) 78 | merges.addBinding(m, l2) 79 | m 80 | } 81 | } 82 | 83 | trait NamerProvider[A <: AbstractLabel] { 84 | def namer(idPrefix: String, namePrefix: String): Namer[A] = new Namer(idPrefix, namePrefix) 85 | } 86 | trait AssignerProvider[A <: AbstractLabel] { self: { def fresh(): A } => 87 | def assigner[B](): Assigner[B, A] = new Assigner(fresh()) 88 | } 89 | } 90 | 91 | -------------------------------------------------------------------------------- /example/src/main/scala/bench.scala: -------------------------------------------------------------------------------- 1 | package com.todesking.unveil.benchmark 2 | 3 | import org.openjdk.jmh.annotations.{ Benchmark, State } 4 | 5 | object Main { 6 | def main(args: Array[String]): Unit = { 7 | (-10 to 10) foreach { i => 8 | val baseline = Bench.baseline(i) 9 | val standard = Bench.standardF(i) 10 | val fused = Bench.fusedF(i) 11 | val fuseInlined = Bench.fuseInlinedF(i) 12 | println((i, baseline, standard, fused, fuseInlined)) 13 | assert(baseline == standard) 14 | assert(baseline == fused) 15 | assert(baseline == fuseInlined) 16 | } 17 | } 18 | } 19 | 20 | object Bench { 21 | def f1_(x: Int) = x + 1 22 | def f2_(x: Int) = x + 10.0 23 | def f3_(x: Double) = (x * 100).toInt 24 | def f4_(x: Int) = x + 1.5 25 | def f5_(x: Double) = x * 0.01 26 | def f6_(x: Double) = x - 200.0 27 | def f7_(x: Double) = x.toInt 28 | def f8_(x: Int) = x + 10 29 | val f1 = f1_ _ 30 | val f2 = f2_ _ 31 | val f3 = f3_ _ 32 | val f4 = f4_ _ 33 | val f5 = f5_ _ 34 | val f6 = f6_ _ 35 | val f7 = f7_ _ 36 | val f8 = f8_ _ 37 | 38 | def F(x: Int) = f8_(f7_(f6_(f5_(f4_(f3_(f2_(f1_(x)))))))) 39 | 40 | val fastest: Int => Int = { x: Int => ((((((x + 1) + 10.0) * 100).toInt + 1.5) * 0.01) - 200.0).toInt + 10 } 41 | 42 | val baseline = { 43 | x: Int => F(F(F(F(x)))) 44 | } 45 | 46 | val standardF = { 47 | def F = f1 andThen f2 andThen f3 andThen f4 andThen f5 andThen f6 andThen f7 andThen f8 48 | F andThen F andThen F andThen F 49 | } 50 | 51 | val fusedF = { 52 | def F = f1 andThen f2 andThen f3 andThen f4 andThen f5 andThen f6 andThen f7 andThen f8 53 | val FF = F andThen F andThen F andThen F 54 | import com.todesking.unveil.{ Transformer, Instance } 55 | val el = Transformer.newEventLogger 56 | val i = Instance.of(FF).duplicate[Function1[Int, Int]](el) 57 | val ti = 58 | Transformer.fieldFusion(i, el) 59 | ti.get.materialize(el).value 60 | } 61 | val fuseInlinedF = { 62 | def F = f1 andThen f2 andThen f3 andThen f4 andThen f5 andThen f6 andThen f7 andThen f8 63 | val FF = F andThen F andThen F andThen F 64 | import com.todesking.unveil.{ Transformer, Instance } 65 | val el = Transformer.newEventLogger 66 | val i = Instance.of(FF).duplicate[Function1[Int, Int]](el) 67 | val ti = 68 | (Transformer.fieldFusion >>> Transformer.methodInlining)(i, el) 69 | ti.get.materialize(el).value 70 | } 71 | } 72 | 73 | class Bench { 74 | @Benchmark 75 | def fastest(): Any = { 76 | var x = 0 77 | (0 until 1000).foreach { i => x += Bench.fastest(i) } 78 | x 79 | } 80 | 81 | @Benchmark 82 | def baseline(): Any = { 83 | var x = 0 84 | (0 until 1000).foreach { i => x += Bench.baseline(i) } 85 | x 86 | } 87 | 88 | @Benchmark 89 | def standard(): Any = { 90 | var x = 0 91 | (0 until 1000).foreach { i => x += Bench.standardF(i) } 92 | x 93 | } 94 | 95 | @Benchmark 96 | def fused(): Any = { 97 | var x = 0 98 | (0 until 1000).foreach { i => x += Bench.fusedF(i) } 99 | x 100 | } 101 | 102 | @Benchmark 103 | def fuseInlined(): Any = { 104 | var x = 0 105 | (0 until 1000).foreach { i => x += Bench.fuseInlinedF(i) } 106 | x 107 | } 108 | } 109 | -------------------------------------------------------------------------------- /core/src/main/scala/MethodBody.scala: -------------------------------------------------------------------------------- 1 | package com.todesking.unveil 2 | 3 | import scala.language.existentials 4 | import scala.language.higherKinds 5 | 6 | import scala.reflect.{ classTag, ClassTag } 7 | import scala.collection.mutable 8 | 9 | import java.lang.reflect.{ Method => JMethod, Constructor => JConstructor } 10 | 11 | import com.todesking.scalapp.syntax._ 12 | 13 | case class MethodBody( 14 | isInit: Boolean, 15 | descriptor: MethodDescriptor, 16 | attribute: MethodAttribute, 17 | codeFragment: CodeFragment.Complete 18 | ) { 19 | private[this] def require(cond: Boolean, detail: => String = ""): Unit = 20 | MethodBody.require(this, cond, detail) 21 | 22 | require(codeFragment.nonEmpty) 23 | 24 | def bytecodeFromLabel(l: Bytecode.Label): Bytecode = 25 | codeFragment.bytecodeFromLabel(l) 26 | 27 | def jumpTargets: Map[(Bytecode.Label, JumpTarget), Bytecode.Label] = 28 | codeFragment.jumpTargets 29 | 30 | def bytecode: Seq[(Bytecode.Label, Bytecode)] = 31 | codeFragment.bytecode 32 | 33 | def isStatic: Boolean = attribute.isStatic 34 | 35 | def makePrivate: MethodBody = 36 | copy(attribute = attribute.makePrivate) 37 | 38 | def makeNonFinal: MethodBody = 39 | copy(attribute = attribute.makeNonFinal) 40 | 41 | // TODO: Exception handler 42 | 43 | def methodReferences: Set[(ClassRef, MethodRef)] = 44 | codeFragment.methodReferences 45 | 46 | def fieldReferences: Set[(ClassRef, FieldRef)] = 47 | codeFragment.fieldReferences 48 | 49 | def labelToBytecode: Map[Bytecode.Label, Bytecode] = 50 | codeFragment.labelToBytecode 51 | 52 | def rewrite(f: PartialFunction[(Bytecode.Label, Bytecode), Bytecode]): MethodBody = { 53 | val lifted = f.lift 54 | rewrite_* { 55 | case x @ (label, bc) if f.isDefinedAt(x) => CodeFragment.bytecode(f(x)) 56 | } 57 | } 58 | 59 | def rewrite_*(f: PartialFunction[(Bytecode.Label, Bytecode), CodeFragment]): MethodBody = 60 | copy(codeFragment = codeFragment.rewrite_*(f).complete()) 61 | 62 | def rewrite_**( 63 | f: PartialFunction[(Bytecode.Label, Bytecode), Map[Bytecode.Label, CodeFragment]] 64 | ): MethodBody = 65 | copy(codeFragment = codeFragment.rewrite_**(f).complete()) 66 | 67 | def rewriteClassRef(from: ClassRef, to: ClassRef): MethodBody = { 68 | rewrite { case (label, bc: Bytecode.HasClassRef) if bc.classRef == from => bc.rewriteClassRef(to) } 69 | } 70 | 71 | def pretty: String = 72 | s"""$descriptor [$attribute] 73 | ${codeFragment.pretty}""" 74 | 75 | // TODO: refactor 76 | def dataflow(self: Instance[_ <: AnyRef]): DataFlow = 77 | new DataFlow(this, self) 78 | 79 | def dataflow(klass: Klass): DataFlow = 80 | new DataFlow(this, new Instance.Given(klass, Map())) 81 | } 82 | 83 | object MethodBody { 84 | def parse(m: JMethod): MethodBody = 85 | Javassist.decompile(m).getOrElse { throw new MethodAnalyzeException(ClassRef.of(m.getDeclaringClass), MethodRef.from(m), "CA not found") } 86 | 87 | def parse(m: JConstructor[_]): MethodBody = 88 | Javassist.decompile(m).getOrElse { throw new MethodAnalyzeException(ClassRef.of(m.getDeclaringClass), MethodRef.from(m), "CA not found") } 89 | 90 | def require(body: MethodBody, cond: Boolean, detailMsg: => String): Unit = 91 | if (!cond) 92 | throw new UnveilBugException("BUG", null) { 93 | override val detail = body.pretty + "\n\n" + detailMsg 94 | } 95 | } 96 | 97 | -------------------------------------------------------------------------------- /core/src/main/scala/Data.scala: -------------------------------------------------------------------------------- 1 | package com.todesking.unveil 2 | 3 | import scala.language.existentials 4 | import scala.language.higherKinds 5 | 6 | sealed abstract class Data { 7 | def typeRef: TypeRef 8 | def value: Option[Any] 9 | def valueString: String 10 | def isInstance(instance: Instance[_ <: AnyRef]): Boolean = false 11 | 12 | def merge(rhs: Data): Data = Data.merge(this, rhs) 13 | 14 | def secondWordData: Data = 15 | if (!typeRef.isDoubleWord) throw new IllegalArgumentException() 16 | else Data.Unknown(TypeRef.SecondWord) 17 | 18 | override def toString: String = s"""${typeRef} = ${valueString}""" 19 | } 20 | object Data { 21 | def merge(d1: Data, d2: Data): Data = { 22 | if (d1 eq d2) d1 23 | else (d1, d2) match { 24 | case (Unknown(t), rhs) => 25 | Unknown(TypeRef.common(t, rhs.typeRef)) 26 | case (lhs, Unknown(t)) => 27 | Unknown(TypeRef.common(t, lhs.typeRef)) 28 | case (ConcretePrimitive(t1, v1), ConcretePrimitive(t2, v2)) if t1 == t2 && v1 == v2 => 29 | ConcretePrimitive(t1, v1) 30 | case (ConcreteReference(i1), ConcreteReference(i2)) if i1 == i2 => 31 | ConcreteReference(i1) 32 | case (AbstractReference(i1), AbstractReference(i2)) if i1 == i2 => 33 | AbstractReference(i1) 34 | case (Null, Null) => Null 35 | case (d1, d2) => 36 | Unknown(TypeRef.common(d1.typeRef, d2.typeRef)) 37 | } 38 | } 39 | 40 | def reference(i: Instance[_ <: AnyRef]): Reference = i match { 41 | case i: Instance.Concrete[_] => ConcreteReference(i) 42 | case i: Instance.Abstract[_] => AbstractReference(i) 43 | } 44 | 45 | sealed abstract class Known extends Data { 46 | } 47 | 48 | case class Unknown(override val typeRef: TypeRef) extends Data with Equality.Reference { 49 | override def valueString = "???" 50 | override def value = None 51 | } 52 | 53 | case class UnknownReference( 54 | klass: Klass, 55 | fieldValues: Map[(ClassRef, FieldRef), Data] 56 | ) extends Data with Equality.Reference { 57 | override def typeRef = klass.ref.toTypeRef 58 | override def valueString = "???" 59 | override def value = None 60 | } 61 | 62 | sealed abstract class Concrete extends Known { 63 | def concreteValue: Any 64 | } 65 | 66 | case class Uninitialized(classRef: ClassRef) extends Known with Equality.Reference { 67 | override def typeRef = classRef.toTypeRef 68 | override val valueString = s"new $typeRef(uninitialized)" 69 | override val value = None 70 | } 71 | case class ConcretePrimitive(override val typeRef: TypeRef.Primitive, override val concreteValue: AnyVal) extends Concrete { 72 | // TODO: require(typeRef.isValue(concreteValue)) 73 | override def valueString = concreteValue.toString 74 | override def value = Some(concreteValue) 75 | } 76 | 77 | case object Null extends Concrete { 78 | override def valueString = "null" 79 | override val typeRef = TypeRef.Null 80 | override val concreteValue = null 81 | override def value = Some(concreteValue) 82 | } 83 | 84 | sealed trait Reference extends Known with Equality.Delegate { 85 | override def canEqual(rhs: Any) = rhs.isInstanceOf[Reference] 86 | override def equalityObject = instance 87 | 88 | val instance: Instance[_ <: AnyRef] 89 | override def typeRef = instance.thisRef.toTypeRef 90 | override def value = instance.valueOption 91 | override def valueString = instance.toString 92 | } 93 | 94 | case class ConcreteReference(override val instance: Instance.Concrete[_ <: AnyRef]) extends Concrete with Reference { 95 | override def concreteValue = instance.value 96 | } 97 | 98 | case class AbstractReference(override val instance: Instance.Abstract[_ <: AnyRef]) extends Known with Reference { 99 | } 100 | } 101 | 102 | -------------------------------------------------------------------------------- /core/src/main/scala/ClassRef.scala: -------------------------------------------------------------------------------- 1 | package com.todesking.unveil 2 | 3 | import scala.language.existentials 4 | 5 | import java.util.Objects 6 | 7 | sealed abstract class ClassRef { 8 | def pretty: String 9 | override def toString = pretty 10 | def name: String 11 | def classLoader: ClassLoader 12 | def binaryName: String = name.replaceAll("\\.", "/") 13 | 14 | def <(rhs: ClassRef): Boolean = 15 | ClassRef.compare(this, rhs).map { case -1 => true; case 0 => false; case 1 => false } getOrElse false 16 | 17 | def >(rhs: ClassRef): Boolean = 18 | ClassRef.compare(this, rhs).map { case -1 => false; case 0 => false; case 1 => true } getOrElse false 19 | 20 | def >=(rhs: ClassRef): Boolean = 21 | this > rhs || this == rhs 22 | 23 | def <=(rhs: ClassRef): Boolean = 24 | this < rhs || this == rhs 25 | 26 | def toTypeRef: TypeRef.Reference = TypeRef.Reference(this) 27 | def renamed(newName: String): ClassRef 28 | 29 | override def hashCode = Objects.hashCode(name) ^ Objects.hashCode(classLoader) 30 | 31 | override def equals(obj: Any) = 32 | obj match { 33 | case that: ClassRef => 34 | this.name == that.name && this.classLoader == that.classLoader 35 | case _ => 36 | false 37 | } 38 | } 39 | object ClassRef { 40 | // Some(n): Determinable 41 | // None: Not sure 42 | def compare(lhs: ClassRef, rhs: ClassRef): Option[Int] = (lhs, rhs) match { 43 | case (l: Concrete, r: Concrete) => 44 | if (l.loadClass == r.loadClass) Some(0) 45 | else if (l.loadClass.isAssignableFrom(r.loadClass)) Some(1) 46 | else if (r.loadClass.isAssignableFrom(l.loadClass)) Some(-1) 47 | else None 48 | case (l: Concrete, r: Extend) => 49 | if (l.loadClass.isAssignableFrom(r.superClassRef.loadClass)) Some(1) 50 | else None 51 | case (l: Extend, r: Concrete) => 52 | if (r.loadClass.isAssignableFrom(l.superClassRef.loadClass)) Some(-1) 53 | else None 54 | case (l: Extend, r: Extend) => 55 | None 56 | } 57 | 58 | val Object: ClassRef.Concrete = of(classOf[java.lang.Object]) 59 | 60 | case class Concrete(override val name: String, override val classLoader: ClassLoader) extends ClassRef { 61 | override def pretty = s"${name}@${System.identityHashCode(classLoader)}" 62 | // TODO: Is this really correct? 63 | lazy val loadClass: Class[_] = 64 | (if (classLoader == null) ClassLoader.getSystemClassLoader else classLoader).loadClass(name) 65 | def loadKlass: Klass.Native = Klass.from(loadClass) 66 | 67 | def extend(name: String, cl: AccessibleClassLoader): Extend = { 68 | if (loadClass.isInterface) Extend(ClassRef.Object, name, cl, Seq(loadClass)) 69 | else Extend(this, name, cl, Seq.empty) 70 | } 71 | 72 | // TODO: preserve package name 73 | def extend(cl: AccessibleClassLoader): Extend = 74 | extend(uniqueNamer("generated"), cl) 75 | 76 | override def renamed(newName: String): Concrete = 77 | copy(name = newName) 78 | } 79 | 80 | // TODO: interface 81 | case class Extend( 82 | superClassRef: ClassRef.Concrete, 83 | override val name: String, 84 | override val classLoader: AccessibleClassLoader, 85 | interfaces: Seq[Class[_]] 86 | ) extends ClassRef { 87 | override def pretty = s"${name}(<:${superClassRef.name} ${interfaces.map(_.getName).mkString(", ")})@${System.identityHashCode(classLoader)}" 88 | def anotherUniqueName: Extend = 89 | copy(name = uniqueNamer(name)) 90 | override def renamed(newName: String): Extend = 91 | copy(name = newName) 92 | } 93 | 94 | def of(klass: Class[_]): Concrete = 95 | ClassRef.Concrete(klass.getName, klass.getClassLoader) 96 | 97 | def of(name: String, cl: ClassLoader): Concrete = 98 | of((if (cl == null) ClassLoader.getSystemClassLoader else cl).loadClass(name)) 99 | 100 | private[this] val uniqueNamer = new UniqueNamer 101 | } 102 | -------------------------------------------------------------------------------- /core/src/main/scala/TypeRef.scala: -------------------------------------------------------------------------------- 1 | package com.todesking.unveil 2 | 3 | sealed abstract class TypeRef { 4 | def isDoubleWord: Boolean 5 | def wordSize: Int = if (isDoubleWord) 2 else 1 6 | def isAssignableFrom(rhs: TypeRef): Boolean = 7 | (this == rhs) || ((this, rhs) match { 8 | case (l: TypeRef.Reference, r: TypeRef.Reference) if l.classRef >= r.classRef => true 9 | case (l: TypeRef.Reference, TypeRef.Null) => true 10 | case _ => false 11 | }) 12 | } 13 | object TypeRef { 14 | def parse(src: String, cl: ClassLoader): TypeRef.Public = 15 | Parsers.parseTypeRef(src, cl) 16 | 17 | def from(c: Class[_]): Public = { 18 | if (c == java.lang.Integer.TYPE) Int 19 | else if (c == Long.javaClass) Long 20 | else if (c == Char.javaClass) Char 21 | else if (c == Byte.javaClass) Byte 22 | else if (c == Boolean.javaClass) Boolean 23 | else if (c == Short.javaClass) Short 24 | else if (c == Float.javaClass) Float 25 | else if (c == Double.javaClass) Double 26 | else if (c == Void.javaClass) Void 27 | else if (c.isArray) throw new UnveilException("NIMPL") // ??? 28 | else Reference(ClassRef.of(c)) 29 | } 30 | 31 | def common(t1: TypeRef, t2: TypeRef): TypeRef = 32 | (t1, t2) match { 33 | case (t1, t2) if t1 == t2 => t1 34 | case (Undefined, _) => Undefined 35 | case (_, Undefined) => Undefined 36 | case (Null, ref @ Reference(_)) => ref 37 | case (ref @ Reference(_), Null) => ref 38 | case (r1 @ Reference(_), r2 @ Reference(_)) => 39 | ??? 40 | case (_: Primitive, _: Primitive) => Undefined 41 | case (SecondWord, _) => Undefined 42 | case (_, SecondWord) => Undefined 43 | } 44 | 45 | sealed trait SingleWord extends TypeRef { 46 | override def isDoubleWord = false 47 | } 48 | sealed trait DoubleWord extends TypeRef { 49 | override def isDoubleWord = true 50 | } 51 | 52 | case object Undefined extends TypeRef with SingleWord { 53 | override def toString = "[undefined]" 54 | } 55 | case object SecondWord extends TypeRef with SingleWord { 56 | override def toString = "[second word]" 57 | } 58 | case object Null extends TypeRef with SingleWord { 59 | override def toString = "[null]" 60 | } 61 | 62 | sealed abstract class Public extends TypeRef { 63 | def str: String 64 | def javaClass: Class[_] 65 | def defaultValue: Any 66 | } 67 | 68 | sealed abstract class Primitive( 69 | override val toString: String, 70 | override val str: String, 71 | override val javaClass: Class[_], 72 | override val defaultValue: Any 73 | ) extends Public 74 | 75 | case object Byte extends Primitive("int", "B", java.lang.Byte.TYPE, 0.toByte) with SingleWord 76 | case object Boolean extends Primitive("bool", "Z", java.lang.Boolean.TYPE, false) with SingleWord 77 | case object Char extends Primitive("char", "C", java.lang.Character.TYPE, '\u0000') with SingleWord 78 | case object Short extends Primitive("short", "S", java.lang.Short.TYPE, 0.toShort) with SingleWord 79 | case object Int extends Primitive("int", "I", java.lang.Integer.TYPE, 0) with SingleWord 80 | case object Float extends Primitive("float", "F", java.lang.Float.TYPE, 0.0f) with SingleWord 81 | case object Long extends Primitive("long", "J", java.lang.Long.TYPE, 0L) with DoubleWord 82 | case object Double extends Primitive("double", "D", java.lang.Double.TYPE, 0.0) with DoubleWord 83 | case object Void extends Primitive("void", "V", java.lang.Void.TYPE, null) { 84 | override def isDoubleWord = false 85 | override def wordSize = 0 86 | } 87 | 88 | case class Reference(classRef: ClassRef) extends Public with SingleWord { 89 | override def str = s"L${classRef.binaryName};" 90 | override def toString = classRef.toString 91 | override def defaultValue = null 92 | // TODO: It smells.. 93 | override def javaClass = classRef match { 94 | case c: ClassRef.Concrete => c.loadClass 95 | case c: ClassRef.Extend => throw new IllegalStateException() 96 | } 97 | } 98 | 99 | val Object: Reference = Reference(ClassRef.Object) 100 | } 101 | -------------------------------------------------------------------------------- /core/src/main/scala/DataSource.scala: -------------------------------------------------------------------------------- 1 | package com.todesking.unveil 2 | 3 | sealed abstract class DataSource { 4 | def merge(rhs: DataSource): DataSource 5 | def is(rhs: DataSource.Single): Option[Boolean] 6 | def may(rhs: DataSource.Single): Boolean = 7 | is(rhs) match { 8 | case Some(true) | None => true 9 | case Some(false) => false 10 | } 11 | def must(rhs: DataSource.Single): Boolean = 12 | is(rhs) match { 13 | case Some(true) => true 14 | case Some(false) => false 15 | case None => throw new RuntimeException("ambigious") 16 | } 17 | def mayProducedBy(l: Bytecode.Label, p: DataPort.Out): Boolean = 18 | producedBy(l, p) != Some(false) 19 | def producedBy(l: Bytecode.Label, p: DataPort.Out): Option[Boolean] 20 | def unambiguous: Boolean 21 | def mayFieldAccess(cr: ClassRef, fr: FieldRef): Boolean 22 | def single: Option[DataSource.Single] 23 | } 24 | object DataSource { 25 | case class Multiple(sources: Set[DataSource.Single]) extends DataSource { 26 | require(sources.nonEmpty) 27 | override def merge(rhs: DataSource): DataSource = rhs match { 28 | case Multiple(ss) => Multiple(sources ++ ss) 29 | case s: Single => Multiple(sources + s) 30 | } 31 | override def is(rhs: DataSource.Single): Option[Boolean] = 32 | if (sources.contains(rhs)) { 33 | if (sources.size == 1) Some(true) 34 | else None 35 | } else { 36 | Some(false) 37 | } 38 | override def unambiguous = sources.size == 1 39 | override def mayFieldAccess(cr: ClassRef, fr: FieldRef) = 40 | sources.exists(_.mayFieldAccess(cr, fr)) 41 | override def single = if(sources.size == 1) Some(sources.head) else None 42 | override def producedBy(l: Bytecode.Label, p: DataPort.Out): Option[Boolean] = { 43 | val parts = sources.map(_.producedBy(l, p)) 44 | if(parts.forall(_ == Some(true))) Some(true) 45 | else if(parts.forall(_ == Some(false))) Some(false) 46 | else None 47 | } 48 | } 49 | 50 | sealed abstract class Single extends DataSource { 51 | override def merge(rhs: DataSource): DataSource = rhs match { 52 | case Multiple(sources) => Multiple(sources + this) 53 | case s: Single => Multiple(Set(this, s)) 54 | } 55 | override def is(rhs: DataSource.Single): Option[Boolean] = 56 | Some(this == rhs) 57 | override def unambiguous = true 58 | override def mayFieldAccess(cr: ClassRef, fr: FieldRef) = false 59 | override def single = Some(this) 60 | override def producedBy(l: Bytecode.Label, p: DataPort.Out): Option[Boolean] = Some(false) 61 | } 62 | 63 | sealed trait HasLocation extends Single { 64 | def label: Bytecode.Label 65 | def port: DataPort.Out 66 | final override def producedBy(l: Bytecode.Label, p: DataPort.Out): Option[Boolean] = Some(l == label && p == port) 67 | } 68 | object HasLocation { 69 | def unapply(s: DataSource): Option[(Bytecode.Label, DataPort.Out)] = s match { 70 | case s: HasLocation => Some((s.label, s.port)) 71 | case s => None 72 | } 73 | } 74 | 75 | case object This extends Single 76 | case class Argument(n: Int) extends Single 77 | case object Constant extends Single 78 | 79 | case class InstanceField( 80 | override val label: Bytecode.Label, 81 | override val port: DataPort.Out, 82 | target: Data, 83 | classRef: ClassRef, 84 | fieldRef: FieldRef 85 | ) extends HasLocation { 86 | override def mayFieldAccess(cr: ClassRef, fr: FieldRef) = 87 | cr == classRef && fr == fieldRef 88 | } 89 | case class StaticField( 90 | override val label: Bytecode.Label, 91 | override val port: DataPort.Out, 92 | classRef: ClassRef, 93 | fieldRef: FieldRef 94 | ) extends HasLocation { 95 | override def mayFieldAccess(cr: ClassRef, fr: FieldRef) = 96 | cr == classRef && fr == fieldRef 97 | } 98 | case class New( 99 | override val label: Bytecode.Label, 100 | override val port: DataPort.Out 101 | ) extends HasLocation 102 | case class MethodInvocation( 103 | override val label: Bytecode.Label, 104 | override val port: DataPort.Out 105 | ) extends HasLocation 106 | case class Constant( 107 | override val label: Bytecode.Label, 108 | override val port: DataPort.Out, 109 | data: Data.Concrete 110 | ) extends HasLocation 111 | case class Generic( 112 | override val label: Bytecode.Label, 113 | override val port: DataPort.Out 114 | ) extends HasLocation 115 | } 116 | -------------------------------------------------------------------------------- /core/src/main/scala/EventLogger.scala: -------------------------------------------------------------------------------- 1 | package com.todesking.unveil 2 | 3 | import scala.language.existentials 4 | 5 | import scala.collection.mutable 6 | 7 | import EventLogger.{ Event, Path } 8 | 9 | class EventLogger { 10 | private[this] val eventBuffer = mutable.ArrayBuffer.empty[Event] 11 | 12 | private[this] def withSubEL[A](path: Path)(f: EventLogger => A): A = { 13 | val el = new EventLogger 14 | val ret = 15 | try { 16 | f(el) 17 | } catch { 18 | case e: Throwable => 19 | el.fail(e) 20 | eventBuffer += Event.Grouped(path, el.events) 21 | throw e 22 | } 23 | eventBuffer += Event.Grouped(path, el.events) 24 | ret 25 | } 26 | 27 | def clear(): Unit = eventBuffer.clear() 28 | 29 | def events: Seq[Event] = eventBuffer.toSeq 30 | 31 | def apply[A](f: EventLogger => A): A = 32 | try { 33 | f(this) 34 | } catch { 35 | case e: Throwable => 36 | eventBuffer += Event.Fail(e) 37 | throw e 38 | } 39 | 40 | def enterField[A](cr: ClassRef, fr: FieldRef)(f: EventLogger => A): A = 41 | withSubEL(Path.Field(cr, fr))(f) 42 | def enterMethod[A](cr: ClassRef, mr: MethodRef)(f: EventLogger => A): A = 43 | withSubEL(Path.Method(cr, mr))(f) 44 | 45 | def enterTransformer[A](t: Transformer, i: Instance[_ <: AnyRef])(f: EventLogger => A): A = 46 | withSubEL(Path.Transformer(t, i))(f) 47 | 48 | def section[A](title: String)(f: EventLogger => A): A = 49 | withSubEL(Path.Section(title))(f) 50 | 51 | def log(msg: String): Unit = 52 | eventBuffer += Event.Message(msg) 53 | 54 | def logCFields(desc: String, cfs: Iterable[(ClassRef, FieldRef)]): Unit = 55 | eventBuffer += Event.CFields(desc, cfs.toSeq) 56 | 57 | def logCMethods(desc: String, cms: Iterable[(ClassRef, MethodRef)]): Unit = 58 | eventBuffer += Event.CMethods(desc, cms.toSeq) 59 | 60 | def logMethods(desc: String, ms: Iterable[MethodRef]): Unit = 61 | eventBuffer += Event.Methods(desc, ms.toSeq) 62 | 63 | def logFields(desc: String, fs: Iterable[FieldRef]): Unit = 64 | eventBuffer += Event.Fields(desc, fs.toSeq) 65 | 66 | def logFieldValues(desc: String, fvs: Iterable[((ClassRef, FieldRef), Data)]): Unit = 67 | eventBuffer += Event.FieldValues(desc, fvs.toSeq) 68 | 69 | def fail(e: Throwable): Unit = 70 | eventBuffer += Event.Fail(e) 71 | 72 | def pretty(): String = pretty(0, events).split("\n").map { line => 73 | val indent = (line.size - line.replaceAll("^ +", "").size) / 2 74 | f"$indent%1X$line" 75 | }.mkString("\n") 76 | 77 | private[this] def pretty(indent: Int, evs: Seq[Event]): String = { 78 | evs 79 | .map(prettyEvent) 80 | .flatMap(_.split("\n")) 81 | .map { s => " " * indent + s } 82 | .mkString("\n") 83 | } 84 | 85 | // TODO[refactor]: fields/methods 86 | private[this] def prettyEvent(event: Event): String = event match { 87 | case Event.Message(msg) => 88 | msg 89 | case Event.Fail(e) => 90 | s"FAIL: $e" 91 | case Event.Grouped(path, evs) => 92 | prettyPath(path) + "\n" + pretty(1, evs) 93 | case Event.CFields(desc, cfs) => 94 | s"$desc =" + ( 95 | if (cfs.isEmpty) "" 96 | else cfs.map { case (cr, fr) => s"- $cr\n .$fr" }.mkString("\n", "\n", "") 97 | ) 98 | case Event.CMethods(desc, cms) => 99 | s"$desc =" + ( 100 | if (cms.isEmpty) "" 101 | else cms.map { case (cr, mr) => s"- $cr\n .$mr" }.mkString("\n", "\n", "") 102 | ) 103 | case Event.Methods(desc, ms) => 104 | s"$desc =" + ( 105 | if (ms.isEmpty) "" 106 | else ms.map { case mr => s"- $mr" }.mkString("\n", "\n", "") 107 | ) 108 | case Event.Fields(desc, fs) => 109 | s"$desc =" + ( 110 | if (fs.isEmpty) "" 111 | else fs.map { case fr => s"- $fr" }.mkString("\n", "\n", "") 112 | ) 113 | case Event.FieldValues(desc, fvs) => 114 | s"$desc =" + ( 115 | if (fvs.isEmpty) "" 116 | else fvs.map { case ((cr, fr), v) => s"- $cr\n .$fr = $v" }.mkString("\n", "\n", "") 117 | ) 118 | } 119 | 120 | private[this] def prettyPath(path: Path) = path match { 121 | case Path.Section(title) => 122 | s"SECTION: $title" 123 | case Path.Field(cr, fr) => 124 | s"""ENTERING FIELD: ${fr.name} 125 | | class = $cr 126 | | field = $fr""".stripMargin('|') 127 | case Path.Method(cr, mr) => 128 | s"""ENTERING METHOD: ${mr.name} 129 | | class = $cr 130 | | method = $mr""".stripMargin('|') 131 | case Path.Transformer(t, i) => 132 | s"""APPLYING TRANSFORMER: ${t.name}""" + ( 133 | if (t.params.isEmpty) "" 134 | else t.params.map { case (k, v) => s" $k = $v" }.mkString("\n", "\n", "") 135 | ) + s"\n instance = ${i.thisRef}" 136 | } 137 | } 138 | 139 | object EventLogger { 140 | sealed abstract class Path 141 | object Path { 142 | case class Field(classRef: ClassRef, fieldRef: FieldRef) extends Path 143 | case class Method(classRef: ClassRef, methodRef: MethodRef) extends Path 144 | case class Transformer(transformer: com.todesking.unveil.Transformer, instance: Instance[_ <: AnyRef]) extends Path 145 | case class Section(title: String) extends Path 146 | } 147 | sealed abstract class Event 148 | object Event { 149 | case class Message(message: String) extends Event 150 | case class Fail(e: Throwable) extends Event 151 | case class Grouped(path: Path, events: Seq[Event]) extends Event 152 | case class CFields(desc: String, cfs: Seq[(ClassRef, FieldRef)]) extends Event 153 | case class CMethods(desc: String, cfs: Seq[(ClassRef, MethodRef)]) extends Event 154 | case class Methods(desc: String, ms: Seq[MethodRef]) extends Event 155 | case class Fields(desc: String, fs: Seq[FieldRef]) extends Event 156 | case class FieldValues(desc: String, fvs: Seq[((ClassRef, FieldRef), Data)]) extends Event 157 | } 158 | 159 | } 160 | -------------------------------------------------------------------------------- /core/src/main/scala/Analyze.scala: -------------------------------------------------------------------------------- 1 | package com.todesking.unveil 2 | 3 | import java.lang.reflect.{ Constructor, Method => JMethod } 4 | import scala.util.{ Try, Success, Failure } 5 | import scala.collection.mutable 6 | 7 | // TODO: Analyzer 8 | object Analyze { 9 | // TODO: add classRef 10 | class SetterConstructor( 11 | val descriptor: MethodDescriptor, 12 | val superConstructor: Option[SetterConstructor], 13 | val constantAssigns0: Map[(ClassRef, FieldRef), Any], 14 | val argumentAssigns0: Map[(ClassRef, FieldRef), Int] 15 | ) { 16 | def methodRef: MethodRef = MethodRef.constructor(descriptor) 17 | def constantAssigns: Map[(ClassRef, FieldRef), Any] = 18 | superConstructor.map(_.constantAssigns).getOrElse(Map.empty) ++ constantAssigns0 19 | def argumentAssigns: Map[(ClassRef, FieldRef), Int] = 20 | superConstructor.map(_.argumentAssigns).getOrElse(Map.empty) ++ argumentAssigns0 // TODO: WRONG 21 | def toArguments(fields: Map[(ClassRef, FieldRef), Field]): Seq[Any] = { 22 | require(assignable(fields)) 23 | descriptor.args.zipWithIndex map { 24 | case (t, i) => 25 | argumentAssigns.find(_._2 == i).map { 26 | case (k, v) => 27 | fields(k).data.concreteValue 28 | } getOrElse { 29 | t.defaultValue 30 | } 31 | } 32 | } 33 | 34 | def sameArgumentValues: Seq[Set[(ClassRef, FieldRef)]] = 35 | argumentAssigns.groupBy(_._2).map { case (i, vs) => vs.map(_._1).toSet }.toSeq 36 | 37 | def assignable(fields: Map[(ClassRef, FieldRef), Field]): Boolean = { 38 | fields forall { 39 | case ((cr, fr), f1) => 40 | constantAssigns.get(cr -> fr).map { v2 => 41 | isSameValue(fr.typeRef, f1.data.concreteValue, v2) 42 | } getOrElse { 43 | argumentAssigns.contains(cr -> fr) && 44 | sameArgumentValues.forall { s => 45 | !s.contains(cr -> fr) || sameArgumentValues.forall { s => 46 | fields 47 | .toSeq 48 | .filter { case (k, f) => s.contains(k) } 49 | .sliding(2) 50 | .forall { 51 | case Seq(((cr1, fr1), f1), ((cr2, fr2), f2)) => 52 | isSameValue(fr1.typeRef, f1.data.concreteValue, f2.data.concreteValue) 53 | case Seq(_) => true 54 | } 55 | } 56 | } 57 | } 58 | } 59 | } 60 | private[this] def isSameValue(t: TypeRef, v1: Any, v2: Any): Boolean = 61 | t match { 62 | case t: TypeRef.Primitive => v1 == v2 63 | case _ => ??? 64 | } 65 | override def toString = 66 | s"""SetterConstructor(${descriptor}, ${constantAssigns}, ${argumentAssigns})""" 67 | 68 | def pretty = s"""SetterConstructor 69 | | Values from argument: 70 | |${argumentAssigns.map { case ((cr, fr), i) => f" ${i}%5d => ${cr}.${fr}" }.mkString("\n")} 71 | | Values from constant: 72 | |${constantAssigns.map { case ((cr, fr), v) => f" ${v}%5s => ${cr}.${fr}" }.mkString("\n")} 73 | """.stripMargin 74 | } 75 | object SetterConstructor { 76 | def from(klass: Klass.Native, body: MethodBody): Try[SetterConstructor] = { 77 | def makeError(msg: String) = 78 | new MethodAnalyzeException(klass.ref, MethodRef.constructor(body.descriptor), msg) 79 | val df = body.dataflow(klass) 80 | import Bytecode._ 81 | try { 82 | var superConstructor: Option[SetterConstructor] = None 83 | val constAssigns = mutable.HashMap.empty[(ClassRef, FieldRef), Any] 84 | val argAssigns = mutable.HashMap.empty[(ClassRef, FieldRef), Int] 85 | body.bytecode.foreach { 86 | case (label, bc) if df.possibleReturns(label).isEmpty => 87 | // ignore error path 88 | case (label, bc: Shuffle) => 89 | case (label, bc: Jump) => 90 | case (label, bc: Return) => 91 | case (label, bc: ConstX) => 92 | case (label, bc: Branch) if df.possibleReturns(body.jumpTargets(label -> bc.jumpTarget)).isEmpty || df.possibleReturns(df.fallThroughs(label)).isEmpty => 93 | // OK if one of jumps exit by throw 94 | case (label, bc @ invokespecial(classRef, methodRef)) if df.isThis(label, bc.objectref).getOrElse(false) && methodRef.isInit => 95 | // super ctor invocation 96 | if (superConstructor.nonEmpty) 97 | throw makeError(s"Another constructor called twice in ${klass.ref}.${body.descriptor}") 98 | superConstructor = 99 | SetterConstructor.from(klass, klass.methodBody(classRef, methodRef)).map(Some(_)).get 100 | case (label, bc @ putfield(classRef, fieldRef)) if df.isThis(label, bc.objectref).getOrElse(false) => 101 | df.dataValue(label, bc.value).value.map { v => 102 | // value from constant 103 | constAssigns += (classRef -> fieldRef) -> v 104 | } getOrElse { 105 | df.argNum(label, bc.value).fold { 106 | throw makeError(s"putfield non-argument/constant value(${df.dataValue(label, bc.value)}) is not acceptable: ${bc}") 107 | } { i => 108 | argAssigns += (classRef -> fieldRef) -> i 109 | } 110 | } 111 | case bc => 112 | throw makeError(s"Bytecode ${bc} is not acceptable in setter constructor") 113 | } 114 | Success(new SetterConstructor(body.descriptor, superConstructor, constAssigns.toMap, argAssigns.toMap)) 115 | } catch { 116 | case e: UnveilException => Failure(e) 117 | } 118 | } 119 | } 120 | 121 | def setterConstructorsTry(klass: Klass.Native): Seq[Try[SetterConstructor]] = { 122 | klass.ref.loadClass 123 | .getDeclaredConstructors 124 | .filterNot { c => MethodAttribute.Private.enabledIn(c.getModifiers) } 125 | .map { c => MethodRef.from(c) } 126 | .map { mr => SetterConstructor.from(klass, klass.methodBody(klass.ref, mr)) } 127 | } 128 | 129 | def setterConstructors(klass: Klass.Native): Seq[SetterConstructor] = 130 | setterConstructorsTry(klass) 131 | .collect { case Success(sc) => sc } 132 | 133 | def findSetterConstructor[A]( 134 | klass: Klass.Native, 135 | fields: Map[(ClassRef, FieldRef), Field] 136 | ): Option[SetterConstructor] = { 137 | setterConstructors(klass) 138 | .filter { _.assignable(fields) } 139 | .headOption 140 | } 141 | } 142 | 143 | -------------------------------------------------------------------------------- /core/src/main/scala/ClassCompiler.scala: -------------------------------------------------------------------------------- 1 | package com.todesking.unveil 2 | 3 | import java.lang.reflect.Modifier 4 | 5 | import scala.collection.JavaConversions._ 6 | 7 | class ClassCompiler(klass: Klass.Modified, fieldValues: Map[(ClassRef, FieldRef), Data.Concrete], el: EventLogger) { 8 | lazy val superClass: Class[_] = klass.ref.superClassRef.loadClass 9 | 10 | lazy val thisFieldsSeq: Seq[(FieldRef, Data.Concrete)] = 11 | klass.declaredFields.keys.map { fr => fr -> fieldValues(klass.ref, fr) }.toSeq 12 | 13 | lazy val superFields = 14 | klass.`super`.instanceFieldAttributes.map { case (k, a) => k -> Field(k._2.descriptor, a, fieldValues(k)) } 15 | 16 | lazy val superConstructor: Analyze.SetterConstructor = { 17 | // TODO: refactor 18 | el.section(s"Find super ctor from ${klass.ref.superClassRef}") { el => 19 | val ctors = Analyze.setterConstructorsTry(klass.ref.superClassRef.loadKlass) 20 | import scala.util.{ Success, Failure } 21 | el.log(s"${ctors.size} ctor candidate found") 22 | ctors.foreach { 23 | case Success(ctor) => 24 | el.log(s"Setter ctor found: ${ctor.descriptor}") 25 | case Failure(e) => el.log(s"Setter ctor unmatch: ${e}") 26 | } 27 | Analyze.findSetterConstructor(klass.ref.superClassRef.loadKlass, superFields) getOrElse { 28 | throw new TransformException(s"Usable constructor not found") 29 | } 30 | } 31 | } 32 | lazy val superConstructorArgs: Seq[Any] = superConstructor.toArguments(superFields) 33 | lazy val constructorArgs: Seq[(TypeRef.Public, Any)] = 34 | thisFieldsSeq 35 | .map { case (fr, data) => (fr.descriptor.typeRef -> data.concreteValue) } ++ 36 | superConstructor.descriptor.args.zip(superConstructorArgs) 37 | 38 | lazy val constructorDescriptor = MethodDescriptor(TypeRef.Void, constructorArgs.map(_._1)) 39 | lazy val constructorBody: MethodBody = { 40 | val thisFieldAssigns: Seq[(FieldRef, Int)] = 41 | thisFieldsSeq.zipWithIndex.map { case ((fr, f), i) => fr -> (i + 1) } 42 | import Bytecode._ 43 | MethodBody( 44 | true, 45 | descriptor = constructorDescriptor, 46 | MethodAttribute.Public, 47 | codeFragment = CodeFragment.bytecode( 48 | Seq( 49 | Seq(aload(0)), 50 | superConstructor.descriptor.args.zipWithIndex.map { 51 | case (t, i) => 52 | autoLoad(t, i + thisFieldAssigns.size + 1) 53 | }, 54 | Seq( 55 | invokespecial( 56 | ClassRef.of(superClass), 57 | superConstructor.methodRef 58 | ) 59 | ) 60 | ).flatten ++ thisFieldAssigns.flatMap { 61 | case (fr, i) => 62 | import Bytecode._ 63 | Seq( 64 | aload(0), 65 | autoLoad(fr.descriptor.typeRef, i), 66 | putfield(klass.ref, fr) 67 | ) 68 | }.toSeq ++ Seq(vreturn()) 69 | : _* 70 | ) 71 | ) 72 | } 73 | 74 | def compile(): Klass.MaterializedNative = { 75 | import javassist.{ ClassPool, ClassClassPath, CtClass, CtMethod, CtField, CtConstructor, ByteArrayClassPath } 76 | import javassist.bytecode.{ Bytecode => JABytecode, MethodInfo } 77 | 78 | import Javassist.ctClass 79 | 80 | el.section("ClassCompiler.compile") { el => 81 | validate() 82 | 83 | el.log(s"compiling ${klass.ref}") 84 | el.logFieldValues("Field values", fieldValues) 85 | 86 | val classLoader = klass.ref.classLoader 87 | 88 | val classPool = new ClassPool(null) 89 | Instance.findMaterializedClasses(classLoader).foreach { 90 | case (name, bytes) => 91 | classPool.appendClassPath(new ByteArrayClassPath(name, bytes)) 92 | } 93 | classPool.appendClassPath(new ClassClassPath(superClass)) 94 | 95 | val ctBase = classPool.get(superClass.getName) 96 | 97 | val jClass = classPool.makeClass(klass.ref.name, ctBase) 98 | jClass.setModifiers(jClass.getModifiers() | Modifier.PUBLIC) 99 | klass.ref.interfaces.foreach { i => 100 | jClass.addInterface(classPool.get(i.getName)) 101 | } 102 | val constPool = jClass.getClassFile.getConstPool 103 | val ctObject = classPool.get("java.lang.Object") 104 | import Bytecode._ 105 | klass.declaredMethods 106 | .foreach { 107 | case (ref, body) => 108 | val codeAttribute = Javassist.compile(classPool, constPool, body.dataflow(klass)) 109 | val minfo = new MethodInfo(constPool, ref.name, ref.descriptor.str) 110 | minfo.setCodeAttribute(codeAttribute) 111 | val sm = javassist.bytecode.stackmap.MapMaker.make(classPool, minfo) 112 | codeAttribute.setAttribute(sm) 113 | minfo.setAccessFlags(body.attribute.toInt) 114 | jClass.getClassFile.addMethod(minfo) 115 | } 116 | 117 | klass.declaredFields.foreach { 118 | case (ref, attr) => 119 | val ctf = new CtField(ctClass(ref.descriptor.typeRef), ref.name, jClass) 120 | ctf.setModifiers(attr.toInt) 121 | jClass.addField(ctf) 122 | } 123 | 124 | val ctor = new CtConstructor(constructorArgs.map(_._1).map(ctClass).toArray, jClass) 125 | jClass.addConstructor(ctor) 126 | 127 | val ctorMethodInfo = 128 | jClass 129 | .getClassFile 130 | .getMethods 131 | .map(_.asInstanceOf[MethodInfo]) 132 | .find(_.getName == "") 133 | .get 134 | 135 | val ctorCA = Javassist.compile(classPool, constPool, constructorBody.dataflow(klass)) 136 | ctorMethodInfo.setCodeAttribute(ctorCA) 137 | val sm = javassist.bytecode.stackmap.MapMaker.make(classPool, ctorMethodInfo) 138 | ctorCA.setAttribute(sm) 139 | 140 | classLoader.registerClass(klass.ref.name, jClass.toBytecode) 141 | val concreteClass = classLoader.loadClass(klass.ref.name) 142 | 143 | val bytes = jClass.toBytecode 144 | Instance.registerMaterialized(classLoader, jClass.getName, bytes) 145 | 146 | new Klass.MaterializedNative(concreteClass, constructorArgs) 147 | } 148 | } 149 | 150 | private[this] def validate(): Unit = { 151 | def fail(msg: String) = 152 | throw new IllegalStateException(msg) 153 | 154 | klass.requireWholeInstanceField(fieldValues.keySet) 155 | 156 | if ((klass.ref.superClassRef.loadClass.getModifiers & Modifier.FINAL) == Modifier.FINAL) 157 | fail("base is final class") 158 | // TODO: check finalizer 159 | // * for each fields `f` in `x`: 160 | // * FAIL if `f` is non-final and `x` is _escaped_ 161 | // * if `f` defined at `_ <<: X` 162 | // * FAIL if 163 | // * `f` has type `_ <<: X` 164 | // * for each ALL methods/constructors `m` in `x`: 165 | // * FAIL if 166 | // * `m` is abstract 167 | // * `m` takes parameter `_ <<: X` 168 | // * `m` returns `_ <<: X` 169 | // * `m` has non-this reference `_ <<: X` 170 | // * for each visible or self-referenced non-constructor methods `m` in `x`: 171 | // * if `m` defined at `_ <<: X` 172 | // * FAIL if 173 | // * `m` is native 174 | // * `m` leaks `this` as `_ <<: X` 175 | // * for each constructor/used super constructor `c` in `x`: 176 | // * FAIL if ANY OF 177 | // * `c` is native 178 | // * `c` may have side-effect 179 | } 180 | } 181 | 182 | object ClassCompiler { 183 | def compile(klass: Klass.Modified, fieldValues: Map[(ClassRef, FieldRef), Data.Concrete], el: EventLogger): Klass.MaterializedNative = 184 | new ClassCompiler(klass, fieldValues, el).compile() 185 | } 186 | 187 | -------------------------------------------------------------------------------- /core/src/main/scala/FrameUpdate.scala: -------------------------------------------------------------------------------- 1 | package com.todesking.unveil 2 | 3 | import scala.language.existentials 4 | import scala.language.higherKinds 5 | 6 | import scala.reflect.{ classTag, ClassTag } 7 | import scala.collection.mutable 8 | import scala.util.{ Try, Success, Failure } 9 | 10 | import java.lang.reflect.{ Method => JMethod } 11 | 12 | import com.todesking.scalapp.syntax._ 13 | 14 | // TODO: posh(DataLabel.Out, Data) etc // TODO: WHAT is posh??? 15 | case class FrameUpdate( 16 | label: Bytecode.Label, 17 | bytecode: Bytecode, 18 | newFrame: Frame, 19 | frameItems: Map[(Bytecode.Label, DataPort), FrameItem], 20 | initializes: Map[(Bytecode.Label, DataPort.Out), Data.AbstractReference] 21 | ) { 22 | def this(label: Bytecode.Label, bytecode: Bytecode, frame: Frame) = 23 | this( 24 | label, 25 | bytecode, 26 | frame, 27 | Map.empty, 28 | Map.empty 29 | ) 30 | 31 | lazy val dataValues: Map[(Bytecode.Label, DataPort), Data] = 32 | frameItems.mapValues(_.data) 33 | 34 | lazy val dataSources: Map[(Bytecode.Label, DataPort), DataSource] = 35 | frameItems.mapValues(_.source) 36 | 37 | private[this] def fail(msg: String): RuntimeException = 38 | new RuntimeException(s"Analysis failed at ${label.format("L%d")} ${bytecode}: ${msg}") 39 | 40 | private[this] def requireSingleLocal(n: Int): Unit = { 41 | if (!newFrame.locals.contains(n)) throw fail(s"Local $n not defined") 42 | requireSingleWord(newFrame.locals(n)) 43 | } 44 | 45 | private[this] def requireSecondWord(fi: FrameItem): Unit = 46 | if (fi.data.typeRef != TypeRef.SecondWord) 47 | throw fail(s"second word value expected but ${fi}") 48 | 49 | private[this] def requireSingleWord(fi: FrameItem): Unit = 50 | if (fi.data.typeRef.isDoubleWord || fi.data.typeRef == TypeRef.SecondWord || fi.data.typeRef == TypeRef.Undefined) 51 | throw fail(s"single word value expected but ${fi}") 52 | 53 | private[this] def requireDoubleWord(fi: FrameItem): Unit = 54 | if (!fi.data.typeRef.isDoubleWord || fi.data.typeRef == TypeRef.SecondWord || fi.data.typeRef == TypeRef.Undefined) 55 | throw fail(s"double word value expected but ${fi}") 56 | 57 | private[this] def requireStackTopType(f: Frame, t: TypeRef): Unit = t match { 58 | case t: TypeRef.DoubleWord => 59 | if (f.stack.size < 2) throw fail("double word expected but stack too short") 60 | if (!t.isAssignableFrom(f.stack(0).data.typeRef)) throw fail(s"$t expected but ${f.stack(0).data.typeRef}") 61 | requireSecondWord(f.stack(1)) 62 | case t: TypeRef.SingleWord => 63 | if (f.stack.size < 1) throw fail("single word expected but stack too short") 64 | if (!t.isAssignableFrom(f.stack(0).data.typeRef)) throw fail(s"$t expected but ${f.stack(0).data.typeRef}") 65 | } 66 | 67 | private[this] def makeSecondWord(fi: FrameItem): FrameItem = 68 | FrameItem(fi.source, fi.data.secondWordData) 69 | 70 | def pop(t: TypeRef): FrameUpdate = 71 | if (t.isDoubleWord) pop2() 72 | else pop1() 73 | 74 | def pop(t: TypeRef, in: DataPort.In): FrameUpdate = 75 | if (t.isDoubleWord) pop2(in) 76 | else pop1(in) 77 | 78 | def pop1(): FrameUpdate = { 79 | requireSingleWord(newFrame.stackTop) 80 | copy( 81 | newFrame = newFrame.copy(stack = newFrame.stack.drop(1)) 82 | ) 83 | } 84 | 85 | def pop1(in: DataPort.In): FrameUpdate = { 86 | requireSingleWord(newFrame.stackTop) 87 | val x = newFrame.stackTop 88 | pop0(in, x, newFrame.stack.drop(1)) 89 | } 90 | 91 | def pop2(): FrameUpdate = { 92 | // TODO[BUG]: pop2 can pop 2 single word 93 | requireDoubleWord(newFrame.stack(0)) 94 | requireSecondWord(newFrame.stack(1)) 95 | copy(newFrame = newFrame.copy(stack = newFrame.stack.drop(2))) 96 | } 97 | 98 | def pop2(in: DataPort.In): FrameUpdate = { 99 | // TODO[BUG]: pop2 can pop 2 single word 100 | requireDoubleWord(newFrame.stack(0)) 101 | requireSecondWord(newFrame.stack(1)) 102 | val x = newFrame.stack(0) 103 | pop0(in, x, newFrame.stack.drop(2)) 104 | } 105 | 106 | private[this] def pop0(in: DataPort.In, fi: FrameItem, stack: List[FrameItem]): FrameUpdate = 107 | FrameUpdate( 108 | label, 109 | bytecode, 110 | newFrame.copy(stack = stack), 111 | frameItems + ((label -> in) -> fi), 112 | initializes 113 | ) 114 | 115 | def push(p: Option[DataPort], d: FrameItem): FrameUpdate = 116 | if (d.data.typeRef.isDoubleWord) push2(p, d) 117 | else push1(p, d) 118 | 119 | def push1(p: Option[DataPort], d: FrameItem): FrameUpdate = { 120 | requireSingleWord(d) 121 | push0(p, d, d :: newFrame.stack) 122 | } 123 | 124 | def push2(p: Option[DataPort], d: FrameItem): FrameUpdate = { 125 | requireDoubleWord(d) 126 | push0(p, d, d :: makeSecondWord(d) :: newFrame.stack) 127 | } 128 | 129 | private[this] def push0(p: Option[DataPort], fi: FrameItem, stack: List[FrameItem]): FrameUpdate = 130 | FrameUpdate( 131 | label, 132 | bytecode, 133 | newFrame.copy(stack = stack), 134 | p.fold(frameItems) { p => frameItems + ((label -> p) -> fi) }, 135 | initializes 136 | ) 137 | 138 | def setLocal(n: Int, data: FrameItem): FrameUpdate = { 139 | val locals = 140 | if (data.data.typeRef.isDoubleWord) 141 | newFrame.locals.updated(n, data).updated(n + 1, makeSecondWord(data)) 142 | else 143 | newFrame.locals.updated(n, data) 144 | FrameUpdate( 145 | label, 146 | bytecode, 147 | newFrame.copy(locals = newFrame.locals.updated(n, data)), 148 | frameItems, 149 | initializes 150 | ) 151 | } 152 | 153 | private[this] def local1(n: Int): FrameItem = { 154 | requireSingleLocal(n) 155 | newFrame.locals(n) 156 | } 157 | 158 | private[this] def local2(n: Int): FrameItem = { 159 | requireDoubleWord(newFrame.locals(n)) 160 | requireSecondWord(newFrame.locals(n + 1)) 161 | newFrame.locals(n) 162 | } 163 | 164 | def load1(n: Int): FrameUpdate = push1(None, local1(n)) 165 | def load2(n: Int): FrameUpdate = push2(None, local2(n)) 166 | 167 | def store1(tpe: TypeRef.SingleWord, n: Int): FrameUpdate = { 168 | requireStackTopType(newFrame, tpe) 169 | setLocal(n, newFrame.stackTop) 170 | .pop1() 171 | } 172 | 173 | def store2(tpe: TypeRef.DoubleWord, n: Int): FrameUpdate = { 174 | requireStackTopType(newFrame, tpe) 175 | setLocal(n, newFrame.stackTop) 176 | .setLocal(n + 1, makeSecondWord(newFrame.stackTop)) 177 | .pop2() 178 | } 179 | 180 | def ret(retval: DataPort.In): FrameUpdate = { 181 | val fi = 182 | if (newFrame.stackTop.data.typeRef == TypeRef.SecondWord) { 183 | requireDoubleWord(newFrame.stack(1)) 184 | newFrame.stack(1) 185 | } else { 186 | newFrame.stackTop 187 | } 188 | FrameUpdate( 189 | label, 190 | bytecode, 191 | Frame(Map.empty, List.empty), 192 | frameItems + ((label -> retval) -> fi), 193 | initializes 194 | ) 195 | } 196 | 197 | def athrow(objectref: DataPort.In): FrameUpdate = { 198 | requireSingleWord(newFrame.stackTop) 199 | FrameUpdate( 200 | label, 201 | bytecode, 202 | newFrame.copy(stack = newFrame.stack.take(1)), 203 | frameItems, 204 | initializes 205 | ) 206 | } 207 | 208 | def initializeInstance( 209 | label: Bytecode.Label, 210 | port: DataPort.Out, 211 | data: Data.AbstractReference 212 | ): FrameUpdate = 213 | copy( 214 | newFrame = 215 | newFrame.replaceDataBySource(DataSource.New(label, port), data), 216 | initializes = 217 | initializes + ((label, port) -> data) 218 | ) 219 | } 220 | -------------------------------------------------------------------------------- /core/src/main/scala/CodeFragment.scala: -------------------------------------------------------------------------------- 1 | package com.todesking.unveil 2 | 3 | abstract class CodeFragment { 4 | def bytecodeSeq: Seq[Bytecode] 5 | 6 | def size: Int 7 | 8 | def nonEmpty: Boolean = size > 0 9 | 10 | def incompleteJumpTargets: Map[(Bytecode.Label, JumpTarget), Either[String, Bytecode.Label]] 11 | 12 | def nameToLabel: Map[String, Bytecode.Label] 13 | 14 | lazy val bytecode: Seq[(Bytecode.Label, Bytecode)] = 15 | bytecodeSeq.zipWithIndex.map { case (b, i) => Bytecode.Label(i) -> b } 16 | 17 | def bytecodeFromLabel(l: Bytecode.Label): Bytecode = 18 | labelToBytecode(l) 19 | 20 | def pretty: String 21 | 22 | def methodReferences: Set[(ClassRef, MethodRef)] = 23 | bytecode.collect { case bc: Bytecode.HasMethodRef => (bc.classRef -> bc.methodRef) }.toSet 24 | 25 | def fieldReferences: Set[(ClassRef, FieldRef)] = 26 | bytecode.collect { case bc: Bytecode.HasFieldRef => (bc.classRef -> bc.fieldRef) }.toSet 27 | 28 | lazy val labelToBytecode: Map[Bytecode.Label, Bytecode] = 29 | bytecode.toMap 30 | 31 | def prepend(cf: CodeFragment): CodeFragment = 32 | cf + this 33 | 34 | def complete(): CodeFragment.Complete 35 | 36 | def +(rhs: CodeFragment): CodeFragment.Incomplete = 37 | this.concatForm + rhs 38 | 39 | def name(n: String): CodeFragment.Incomplete = 40 | CodeFragment.name(this, n) 41 | 42 | protected def concatForm: CodeFragment.Concat = 43 | CodeFragment.Concat(Seq(this), Map()) 44 | } 45 | object CodeFragment { 46 | def bytecode(bcs: Bytecode*): CodeFragment.Complete = 47 | CodeFragment.Complete(bcs, Map.empty) 48 | 49 | def name(cf: CodeFragment, n: String): Incomplete = 50 | Concat(Seq(cf), Map(n -> Bytecode.Label(0))) 51 | 52 | def abstractJump(bc: Bytecode.HasAJumpTarget, name: String): CodeFragment = 53 | Partial(Seq(bc), Map((Bytecode.Label(0), bc.jumpTarget) -> Left(name)), Map.empty) 54 | 55 | def empty(): CodeFragment = 56 | new Complete(Seq.empty, Map.empty) 57 | 58 | case class Complete( 59 | bytecodeSeq: Seq[Bytecode], 60 | jumpTargets: Map[(Bytecode.Label, JumpTarget), Bytecode.Label] 61 | ) extends CodeFragment { 62 | { 63 | val jts = bytecode.collect { 64 | case (l, bc: Bytecode.HasJumpTargets) => 65 | bc.jumpTargets.map(l -> _) 66 | }.flatten.toSet 67 | require(jts.size == jumpTargets.size) 68 | require(bytecode.forall { 69 | case (l, bc: Bytecode.HasJumpTargets) => 70 | bc.jumpTargets.forall { jt => jumpTargets.contains(l -> jt) } 71 | case _ => true 72 | }) 73 | } 74 | 75 | override def size = bytecodeSeq.size 76 | 77 | override def nameToLabel = Map.empty 78 | 79 | override def incompleteJumpTargets: Map[(Bytecode.Label, JumpTarget), Either[String, Bytecode.Label]] = 80 | jumpTargets.mapValues(Right(_)) 81 | 82 | override def complete() = this 83 | 84 | override def pretty = Pretty.format_CodeFragment_Complete(this) 85 | 86 | def jumpDestination(bcl: Bytecode.Label, jt: JumpTarget): Bytecode.Label = 87 | jumpTargets(bcl -> jt) 88 | 89 | def rewrite_*(f: PartialFunction[(Bytecode.Label, Bytecode), CodeFragment]): CodeFragment.Incomplete = 90 | rewrite_** { 91 | case x @ (label, bc) if f.isDefinedAt(x) => Map(label -> f(x)) 92 | } 93 | 94 | def rewrite_**(f: PartialFunction[(Bytecode.Label, Bytecode), Map[Bytecode.Label, CodeFragment]]): CodeFragment.Incomplete = { 95 | val liftedF = f.lift 96 | val allRewrites = 97 | bytecode.foldLeft(Map.empty[Bytecode.Label, CodeFragment]) { 98 | case (m, lbc @ (l, bc)) => 99 | liftedF(lbc).fold(m) { mm => 100 | Algorithm.sharedNothingUnion(m, mm).fold { 101 | throw new TransformException(s"rewrite conflict") 102 | }(identity) 103 | } 104 | } 105 | rewrite0(allRewrites) 106 | } 107 | 108 | private[this] def rewrite0(rewrites: Map[Bytecode.Label, CodeFragment]): CodeFragment.Incomplete = { 109 | def adjustIndex(start: Int, shift: Int, l: Bytecode.Label) = 110 | if (l.index > start) l.offset(shift) else l 111 | 112 | val (bcs, jts, n2l,_) = 113 | rewrites.toSeq.sortBy(_._1.index).foldLeft( 114 | (bytecodeSeq, incompleteJumpTargets, Map.empty[String, Bytecode.Label], 0) 115 | ) { 116 | case ((bcs, jts, n2l, offset), (label, cf)) => 117 | val start = label.index + offset 118 | require(cf.nonEmpty) // TODO: remove 119 | require(0 <= start && start < bcs.size) 120 | val newBcs = bcs.patch(start, cf.bytecodeSeq, 1) 121 | val shift = cf.bytecode.size - 1 122 | // TODO: [BUG] remove jumpTarget if replace target is jump 123 | val newJts: Map[(Bytecode.Label, JumpTarget), Either[String, Bytecode.Label]] = 124 | jts.map { 125 | case ((l, jt), dest) => 126 | val key = (adjustIndex(start, shift, l) -> jt) 127 | key -> dest.fold(Left.apply, l => Right(adjustIndex(start, shift, l))) 128 | } ++ cf.incompleteJumpTargets.map { case ((l, jt), dest) => 129 | (l.offset(start) -> jt) -> dest.fold(Left.apply, l => Right(l.offset(start))) 130 | } 131 | val newN2L = Algorithm.sharedNothingUnion(n2l, cf.nameToLabel.mapValues(_.offset(start))) getOrElse { 132 | throw new IllegalArgumentException(s"Name conflict: ${n2l.keys.filter(cf.nameToLabel.keySet).mkString(", ")}") 133 | } 134 | (newBcs, newJts, newN2L, offset + shift) 135 | } 136 | new CodeFragment.Partial(bcs, jts, n2l) 137 | } 138 | } 139 | 140 | sealed abstract class Incomplete extends CodeFragment { 141 | override def complete(): Complete 142 | override def nameToLabel: Map[String, Bytecode.Label] 143 | } 144 | 145 | case class Partial( 146 | override val bytecodeSeq: Seq[Bytecode], 147 | override val incompleteJumpTargets: Map[(Bytecode.Label, JumpTarget), Either[String, Bytecode.Label]], 148 | override val nameToLabel: Map[String, Bytecode.Label] 149 | ) extends Incomplete { 150 | override def size = bytecodeSeq.size 151 | override def pretty = toString // TODO 152 | override def complete() = 153 | new Complete( 154 | bytecodeSeq, 155 | incompleteJumpTargets.mapValues { 156 | case Left(name) => nameToLabel(name) 157 | case Right(label) => label 158 | } 159 | ) 160 | } 161 | 162 | case class Concat(items: Seq[CodeFragment], additionalNameToLabel: Map[String, Bytecode.Label]) extends Incomplete { 163 | override val size = items.map(_.size).sum 164 | override def pretty = toString // TODO 165 | override lazy val bytecodeSeq = items.flatMap(_.bytecodeSeq) 166 | override lazy val (incompleteJumpTargets, nameToLabel) = { 167 | val (ijt, n2l, _) = 168 | items.foldLeft(( 169 | Map.empty[(Bytecode.Label, JumpTarget), Either[String, Bytecode.Label]], 170 | additionalNameToLabel, 171 | 0 172 | )) { case ((ijt, n2l, offset), cf) => 173 | ( 174 | Algorithm.sharedNothingUnion( 175 | ijt, 176 | cf.incompleteJumpTargets.map { 177 | case ((l, jt), dest) => (l.offset(offset) -> jt) -> dest.fold(Left.apply, l => Right(l.offset(offset))) 178 | } 179 | ).getOrElse { throw new AssertionError() }, 180 | Algorithm.sharedNothingUnion( 181 | n2l, 182 | cf.nameToLabel.mapValues(_.offset(offset)) 183 | ).getOrElse { throw new AssertionError() }, 184 | offset + cf.size 185 | ) 186 | }; 187 | ( 188 | ijt.mapValues { 189 | case Right(l) => Right(l) 190 | case Left(name) => n2l.get(name).map(Right.apply) getOrElse Left(name) 191 | }, 192 | n2l 193 | ) 194 | } 195 | override def complete(): Complete = { 196 | Partial( 197 | bytecodeSeq, 198 | incompleteJumpTargets, 199 | nameToLabel 200 | ).complete() 201 | } 202 | override def +(rhs: CodeFragment): Concat = rhs match { 203 | case Concat(rItems, an2l) => 204 | val offset = items.map(_.size).sum 205 | Concat( 206 | items ++ rItems, 207 | Algorithm.sharedNothingUnion( 208 | additionalNameToLabel, 209 | an2l.mapValues { case l => l.offset(offset) } 210 | ).getOrElse { 211 | throw new IllegalArgumentException(s"Name conflict: ${additionalNameToLabel.keys.filter(an2l.keySet)}") 212 | } 213 | ) 214 | case rhs => Concat(items :+ rhs, nameToLabel) 215 | } 216 | override def concatForm = this 217 | } 218 | } 219 | 220 | -------------------------------------------------------------------------------- /core/src/main/scala/Instance.scala: -------------------------------------------------------------------------------- 1 | package com.todesking.unveil 2 | 3 | import scala.language.existentials 4 | import scala.language.higherKinds 5 | 6 | import scala.reflect.{ classTag, ClassTag } 7 | import scala.collection.mutable 8 | import scala.collection.JavaConversions._ 9 | 10 | import java.lang.reflect.{ Method => JMethod, Field => JField, Modifier, Constructor } 11 | 12 | import com.todesking.scalapp.syntax._ 13 | import Syntax.Upcast 14 | 15 | sealed abstract class Instance[A <: AnyRef] { 16 | def klass: Klass 17 | def escaped: Boolean = 18 | accessibleMethods 19 | .filter { case (cr, mr) => cr < ClassRef.Object } 20 | .exists { case (cr, mr) => 21 | dataflow(cr, mr).escaped(DataSource.This) 22 | } 23 | def toData: Data = Data.reference(this) 24 | 25 | def accessibleMethods: Set[(ClassRef, MethodRef)] = 26 | methods.filterNot(_._2.isPrivate).keySet 27 | 28 | // TODO: rename virtualMethodBody 29 | final def methodBody(ref: MethodRef): MethodBody = 30 | methodBody(resolveVirtualMethod(ref), ref) 31 | 32 | final def methodBody(classRef: ClassRef, methodRef: MethodRef): MethodBody = 33 | klass.methodBody(classRef, methodRef) 34 | 35 | def fieldValues: Map[(ClassRef, FieldRef), Data] 36 | 37 | def dataflow(methodRef: MethodRef): DataFlow = 38 | methodBody(methodRef).dataflow(this) 39 | 40 | def dataflow(classRef: ClassRef, methodRef: MethodRef): DataFlow = 41 | methodBody(classRef, methodRef).dataflow(this) 42 | 43 | def thisRef: ClassRef 44 | 45 | final def methods: Map[(ClassRef, MethodRef), MethodAttribute] = 46 | klass.instanceMethods 47 | 48 | final def virtualMethods: Map[MethodRef, MethodAttribute] = 49 | methods.filter(_._2.isVirtual).map { case ((cr, mr), a) => mr -> a } 50 | 51 | final def rewritableVirtualMethods: Map[MethodRef, MethodAttribute] = 52 | virtualMethods.filterNot(_._2.isFinal).filterNot(_._2.isNative) 53 | 54 | final def fieldKeys: Set[(ClassRef, FieldRef)] = 55 | klass.instanceFieldAttributes.keySet 56 | 57 | final def resolveVirtualMethod(mr: MethodRef): ClassRef = 58 | klass.resolveVirtualMethod(mr) 59 | 60 | // TODO: interface field??? 61 | def resolveField(cr: ClassRef, fr: FieldRef): ClassRef = 62 | klass.resolveInstanceField(cr, fr) 63 | 64 | def valueOption: Option[A] 65 | 66 | def usedMethodsOf(i: Instance[_ <: AnyRef]): Set[(ClassRef, MethodRef)] = 67 | analyzeMethods(Set.empty[(ClassRef, MethodRef)]) { (agg, cr, mr, df) => agg ++ df.usedMethodsOf(i) } 68 | 69 | def usedFieldsOf(i: Instance[_ <: AnyRef]): Set[(ClassRef, FieldRef)] = 70 | analyzeMethods(Set.empty[(ClassRef, FieldRef)]) { case (agg, cr, mr, df) => agg ++ df.usedFieldsOf(i) } 71 | 72 | def analyzeMethods[B](initial: B)(analyze: (B, ClassRef, MethodRef, DataFlow) => B): B = { 73 | // TODO: Exclude overriden and unaccessed method 74 | val ms = methods.filterNot { case (k, attrs) => attrs.isAbstract }.keys.toSeq.filterNot { case (cr, mr) => cr == ClassRef.Object } 75 | ms.foldLeft(initial) { case (agg, (cr, mr)) => analyze(agg, cr, mr, dataflow(cr, mr)) } 76 | } 77 | 78 | def pretty: String 79 | } 80 | object Instance { 81 | def of[A <: AnyRef](value: A): Original[A] = Original(value) 82 | 83 | sealed abstract class Concrete[A <: AnyRef] extends Instance[A] { 84 | def materialize(el: EventLogger): Instance.Original[A] 85 | // TODO: really need it? 86 | def value: A 87 | def duplicate[B >: A <: AnyRef: ClassTag](el: EventLogger): Instance[B] 88 | def duplicate1(el: EventLogger): Instance.Duplicate[A] 89 | // TODO: change to instanceFields 90 | def fields: Map[(ClassRef, FieldRef), Field] 91 | 92 | override def fieldValues: Map[(ClassRef, FieldRef), Data.Concrete] 93 | } 94 | sealed abstract class Abstract[A <: AnyRef] extends Instance[A] { 95 | final override def valueOption = None 96 | } 97 | 98 | case class Original[A <: AnyRef](value: A) extends Concrete[A] { 99 | require(value != null) 100 | 101 | override def hashCode = System.identityHashCode(value) 102 | override def equals(rhs: Any) = rhs match { 103 | case Original(v) => this.value eq v 104 | case _ => false 105 | } 106 | 107 | override def valueOption = Some(value) 108 | 109 | override val klass: Klass.Native = Klass.from(value.getClass) 110 | 111 | override def pretty = s"Instance.Original(${klass.name})" 112 | 113 | override def toString = pretty 114 | 115 | override def materialize(el: EventLogger) = this 116 | 117 | override val thisRef: ClassRef.Concrete = klass.ref 118 | 119 | override def duplicate[B >: A <: AnyRef: ClassTag](el: EventLogger): Duplicate[B] = 120 | el.section("Original.duplicate") { el => 121 | duplicate1(el).duplicate[B](el) 122 | } 123 | 124 | override def duplicate1(el: EventLogger): Duplicate[A] = 125 | Instance.duplicate(this, thisRef, el) 126 | 127 | override lazy val fields: Map[(ClassRef, FieldRef), Field] = 128 | klass.instanceFieldAttributes 129 | .filterNot(_._2.isStatic) 130 | .map { case (k @ (cr, fr), a) => k -> klass.readField(value, cr, fr) } 131 | 132 | override lazy val fieldValues: Map[(ClassRef, FieldRef), Data.Concrete] = 133 | fields.mapValues(_.data) 134 | } 135 | 136 | class Duplicate[A <: AnyRef]( 137 | override val klass: Klass.Modified, 138 | override val fieldValues: Map[(ClassRef, FieldRef), Data.Concrete] 139 | ) extends Concrete[A] with Equality.Reference { 140 | klass.requireWholeInstanceField(fieldValues.keySet) 141 | 142 | override def value = materialize(new EventLogger).value 143 | override def valueOption = Some(value) 144 | 145 | override def thisRef: ClassRef.Extend = 146 | klass.ref 147 | 148 | def superRef: ClassRef = 149 | klass.`super`.ref 150 | 151 | def setFieldValues(vs: Map[(ClassRef, FieldRef), Data.Concrete]): Duplicate[A] = { 152 | require(vs.keySet subsetOf fields.keySet) 153 | 154 | val (thisValues, superValues) = 155 | vs.partition { case ((cr, fr), f) => cr == thisRef } 156 | 157 | new Duplicate[A]( 158 | klass, 159 | fieldValues ++ vs 160 | ) 161 | } 162 | 163 | override def toString = s"Instance.Duplicate(${thisRef})" 164 | override def pretty: String = 165 | klass.pretty // TODO: add field values 166 | 167 | def addMethod(mr: MethodRef, body: MethodBody): Duplicate[A] = 168 | new Duplicate(klass.addMethod(mr, body), fieldValues) 169 | 170 | def addMethods(ms: Map[MethodRef, MethodBody]): Duplicate[A] = 171 | ms.foldLeft(this) { case (i, (mr, b)) => i.addMethod(mr, b) } 172 | 173 | private[this] def modifyKlass(f: klass.type => Klass.Modified): Duplicate[A] = 174 | new Duplicate(f(klass), fieldValues) 175 | 176 | def addField(fr: FieldRef, field: Field): Duplicate[A] = { 177 | new Duplicate( 178 | klass.addField(fr, field.attribute), 179 | fieldValues + ((thisRef, fr) -> field.data) 180 | ) 181 | } 182 | 183 | def addFields(fs: Map[FieldRef, Field]): Duplicate[A] = 184 | fs.foldLeft(this) { case (i, (fr, f)) => i.addField(fr, f) } 185 | 186 | override def duplicate1(el: EventLogger) = 187 | rewriteThisRef(thisRef.anotherUniqueName) 188 | 189 | override def duplicate[B >: A <: AnyRef: ClassTag](el: EventLogger): Duplicate[B] = { 190 | val newSuperRef = ClassRef.of(implicitly[ClassTag[B]].runtimeClass) 191 | Instance.duplicate(this, newSuperRef, el) 192 | } 193 | 194 | // TODO: replace thisRef in method/field 195 | def rewriteThisRef(newRef: ClassRef.Extend): Duplicate[A] = 196 | new Duplicate[A]( 197 | klass.changeRef(newRef), 198 | fieldValues.map { 199 | case (k @ (cr, fr), v) => 200 | if (cr == thisRef) ((newRef -> fr) -> v) 201 | else k -> v 202 | } 203 | ) 204 | 205 | override lazy val fields: Map[(ClassRef, FieldRef), Field] = 206 | klass.instanceFieldAttributes.map { 207 | case (k @ (cr, fr), fa) => 208 | k -> Field(fr.descriptor, fa, fieldValues(k)) 209 | } 210 | 211 | override def materialize(el: EventLogger): Original[A] = 212 | klass.materialize(fieldValues, el) 213 | .newInstance[A]() 214 | } 215 | 216 | class New[A <: AnyRef](override val klass: Klass.Native, constructor: MethodDescriptor) extends Abstract[A] with Equality.Reference { 217 | def constructorDataFlow: DataFlow = 218 | dataflow(klass.ref, MethodRef.constructor(constructor)) 219 | 220 | override lazy val fieldValues = 221 | klass.instanceFieldAttributes.map { case (k @ (cr, fr), a) => k -> Data.Unknown(fr.typeRef) } 222 | 223 | override def pretty: String = s"new ${klass.ref}(${constructor.argsStr})" 224 | override def toString = pretty 225 | override def thisRef = klass.ref 226 | } 227 | 228 | class Given[A <: AnyRef](override val klass: Klass, valueOverrides: Map[(ClassRef, FieldRef), Data]) extends Abstract[A] with Equality.Reference { 229 | override lazy val fieldValues = 230 | klass.instanceFieldAttributes.map { case (k@(cr, fr), a) => 231 | k -> valueOverrides.get(k).getOrElse(Data.Unknown(fr.typeRef)) 232 | } 233 | override def thisRef = klass.ref 234 | override def pretty = s"" 235 | } 236 | 237 | private def duplicate[A <: AnyRef, B >: A <: AnyRef](o: Instance.Concrete[A], superRef: ClassRef.Concrete, el: EventLogger): Duplicate[B] = { 238 | el.section("Instance.duplicate") { el => 239 | el.logCFields("base instance fields", o.fields.keySet) 240 | val (klass, fieldRenaming) = o.klass.duplicate(superRef, el) 241 | val fieldValues = 242 | o.fields.flatMap { 243 | case (k @ (cr, fr), field) => 244 | fieldRenaming.get(k).fold { 245 | if (cr < superRef) Map.empty[(ClassRef, FieldRef), Data.Concrete] 246 | else Map(k -> field.data) 247 | } { newFr => 248 | if (cr < superRef) 249 | Map((klass.ref.asInstanceOf[ClassRef] -> newFr) -> field.data) 250 | else 251 | Map( 252 | k -> field.data, 253 | (klass.ref.asInstanceOf[ClassRef] -> newFr) -> field.data 254 | ) 255 | } 256 | } 257 | el.logCFields("valued fields", fieldValues.keys) 258 | new Duplicate[B]( 259 | klass, 260 | fieldValues 261 | ) 262 | } 263 | } 264 | 265 | // TODO: Weaken CL 266 | private[this] val materializedClasses = mutable.HashMap.empty[(ClassLoader, String), Array[Byte]] 267 | def registerMaterialized(cl: ClassLoader, name: String, bytes: Array[Byte]): Unit = synchronized { 268 | if (materializedClasses.contains(cl -> name)) 269 | throw new IllegalArgumentException(s"${name} is already defined in ${cl}") 270 | materializedClasses(cl -> name) = bytes 271 | } 272 | // TODO: Resolve name conflict 273 | def findMaterializedClasses(cl: ClassLoader): Seq[(String, Array[Byte])] = synchronized { 274 | if (cl == null) { 275 | Seq.empty 276 | } else { 277 | materializedClasses.collect { case ((l, n), b) if l == cl => (n -> b) }.toSeq ++ 278 | findMaterializedClasses(cl.getParent) 279 | } 280 | } 281 | } 282 | 283 | -------------------------------------------------------------------------------- /core/src/main/scala/Klass.scala: -------------------------------------------------------------------------------- 1 | package com.todesking.unveil 2 | 3 | sealed abstract class Klass { 4 | def name: String = ref.name 5 | def ref: ClassRef 6 | def methodBody(cr: ClassRef, mr: MethodRef): MethodBody 7 | def instanceMethods: Map[(ClassRef, MethodRef), MethodAttribute] 8 | def instanceFieldAttributes: Map[(ClassRef, FieldRef), FieldAttribute] 9 | def virtualMethods: Set[MethodRef] = 10 | instanceMethods.filter(_._2.isVirtual).map { case ((cr, mr), a) => mr }.toSet 11 | def hasVirtualMethod(ref: MethodRef): Boolean = 12 | instanceMethods.exists { case ((c, m), a) => m == ref && a.isVirtual } 13 | def hasVirtualMethod(mr: String): Boolean = 14 | hasVirtualMethod(MethodRef.parse(mr, ref.classLoader)) 15 | def resolveVirtualMethod(mr: MethodRef): ClassRef 16 | def superKlass: Option[Klass] 17 | def hasInstanceField(cr: ClassRef, fr: FieldRef): Boolean = 18 | instanceFieldAttributes.contains(cr -> fr) 19 | def dataflow(cr: ClassRef, mr: MethodRef): DataFlow = 20 | methodBody(cr, mr).dataflow(this) 21 | def pretty: String = Pretty.format_Klass(this) 22 | 23 | // TODO: interface field??? 24 | def resolveInstanceField(cr: ClassRef, fr: FieldRef): ClassRef = 25 | if (hasInstanceField(cr, fr)) cr 26 | else Reflect.superClassOf(cr).map { sc => resolveInstanceField(sc, fr) } getOrElse { 27 | throw new IllegalArgumentException(s"Instance field resolution failed: $cr.$fr") 28 | } 29 | 30 | // TODO: rename this method 31 | // TODO: [BUG] if `this` is leaked...?? 32 | def extendMethods(seed: Set[(ClassRef, MethodRef)]): Set[(ClassRef, MethodRef)] = { 33 | // TODO: this is very inefficient 34 | var ms: Set[(ClassRef, MethodRef)] = null 35 | var ext: Set[(ClassRef, MethodRef)] = seed 36 | do { 37 | ms = ext 38 | ext = ms ++ 39 | ms.filter { case (cr, _) => cr < ClassRef.Object } 40 | .flatMap { case (cr, mr) => dataflow(cr, mr).usedMethodsOf(DataSource.This, this) } 41 | } while (ext.size > ms.size) 42 | ext 43 | } 44 | 45 | def requireWholeInstanceField(fvs: Set[(ClassRef, FieldRef)]): Unit = { 46 | val nonStaticFields = instanceFieldAttributes.keySet 47 | if ((nonStaticFields -- fvs).nonEmpty) { 48 | throw new IllegalArgumentException(s"Field value missing: ${nonStaticFields -- fvs}") 49 | } else if ((fvs -- nonStaticFields).nonEmpty) { 50 | throw new IllegalArgumentException(s"Unknown field value: ${fvs -- nonStaticFields}") 51 | } 52 | } 53 | 54 | // return: field -> new field name 55 | def duplicate(superRef: ClassRef.Concrete, el: EventLogger): (Klass.Modified, Map[(ClassRef, FieldRef), FieldRef]) = 56 | el.section("Klass.duplicate") { el => 57 | el.log(s"from = ${ref}") 58 | el.log(s"new superclass = ${superRef}") 59 | require(superRef >= ref) 60 | 61 | // TODO: reject if dependent method is not moveable 62 | // TODO[BUG]: use same-runtime-package accessor class 63 | // TODO[BUG]: check resolved class reference is same as super class class loader's result 64 | 65 | val thisRef = superRef.extend(new AccessibleClassLoader(superRef.classLoader)) 66 | val overridableVirtualMethods = 67 | this.virtualMethods 68 | .map { mr => this.resolveVirtualMethod(mr) -> mr } 69 | .filter { case (cr, _) => cr < ClassRef.Object } 70 | .filter { case k @ (cr, mr) => (cr < superRef) || !this.instanceMethods(k).isFinal } 71 | .filterNot { case k @ (cr, mr) => this.instanceMethods(k).isNative } 72 | el.logCMethods("overridable virtual instanceMethods", overridableVirtualMethods) 73 | 74 | val requiredMethods = 75 | this.extendMethods(overridableVirtualMethods) 76 | .filter { case (cr, _) => cr < ClassRef.Object } 77 | .filterNot { case k @ (cr, mr) => this.instanceMethods(k).isNative } 78 | .map { case k @ (cr, mr) => k -> this.dataflow(cr, mr) } 79 | .toMap 80 | el.logCMethods("required instanceMethods", requiredMethods.keys) 81 | 82 | val methodRenaming = 83 | requiredMethods.collect { 84 | case (k @ (cr, mr), df) if !overridableVirtualMethods.contains(k) => 85 | (k -> mr.anotherUniqueName()) 86 | } 87 | el.logCMethods("renamed instanceMethods", methodRenaming.keys) 88 | 89 | val requiredFields = 90 | requiredMethods.values.flatMap { df => df.usedFieldsOf(DataSource.This, this) } 91 | .map { case (cr, fr) => this.resolveInstanceField(cr, fr) -> fr } 92 | .toSet 93 | el.logCFields("required fields", requiredFields) 94 | 95 | requiredFields foreach { 96 | case k @ (cr, fr) => 97 | val fa = this.instanceFieldAttributes(k) 98 | if (cr >= superRef && fa.isPrivate && !fa.isFinal) 99 | throw new TransformException(s"Required field is non-final private: $cr.$fr") 100 | } 101 | 102 | val fieldRenaming = 103 | requiredFields 104 | .filter { case k @ (cr, fr) => cr < superRef || this.instanceFieldAttributes(k).isPrivate } 105 | .map { case k @ (cr, fr) => k -> fr.anotherUniqueName() } 106 | .toMap 107 | el.logCFields("renamed fields", fieldRenaming.keys) 108 | 109 | val thisMethods = 110 | requiredMethods.map { 111 | case (k @ (cr, mr), df) => 112 | val newMr = methodRenaming.get(k).getOrElse(mr) 113 | import Bytecode._ 114 | newMr -> df.body.rewrite { 115 | case (label, bc @ invokevirtual(cr, mr)) if df.mustThis(label, bc.objectref) => 116 | val vcr = this.resolveVirtualMethod(mr) 117 | methodRenaming.get(vcr -> mr).fold { 118 | bc.rewriteClassRef(thisRef) 119 | } { newMr => 120 | bc.rewriteMethodRef(thisRef, newMr) 121 | } 122 | case (label, bc @ invokeinterface(cr, mr, _)) if df.mustThis(label, bc.objectref) => 123 | val vcr = this.resolveVirtualMethod(mr) 124 | methodRenaming.get(vcr -> mr).fold { 125 | bc.rewriteClassRef(thisRef) 126 | } { newMr => 127 | bc.rewriteMethodRef(thisRef, newMr) 128 | } 129 | case (label, bc @ invokespecial(cr, mr)) if df.mustThis(label, bc.objectref) => 130 | // TODO: resolve special 131 | methodRenaming.get(cr -> mr).fold { 132 | bc 133 | } { newMr => 134 | bc.rewriteMethodRef(thisRef, newMr) 135 | } 136 | case (label, bc: InstanceFieldAccess) if df.mustThis(label, bc.objectref) => 137 | fieldRenaming.get(this.resolveInstanceField(bc.classRef, bc.fieldRef) -> bc.fieldRef).fold(bc) { newFr => 138 | bc.rewriteFieldRef(thisRef, newFr) 139 | } 140 | }.makeNonFinal 141 | } 142 | el.logMethods("thisMethods", thisMethods.keys) 143 | 144 | val thisFields = 145 | fieldRenaming.map { 146 | case (k @ (cr, fr), newFr) => 147 | newFr -> this.instanceFieldAttributes(k) 148 | } 149 | el.logFields("thisFields", thisFields.keys) 150 | 151 | new Klass.Modified(superRef.loadKlass, thisRef, thisMethods, thisFields) -> fieldRenaming 152 | } 153 | } 154 | object Klass { 155 | def from(j: Class[_]): Native = 156 | new Native(j) 157 | 158 | class Native(val javaClass: Class[_]) extends Klass with Equality.Delegate { 159 | override def canEqual(rhs: Any): Boolean = rhs.isInstanceOf[Native] 160 | override val equalityObject = javaClass 161 | 162 | override lazy val superKlass = 163 | if (javaClass.getSuperclass == null) None 164 | else Some(Klass.from(javaClass.getSuperclass)) 165 | 166 | override def ref: ClassRef.Concrete = ClassRef.of(javaClass) 167 | 168 | override def resolveVirtualMethod(mr: MethodRef): ClassRef = 169 | Reflect.resolveVirtualMethod(javaClass, mr) 170 | 171 | override def methodBody(cr: ClassRef, mr: MethodRef) = 172 | if (mr.isInit) MethodBody.parse(allJConstructors(cr -> mr)) 173 | else MethodBody.parse(allJMethods(cr -> mr)) 174 | 175 | override lazy val instanceMethods: Map[(ClassRef, MethodRef), MethodAttribute] = 176 | allJMethods.map { case (k, m) => k -> MethodAttribute.from(m) }.filterNot(_._2.isStatic) 177 | 178 | override def instanceFieldAttributes = 179 | allJFields.mapValues(FieldAttribute.from(_)).filter(!_._2.isStatic) 180 | 181 | def readField(obj: AnyRef, cr: ClassRef, fr: FieldRef): Field = 182 | Field.from(allJFields(cr -> fr), obj) 183 | 184 | private[this] lazy val virtualJMethods = Reflect.virtualJMethods(javaClass) 185 | private[this] lazy val allJMethods = Reflect.allJMethods(javaClass) 186 | private[this] lazy val allJFields = Reflect.allJFields(javaClass) 187 | private[this] lazy val allJConstructors = Reflect.allJConstructors(javaClass) 188 | 189 | } 190 | 191 | class MaterializedNative( 192 | jc: Class[_], 193 | val constructorArgs: Seq[(TypeRef.Public, Any)] 194 | ) extends Native(jc) { 195 | // TODO: add type parameter 196 | def newInstance[A <: AnyRef](): Instance.Original[A] = { 197 | val value = 198 | try { 199 | javaClass 200 | .getDeclaredConstructor(constructorArgs.map(_._1.javaClass).toArray: _*) 201 | .newInstance(constructorArgs.map(_._2.asInstanceOf[Object]).toArray: _*) 202 | } catch { 203 | case e: LinkageError => throw new InvalidClassException(this, e) 204 | } 205 | Instance.of(value.asInstanceOf[A]) 206 | } 207 | } 208 | 209 | class Modified( 210 | val `super`: Klass.Native, 211 | override val ref: ClassRef.Extend, 212 | val declaredMethods: Map[MethodRef, MethodBody], 213 | val declaredFields: Map[FieldRef, FieldAttribute] 214 | ) extends Klass with Equality.Reference { 215 | require(!declaredFields.exists(_._2.isStatic)) 216 | 217 | override def instanceFieldAttributes: Map[(ClassRef, FieldRef), FieldAttribute] = 218 | `super`.instanceFieldAttributes ++ declaredFields.map { case (fr, fa) => (ref, fr) -> fa } 219 | 220 | override def superKlass: Option[Klass] = Some(`super`) 221 | 222 | // TODO: support default interface method 223 | override def resolveVirtualMethod(mr: MethodRef): ClassRef = 224 | declaredMethods.get(mr).map { body => 225 | if (body.attribute.isVirtual) ref 226 | else throw new IllegalArgumentException(s"Not virtual: ${mr} ${body.attribute}") 227 | } getOrElse { 228 | `super`.resolveVirtualMethod(mr) 229 | } 230 | 231 | override def methodBody(cr: ClassRef, mr: MethodRef) = 232 | if (cr == ref) declaredMethods(mr) 233 | else if (ref < cr) `super`.methodBody(cr, mr) 234 | else throw new IllegalArgumentException(s"Method not found: ${cr.pretty}.${mr.str}") 235 | 236 | override lazy val instanceMethods = 237 | `super`.instanceMethods ++ declaredMethods.map { 238 | case (k, v) => (ref -> k) -> v.attribute 239 | } 240 | 241 | def addMethod(mr: MethodRef, body: MethodBody): Modified = { 242 | require(mr.descriptor == body.descriptor) 243 | require(!body.attribute.isStatic) 244 | new Modified(`super`, ref, declaredMethods + (mr -> body), declaredFields) 245 | } 246 | 247 | def addField(fr: FieldRef, attribute: FieldAttribute): Modified = { 248 | require(!attribute.isStatic) 249 | new Modified(`super`, ref, declaredMethods, declaredFields + (fr -> attribute)) 250 | } 251 | 252 | def changeRef(newRef: ClassRef.Extend): Modified = { 253 | require(newRef.superClassRef == ref.superClassRef) 254 | new Modified( 255 | `super`, 256 | newRef, 257 | declaredMethods 258 | .map { case (mr, body) => mr -> body.rewriteClassRef(ref, newRef) }, 259 | declaredFields 260 | ) 261 | } 262 | 263 | def materialize(fieldValues: Map[(ClassRef, FieldRef), Data.Concrete], el: EventLogger): MaterializedNative = 264 | ClassCompiler.compile(this, fieldValues, el) 265 | } 266 | } 267 | -------------------------------------------------------------------------------- /core/src/test/scala/spec.scala: -------------------------------------------------------------------------------- 1 | package com.todesking.unveil 2 | 3 | import com.todesking.scalapp.syntax._ 4 | 5 | import org.scalatest.{ FunSpec, Matchers, Failed } 6 | 7 | class Spec extends FunSpec with Matchers { 8 | def dotBody(filename: String, self: Instance[_ <: AnyRef], b: MethodBody): Unit = { 9 | import java.nio.file._ 10 | // TODO: DataFlow.toDot 11 | // Files.write(Paths.get(filename), b.dataflow(self).toDot.getBytes("UTF-8")) 12 | } 13 | 14 | def withThe[A](i: Instance[_])(f: => A): A = { 15 | try { f } catch { case e: Throwable => println("==== TEST FAILED\n" + i.pretty); throw e } 16 | } 17 | 18 | val el = Transformer.newEventLogger() 19 | 20 | override def withFixture(test: NoArgTest) = { 21 | def printEvents(): Unit = { 22 | println("=== EVENT LOG:") 23 | println(el.pretty) 24 | } 25 | val result = try { 26 | super.withFixture(test) 27 | } catch { 28 | case e: Throwable => 29 | println("=== TEST FAILED") 30 | printEvents() 31 | throw e 32 | } 33 | result match { 34 | case Failed(t: UnveilException.HasMethodBody) => 35 | println(s"=== FAILED($t)") 36 | println(t) 37 | println(t.methodBody.pretty) 38 | printEvents() 39 | case o @ Failed(t: InvalidClassException) => 40 | println("=== INVALID CLASS") 41 | println(t.klass.pretty) 42 | printEvents() 43 | case Failed(t: UnveilBugException) => 44 | println(t.detail) 45 | printEvents() 46 | case Failed(t) => 47 | println(s"=== TEST FAILED($t)") 48 | printEvents() 49 | case _ => 50 | } 51 | el.clear() 52 | result 53 | } 54 | 55 | describe("opt") { 56 | val defaultCL = ClassLoader.getSystemClassLoader 57 | it("duplicate") { 58 | class Const { 59 | def intMethod(): Int = 1 60 | def longMethod(): Long = 0L 61 | } 62 | val obj = new Const 63 | val i = Instance.of(obj) 64 | i.klass.hasVirtualMethod("intMethod()I") should be(true) 65 | 66 | val intMethod = MethodRef.parse("intMethod()I", defaultCL) 67 | val longMethod = MethodRef.parse("longMethod()J", defaultCL) 68 | val ri = i.duplicate[Const](el) 69 | withThe(ri) { 70 | val m = ri.materialize(el) 71 | m.value.intMethod() should be(1) 72 | m.value.longMethod() should be(0L) 73 | } 74 | } 75 | it("invokeVirtual with no arguments") { 76 | class InvokeVirtual0 { 77 | def foo(): Int = bar() 78 | def bar(): Int = 1 79 | } 80 | val d = new InvokeVirtual0 81 | d.foo() should be(1) 82 | 83 | val i = Instance.of(d) 84 | val foo = MethodRef.parse("foo()I", defaultCL) 85 | 86 | val ri = i.duplicate(el).materialize(el) 87 | 88 | ri.value.foo() should be(1) 89 | } 90 | it("invokeVirtual1") { 91 | class InvokeVirtual1 { 92 | def foo(): Int = bar(1) 93 | def bar(n: Int): Int = n 94 | } 95 | val d = new InvokeVirtual1 96 | d.foo() should be(1) 97 | 98 | val i = Instance.of(d) 99 | val foo = MethodRef.parse("foo()I", defaultCL) 100 | 101 | val ri = i.duplicate(el).materialize(el) 102 | 103 | ri.value.foo() should be(1) 104 | } 105 | it("if") { 106 | class If { 107 | def foo(a: Int): Int = 108 | if (a > 0) 100 109 | else if (a > -10) -10 110 | else -100 111 | } 112 | val d = new If 113 | d.foo(1) should be(100) 114 | d.foo(-1) should be(-10) 115 | d.foo(-11) should be(-100) 116 | 117 | val i = Instance.of(d) 118 | val foo = MethodRef.parse("foo(I)I", defaultCL) 119 | 120 | val ri = i.duplicate(el).materialize(el) 121 | ri.value.foo(1) should be(100) 122 | ri.value.foo(-1) should be(-10) 123 | ri.value.foo(-11) should be(-100) 124 | } 125 | it("other method") { 126 | object OtherMethod { 127 | abstract class A { 128 | def foo(): Int 129 | def bar(): Int = 10 130 | } 131 | class B extends A { 132 | override def foo() = baz() 133 | override def bar() = 99 134 | def baz() = bar() 135 | } 136 | } 137 | val obj = new OtherMethod.B 138 | obj.foo() should be(99) 139 | val i = Instance.of(obj) 140 | val foo = MethodRef.parse("foo()I", defaultCL) 141 | val ri = i.duplicate(el).materialize(el) 142 | ri.value.foo() should be(99) 143 | } 144 | it("real upcast") { 145 | abstract class A { 146 | def foo(): Int 147 | def bar(): Int = 10 148 | } 149 | final class B extends A { 150 | override def foo() = baz() 151 | override def bar() = 99 152 | def baz() = bar() 153 | } 154 | val obj = new B 155 | obj.foo() should be(99) 156 | val i = Instance.of[A](obj) 157 | val foo = MethodRef.parse("foo()I", defaultCL) 158 | val ri = i.duplicate[A](el).materialize(el) 159 | dotBody("real_upcast.dot", ri, ri.methodBody(foo)) 160 | classOf[A].isAssignableFrom(ri.value.getClass) should be(true) 161 | classOf[B].isAssignableFrom(ri.value.getClass) should be(false) 162 | ri.value.foo() should be(99) 163 | } 164 | it("simple dataflow compile") { 165 | class A { 166 | def foo(): Int = if (bar > 20) 1 else 2 167 | def bar(): Int = 10 168 | } 169 | 170 | val i = Instance.of(new A) 171 | i.value.foo() should be(2) 172 | 173 | val foo = MethodRef.parse("foo()I", defaultCL) 174 | 175 | val ri = i.duplicate(el).materialize(el) 176 | 177 | dotBody("s.dot", ri, ri.methodBody(foo)) 178 | 179 | ri.value.foo() should be(2) 180 | } 181 | it("primitive field") { 182 | class A { 183 | val foo = 10 184 | } 185 | val foo = MethodRef.parse("foo()I", defaultCL) 186 | 187 | val i = Instance.of(new A) 188 | i.value.foo should be(10) 189 | 190 | val ri = i.duplicate(el).materialize(el) 191 | ri.value.foo should be(10) 192 | } 193 | it("field duplicate") { 194 | abstract class Base { 195 | def foo: Int 196 | } 197 | class A extends Base { 198 | val x = 1000 199 | override def foo = x 200 | } 201 | class B extends A { 202 | } 203 | val i = Instance.of(new B) 204 | i.value.foo should be(1000) 205 | 206 | val ri = i.duplicate[Base](el).materialize(el) 207 | ri.value.foo should be(1000) 208 | } 209 | it("dupdup") { 210 | abstract class Base { 211 | def foo: Int 212 | } 213 | class A extends Base { 214 | val x = 1000 215 | override def foo = x 216 | } 217 | class B extends A { 218 | } 219 | val i = Instance.of(new B) 220 | i.value.foo should be(1000) 221 | 222 | val ri = i.duplicate[Base](el).duplicate[Base](el).materialize(el) 223 | ri.value.foo should be(1000) 224 | } 225 | it("dup and accessor") { 226 | abstract class Base { def foo: Int } 227 | class A extends Base { override val foo = 1 } 228 | val dup = Instance.of(new A).duplicate[Base](el) 229 | dup.materialize(el).value.foo should be(1) 230 | } 231 | describe("field fusion") { 232 | it("when empty") { 233 | class A { 234 | def foo(): Int = 1 235 | } 236 | val expected = 1 237 | 238 | val i = Instance.of(new A) 239 | i.value.foo() should be(expected) 240 | 241 | val fi = Transformer.fieldFusion.apply(i, el).get 242 | fi.materialize(el).value.foo should be(expected) 243 | } 244 | it("simple") { 245 | class A { 246 | def foo(): Int = 1 247 | } 248 | class B { 249 | def bar(): Int = 10 250 | } 251 | val expected = 10 252 | 253 | val foo = MethodRef.parse("foo()I", defaultCL) 254 | val bar = MethodRef.parse("bar()I", defaultCL) 255 | val fieldB = FieldRef("b", FieldDescriptor(TypeRef.Reference(ClassRef.of(classOf[B])))) 256 | 257 | val b = Instance.of(new B) 258 | val i0 = Instance.of(new A) 259 | .duplicate1(el) 260 | .addField(fieldB, Field(fieldB.descriptor, FieldAttribute.Final, Data.ConcreteReference(b))) 261 | val i = 262 | i0.addMethod( 263 | foo, 264 | MethodBody( 265 | false, 266 | foo.descriptor, 267 | MethodAttribute.Public, 268 | CodeFragment.bytecode( 269 | Bytecode.aload(0), 270 | Bytecode.getfield(i0.thisRef, fieldB), 271 | Bytecode.invokevirtual(b.thisRef, bar), 272 | Bytecode.ireturn() 273 | ) 274 | ) 275 | ) 276 | withThe(i) { 277 | i.materialize(el).value.foo() should be(expected) 278 | } 279 | 280 | val fi = withThe(i) { 281 | Transformer.fieldFusion(i, el).get 282 | } 283 | withThe(fi) { 284 | fi.materialize(el).value.foo() should be(expected) 285 | } 286 | } 287 | it("nested") { 288 | class A(b: B) { 289 | def bbb = b 290 | def foo(): Int = b.bar() + 1000 291 | } 292 | class B(c: C) { 293 | def bar(): Int = c.baz() + 1 294 | } 295 | class C { 296 | def baz(): Int = 999 297 | } 298 | val expected = 2000 299 | val foo = MethodRef.parse("foo()I", defaultCL) 300 | val bar = MethodRef.parse("bar()I", defaultCL) 301 | 302 | val c = new C 303 | val b = new B(c) 304 | val a = new A(b) 305 | a.foo() should be(expected) 306 | 307 | val i = Instance.of(a) 308 | 309 | val fused = Transformer.fieldFusion(i, el).get 310 | withThe(fused) { 311 | fused.dataflow(foo).usedFieldsOf(fused) should be('empty) 312 | fused.usedMethodsOf(Instance.of(c)) should be('empty) 313 | fused.materialize(el).value.foo() should be(expected) 314 | } 315 | } 316 | 317 | it("Function1") { 318 | val f1 = { n: Int => n + 1 } 319 | val f2 = { n: Int => n * 2 } 320 | val f3 = { n: Int => n + 1 } 321 | val f = f1 andThen f2 andThen f3 322 | val i = Instance.of(f) 323 | val n = 1 324 | val expected = 5 325 | i.materialize(el).value.apply(n) should be(expected) 326 | 327 | val dup = i.duplicate[Int => Int](el) 328 | withThe(dup) { 329 | dup.materialize(el).value.apply(n) should be(expected) 330 | } 331 | 332 | val ti = Transformer.fieldFusion(i.duplicate[Int => Int](el), el).get 333 | withThe(ti) { 334 | ti.materialize(el).value.apply(n) should be(expected) 335 | } 336 | } 337 | } 338 | 339 | describe("method inlining") { 340 | it("no control") { 341 | class A { 342 | def foo(): Int = bar() + baz() 343 | def bar(): Int = 10 344 | private[this] def baz(): Int = 90 345 | } 346 | val expected = 100 347 | val foo = MethodRef.parse("foo()I", defaultCL) 348 | 349 | val i = Instance.of(new A) 350 | i.value.foo() should be(expected) 351 | 352 | val ri = Transformer.methodInlining(i, el).get 353 | withThe(ri) { 354 | ri.materialize(el).value.foo() should be(expected) 355 | ri.dataflow(foo).usedMethodsOf(ri) should be('empty) 356 | } 357 | } 358 | it("control") { 359 | class A { 360 | def foo(): Int = bar(10) + baz(1) + baz(0) 361 | def bar(n: Int): Int = if (n > 0) 10 else 20 362 | private[this] def baz(n: Int): Int = if (n < 0) n + 1 else if (n > 0) n - 1 else 10 363 | } 364 | val expected = 10 + 0 + 10 365 | val foo = MethodRef.parse("foo()I", defaultCL) 366 | 367 | val i = Instance.of(new A) 368 | i.value.foo() should be(expected) 369 | 370 | val ri = Transformer.methodInlining(i, el).get 371 | withThe(ri) { 372 | ri.materialize(el).value.foo() should be(expected) 373 | ri.dataflow(foo).usedMethodsOf(ri) should be('empty) 374 | } 375 | } 376 | it("Function1") { 377 | pending 378 | } 379 | } 380 | 381 | describe("new instance") { 382 | it("handle new insn") { 383 | class A(val value: Int) { 384 | def foo(): A = new A(2) 385 | } 386 | val x = Instance.of(new A(1)).duplicate[A](el) 387 | x.materialize(el).value.foo.value should be(2) 388 | } 389 | it("inline local instance") { 390 | class A(val value: Int) { 391 | def foo(): Int = new A(2).value + value 392 | } 393 | val i = Instance.of(new A(1)) 394 | println(i.duplicate1(el).pretty) 395 | val ri = Transformer.localInstanceInlining(i, el).get 396 | withThe(ri) { 397 | ri.materialize(el).value.foo() should be(3) 398 | ri.methodBody(MethodRef.parse("foo()I", defaultCL)) 399 | .bytecode.collect { case x@(l, Bytecode.new_(_)) => x } should be('empty) 400 | } 401 | } 402 | } 403 | 404 | it("double values") { 405 | pending 406 | } 407 | 408 | // TODO: Exception handler rejection test 409 | // TODO: accessor inlining 410 | // TODO: accept new instance as constant in SetterConstructor 411 | } 412 | } 413 | -------------------------------------------------------------------------------- /core/src/main/scala/DataFlow.scala: -------------------------------------------------------------------------------- 1 | package com.todesking.unveil 2 | 3 | import scala.language.existentials 4 | import scala.language.higherKinds 5 | 6 | import scala.reflect.{ classTag, ClassTag } 7 | import scala.collection.mutable 8 | 9 | import java.lang.reflect.{ Method => JMethod, Constructor => JConstructor } 10 | 11 | class DataFlow(val body: MethodBody, val instance: Instance[_ <: AnyRef]) { 12 | 13 | // TODO: detailed values from merge 14 | def possibleValues(l: Bytecode.Label, p: DataPort): Seq[Data] = 15 | Seq(dataValues(l -> p)) 16 | 17 | lazy val newInstances: Map[(Bytecode.Label, DataPort.Out), Instance.New[_ <: AnyRef]] = 18 | dataValues.collect { 19 | case (k @ (l, p: DataPort.Out), Data.AbstractReference(n: Instance.New[_])) => (l -> p) -> n.asInstanceOf[Instance.New[_ <: AnyRef]] 20 | }.toMap 21 | 22 | // TODO: [BUG] check leakage in ctor 23 | def escaped(label: Bytecode.Label, p: DataPort.Out): Boolean = { 24 | import Bytecode._ 25 | body.bytecode.exists { 26 | case (l, bc: InvokeMethod) => 27 | bc.args.exists { arg => dataSource(l, arg).mayProducedBy(label, p) } 28 | case (l, bc: FieldSetter) => 29 | dataSource(l, bc.value).mayProducedBy(label, p) 30 | case (l, bc: XReturn) => 31 | dataSource(l, bc.retval).mayProducedBy(label, p) 32 | case (l, bc @ athrow()) => 33 | dataSource(l, bc.objectref).mayProducedBy(label, p) 34 | case _ => false 35 | } 36 | } 37 | 38 | // TODO: refactor 39 | def escaped(ds: DataSource.Single): Boolean = { 40 | import Bytecode._ 41 | body.bytecode.exists { 42 | case (l, bc: InvokeMethod) => 43 | bc.args.exists { arg => dataSource(l, arg).may(ds) } 44 | case (l, bc: FieldSetter) => 45 | dataSource(l, bc.value).may(ds) 46 | case (l, bc: XReturn) => 47 | dataSource(l, bc.retval).may(ds) 48 | case (l, bc @ athrow()) => 49 | dataSource(l, bc.objectref).may(ds) 50 | case _ => false 51 | } 52 | } 53 | 54 | def onlyValue(l: Bytecode.Label, p: DataPort): Option[Data.Known] = { 55 | val pvs = possibleValues(l, p) 56 | if (pvs.size == 1) Some(pvs.head).collect { case d: Data.Known => d } 57 | else None 58 | } 59 | 60 | def dataSource(l: Bytecode.Label, p: DataPort): DataSource = 61 | dataSources.get(l -> p) getOrElse { 62 | throw new IllegalArgumentException(s"DataSource not found: $l ${body.labelToBytecode.get(l)}, $p") 63 | } 64 | 65 | // TODO: [BUG] track merge 66 | private[this] lazy val useSitesMap: Map[(Bytecode.Label, DataPort.Out), Map[Bytecode.Label, Set[DataPort.In]]] = 67 | body.bytecode.flatMap { 68 | case (l, bc) => 69 | bc.inputs 70 | .map { p => p -> dataSource(l, p) } 71 | .collect { case (p1, DataSource.HasLocation(l2, p2)) => (l2, p2) -> (l, p1) } 72 | }.groupBy(_._1).mapValues(_.map(_._2)).map { 73 | case (k, uses) => k -> uses.groupBy(_._1).mapValues(_.map(_._2).toSet) 74 | }.toMap 75 | 76 | def useSites(l: Bytecode.Label, p: DataPort.Out): Seq[(Bytecode.Label, Bytecode, Set[DataPort.In])] = 77 | useSitesMap.get(l -> p).getOrElse(Map.empty).map { case (l, ps) => (l, body.bytecodeFromLabel(l), ps) }.toSeq 78 | 79 | def constructor(l: Bytecode.Label, p: DataPort): Option[MethodRef] = ??? 80 | 81 | def argNum(l: Bytecode.Label, p: DataPort): Option[Int] = 82 | dataSources(l -> p) match { 83 | case DataSource.Argument(n) => Some(n) 84 | case _ => None 85 | } 86 | 87 | // Some(true): data has single value that point the instance 88 | // Some(false): data is not point the instance 89 | // None: not sure 90 | def isInstance(l: Bytecode.Label, p: DataPort, i: Instance[_ <: AnyRef]): Option[Boolean] = 91 | if (!dataType(l, p).isAssignableFrom(TypeRef.Reference(i.thisRef))) Some(false) 92 | else onlyValue(l, p).map(_.isInstance(i)) 93 | 94 | def mayInstance(l: Bytecode.Label, p: DataPort, i: Instance[_ <: AnyRef]): Boolean = 95 | isInstance(l, p, i) != Some(false) 96 | 97 | def mustInstance(l: Bytecode.Label, p: DataPort, i: Instance[_ <: AnyRef]): Boolean = 98 | isInstance(l, p, i) getOrElse { throw new RuntimeException(s"ambigious: $l $p") } 99 | 100 | def isThis(l: Bytecode.Label, p: DataPort): Option[Boolean] = 101 | dataSource(l, p).is(DataSource.This) 102 | 103 | def mustThis(l: Bytecode.Label, p: DataPort): Boolean = 104 | dataSource(l, p).must(DataSource.This) 105 | 106 | def usedFieldsOf(i: Instance[_ <: AnyRef]): Set[(ClassRef, FieldRef)] = 107 | body.bytecode.foldLeft(Set.empty[(ClassRef, FieldRef)]) { 108 | case (agg, (label, bc)) => 109 | import Bytecode._ 110 | bc match { 111 | case bc: InstanceFieldAccess if mustInstance(label, bc.objectref, i) => 112 | agg + (i.resolveField(bc.classRef, bc.fieldRef) -> bc.fieldRef) 113 | case _ => agg 114 | } 115 | } 116 | 117 | def usedFieldsOf(src: DataSource.Single, klass: Klass): Set[(ClassRef, FieldRef)] = 118 | body.bytecode.foldLeft(Set.empty[(ClassRef, FieldRef)]) { 119 | case (agg, (label, bc)) => 120 | import Bytecode._ 121 | bc match { 122 | case bc: InstanceFieldAccess if dataSource(label, bc.objectref).may(src) => 123 | agg + (klass.resolveInstanceField(bc.classRef, bc.fieldRef) -> bc.fieldRef) 124 | case _ => agg 125 | } 126 | } 127 | 128 | // TODO: DirectUsedMethod 129 | def usedMethodsOf(i: Instance[_ <: AnyRef]): Set[(ClassRef, MethodRef)] = 130 | body.bytecode.foldLeft(Set.empty[(ClassRef, MethodRef)]) { 131 | case (agg, (label, bc)) => 132 | import Bytecode._ 133 | bc match { 134 | case bc @ invokevirtual(cr, mr) if mayInstance(label, bc.objectref, i) => 135 | agg + (i.resolveVirtualMethod(mr) -> mr) 136 | case bc @ invokeinterface(cr, mr, _) if mayInstance(label, bc.objectref, i) => 137 | agg + (i.resolveVirtualMethod(mr) -> mr) 138 | case bc @ invokespecial(cr, mr) if mayInstance(label, bc.objectref, i) => 139 | // TODO: Special method resolution 140 | agg + (cr -> mr) 141 | case _ => agg 142 | } 143 | } 144 | 145 | def usedMethodsOf(src: DataSource.Single, klass: Klass): Set[(ClassRef, MethodRef)] = 146 | body.bytecode.foldLeft(Set.empty[(ClassRef, MethodRef)]) { 147 | case (agg, (label, bc)) => 148 | import Bytecode._ 149 | bc match { 150 | case bc @ invokevirtual(cr, mr) if dataSource(label, bc.objectref).may(src) => 151 | agg + (klass.resolveVirtualMethod(mr) -> mr) 152 | case bc @ invokeinterface(cr, mr, _) if dataSource(label, bc.objectref).may(src) => 153 | agg + (klass.resolveVirtualMethod(mr) -> mr) 154 | case bc @ invokespecial(cr, mr) if dataSource(label, bc.objectref).may(src) => 155 | // TODO: Special method resolution 156 | agg + (cr -> mr) 157 | case _ => agg 158 | } 159 | } 160 | 161 | def possibleReturns(l: Bytecode.Label): Seq[Bytecode.Return] = 162 | possibleExits(l).collect { case bc: Bytecode.Return => bc } 163 | 164 | def possibleExits(l: Bytecode.Label): Seq[Bytecode.Exit] = 165 | possibleExits0(l, Set.empty) 166 | .toSeq 167 | .map(body.labelToBytecode(_)) 168 | .map(_.asInstanceOf[Bytecode.Exit]) 169 | 170 | private[this] def possibleExits0(l: Bytecode.Label, ignore: Set[Bytecode.Label]): Set[Bytecode.Label] = 171 | if (ignore.contains(l)) Set.empty 172 | else body.labelToBytecode(l) match { 173 | case bc: Bytecode.Exit => Set(l) 174 | case bc => jumpDestinations(l).flatMap { l2 => possibleExits0(l2, ignore + l) } 175 | } 176 | 177 | lazy val jumpOrigins: Map[Bytecode.Label, Set[Bytecode.Label]] = { 178 | val m = Collections.newMultiMap[Bytecode.Label, Bytecode.Label] 179 | fallThroughs.foreach { case (from, to) => m.addBinding(to, from) } 180 | body.bytecode.foreach { 181 | case (label, bc: Bytecode.HasJumpTargets) => 182 | bc.jumpTargets.foreach { t => m.addBinding(label, body.jumpTargets(label -> t)) } 183 | case _ => 184 | } 185 | m.mapValues(_.toSet).toMap 186 | } 187 | 188 | def jumpDestinations(l: Bytecode.Label): Set[Bytecode.Label] = 189 | fallThroughs.get(l).toSet ++ 190 | (body.labelToBytecode(l) match { 191 | case bc: Bytecode.HasJumpTargets => bc.jumpTargets.map { jt => body.jumpTargets(l -> jt) } 192 | case _ => Set.empty 193 | }) 194 | 195 | lazy val initialFrame: Frame = { 196 | val thisData = 197 | if (body.isStatic) None 198 | else if(body.isInit) Some(FrameItem(DataSource.This, Data.Uninitialized(instance.klass.ref))) 199 | else Some(FrameItem(DataSource.This, Data.reference(instance))) 200 | val argData = body.descriptor.args.zipWithIndex.flatMap { 201 | case (t, i) => 202 | val source = DataSource.Argument(i) 203 | val data = Data.Unknown(t) 204 | if (t.isDoubleWord) 205 | Seq( 206 | FrameItem(source, data), 207 | FrameItem(source, data.secondWordData) 208 | ) 209 | else 210 | Seq(FrameItem(source, data)) 211 | } 212 | Frame((thisData.toSeq ++ argData).zipWithIndex.map(_.swap).toMap, List.empty) 213 | } 214 | def dataValue(l: Bytecode.Label, p: DataPort): Data = 215 | dataValues(l -> p) 216 | 217 | def dataType(l: Bytecode.Label, p: DataPort): TypeRef = dataValue(l, p).typeRef 218 | 219 | lazy val fallThroughs: Map[Bytecode.Label, Bytecode.Label] = { 220 | import Bytecode._ 221 | body.bytecode.sliding(2).map { 222 | case Seq() => Map.empty 223 | case Seq(_) => Map.empty 224 | case Seq((l1, bc1: FallThrough), (l2, bc2)) => Map(l1 -> l2) 225 | case Seq(_, _) => Map.empty 226 | }.foldLeft(Map.empty[Bytecode.Label, Bytecode.Label]) { (a, m) => a ++ m } 227 | } 228 | 229 | // Yes I know this is just a pattern matching, not type-annotation. But I need readability 230 | lazy val ( 231 | dataValues: Map[(Bytecode.Label, DataPort), Data], 232 | maxLocals: Int, 233 | maxStackDepth: Int, 234 | beforeFrames: Map[Bytecode.Label, Frame], 235 | dataSources: Map[(Bytecode.Label, DataPort), DataSource] 236 | ) = { 237 | val effectMerges = new AbstractLabel.Merger[Effect](Effect.fresh()) 238 | def mergeData(d1: FrameItem, d2: FrameItem): FrameItem = 239 | d1.merge(d2) 240 | def merge(f1: Frame, f2: Frame): Frame = { 241 | Frame( 242 | (f1.locals.keySet ++ f2.locals.keySet) 243 | .filter { k => f1.locals.contains(k) && f2.locals.contains(k) } 244 | .map { k => (k -> mergeData(f1.locals(k), f2.locals(k))) }.toMap, 245 | f1.stack.zip(f2.stack).map { case (a, b) => mergeData(a, b) } 246 | ) 247 | } 248 | 249 | val preFrames = mutable.HashMap.empty[Bytecode.Label, Frame] 250 | val updates = mutable.HashMap.empty[Bytecode.Label, FrameUpdate] 251 | val falls = mutable.HashMap.empty[Bytecode.Label, Bytecode.Label] 252 | 253 | val liveBcs = mutable.HashMap.empty[Bytecode.Label, Bytecode] 254 | 255 | val tasks = mutable.Set.empty[(Bytecode.Label, Frame)] 256 | tasks += (body.bytecode.head._1 -> initialFrame) 257 | 258 | while (tasks.nonEmpty) { 259 | val (pos, frame) = tasks.head 260 | tasks.remove(pos -> frame) 261 | val merged = preFrames.get(pos).map(merge(_, frame)) getOrElse frame 262 | if (preFrames.get(pos).map(_ != merged) getOrElse true) { 263 | preFrames(pos) = merged 264 | val bseq = body.bytecode.dropWhile(_._1 != pos) 265 | val (label, bc) = bseq.head 266 | assert(label == pos) 267 | liveBcs(label) = bc 268 | val u = try { 269 | bc.nextFrame(label, merged) 270 | } catch { 271 | case scala.util.control.NonFatal(e) => 272 | throw new RuntimeException(s"Errow while dataflow analysis: ${e.getMessage}: ${label.format("L%d")} $bc, frame={\n${frame.pretty}\n}", e) 273 | } 274 | updates(label) = u 275 | bc match { 276 | case r: Bytecode.Return => 277 | case j: Bytecode.Jump => 278 | tasks += (body.jumpTargets(label -> j.target) -> u.newFrame) 279 | case b: Bytecode.Branch => 280 | tasks += (body.jumpTargets(label -> b.target) -> u.newFrame) 281 | tasks += (bseq(1)._1 -> u.newFrame) 282 | case _: Bytecode.Procedure | _: Bytecode.Shuffle => 283 | tasks += (bseq(1)._1 -> u.newFrame) 284 | case Bytecode.athrow() => 285 | // TODO: Exception handler 286 | } 287 | } 288 | } 289 | 290 | val allFrames = preFrames.values ++ updates.values.map(_.newFrame) 291 | val maxLocals = allFrames.flatMap(_.locals.keys).max + 1 292 | val maxStackDepth = allFrames.map(_.stack.size).max 293 | val dataValues: Map[(Bytecode.Label, DataPort), Data] = 294 | updates.values.flatMap(_.dataValues).toMap ++ updates.values.flatMap(_.initializes) 295 | val dataSources = updates.values.flatMap(_.dataSources).toMap 296 | 297 | (dataValues, maxLocals, maxStackDepth, preFrames.toMap, dataSources) 298 | } 299 | 300 | def pretty: String = { 301 | val format = "L%03d" 302 | def formatData(l: Bytecode.Label, p: DataPort, d: Data): String = { 303 | val typeStr = if (d.typeRef == instance.klass.ref.toTypeRef) "this.class" else d.typeRef.toString 304 | val data = s"$typeStr = ${d.valueString}" 305 | isThis(l, p).fold { 306 | s"$data(this?)" 307 | } { yes => 308 | if (yes) s"this" 309 | else data 310 | } 311 | } 312 | body.bytecode.map { 313 | case (label, bc) => 314 | val base = s"${label.format(format)} ${bc.pretty}" 315 | val in = bc.inputs.map { in => s" # ${in.name}: ${dataSource(label, in)}, ${possibleValues(label, in).map(formatData(label, in, _)).mkString(", ")}" } 316 | val out = bc.output.map { out => s" # ${out.name}: ${dataSource(label, out)}, ${possibleValues(label, out).map(formatData(label, out, _)).mkString(", ")}" }.toSeq 317 | (Seq(base) ++ in ++ out).mkString("\n") 318 | }.mkString("\n") 319 | } 320 | } 321 | 322 | -------------------------------------------------------------------------------- /core/src/main/scala/Transformer.scala: -------------------------------------------------------------------------------- 1 | package com.todesking.unveil 2 | 3 | import scala.language.existentials 4 | 5 | import scala.collection.mutable 6 | import scala.util.{ Try, Success, Failure } 7 | 8 | trait Transformer { self => 9 | def name: String 10 | def params: Map[String, String] = Map.empty 11 | 12 | def apply[A <: AnyRef](orig: Instance.Concrete[A], el: EventLogger): Try[Instance.Duplicate[A]] = 13 | try { 14 | el.enterTransformer(this, orig) { el => Success(apply0(orig, el)) } 15 | } catch { 16 | case e: UnveilException => 17 | el.fail(e) 18 | Failure(e) 19 | } 20 | protected[this] def apply0[A <: AnyRef](orig: Instance.Concrete[A], el: EventLogger): Instance.Duplicate[A] 21 | 22 | def andThen(next: Transformer): Transformer = 23 | new Transformer { 24 | override def name = s"${self.name} >>> ${next.name}" 25 | override def apply[A <: AnyRef](orig: Instance.Concrete[A], el: EventLogger): Try[Instance.Duplicate[A]] = { 26 | el.enterTransformer(this, orig) { el => 27 | self.apply(orig, el).flatMap { i2 => next.apply(i2, el) } 28 | } 29 | } 30 | override def apply0[A <: AnyRef](orig: Instance.Concrete[A], el: EventLogger) = throw new AssertionError() 31 | } 32 | 33 | def >>>(next: Transformer): Transformer = 34 | this andThen next 35 | } 36 | object Transformer { 37 | // TODO: remove this 38 | def newEventLogger(): EventLogger = 39 | new EventLogger 40 | 41 | // TODO: support instance-stateful fields(need leakage detection) 42 | // TODO: support mutable fields(if fref eq original then optimized else original) 43 | object fieldFusion extends Transformer { 44 | override def name = s"fieldFusion" 45 | override def apply0[A <: AnyRef](instance: Instance.Concrete[A], el: EventLogger): Instance.Duplicate[A] = { 46 | val dupInstance = instance.duplicate1(el) 47 | fuse( 48 | "", 49 | dupInstance, 50 | dupInstance 51 | .rewritableVirtualMethods 52 | .keySet 53 | .filterNot { mr => dupInstance.resolveVirtualMethod(mr) == ClassRef.Object } 54 | .map { mr => dupInstance.resolveVirtualMethod(mr) -> mr } ++ ( 55 | dupInstance.klass 56 | .declaredMethods 57 | .keySet 58 | .map { mr => dupInstance.thisRef -> mr } 59 | ), 60 | el 61 | ) 62 | } 63 | // TODO: prevent inf loop 64 | private[this] def fuse[A <: AnyRef]( 65 | memberPrefix: String, 66 | self: Instance.Duplicate[A], 67 | methods: Set[(ClassRef, MethodRef)], 68 | el: EventLogger 69 | ): Instance.Duplicate[A] = { 70 | val dfs = 71 | methods 72 | .toSeq 73 | .map { case k @ (cr, mr) => k -> self.dataflow(cr, mr) } 74 | .toMap 75 | val usedFields: Set[(ClassRef, FieldRef)] = 76 | dfs.values 77 | .map(_.usedFieldsOf(self)) 78 | .reduceLeftOption(_ ++ _).getOrElse(Set.empty) 79 | .filter { case (cr, fr) => self.fields(cr -> fr).attribute.isFinal } 80 | el.logCFields("used fields from target methods", usedFields) 81 | 82 | usedFields.foldLeft(self) { 83 | case (self, (fcr, fr)) => 84 | el.enterField(fcr, fr) { el => 85 | self.fields(fcr -> fr).data match { 86 | // TODO[refactor]: instance.isInstanceStateful 87 | case Data.ConcreteReference(fieldInstance) if !fieldInstance.fields.forall(_._2.attribute.isFinal) => 88 | el.log("Pass: This field is instance-stateful") 89 | self 90 | case Data.ConcreteReference(fieldInstance) => 91 | // TODO: log 92 | val usedMethods = fieldInstance.klass.extendMethods( 93 | methods.flatMap { 94 | case (cr, mr) => 95 | self.dataflow(cr, mr).usedMethodsOf(fieldInstance) 96 | } 97 | ) 98 | el.logCMethods("used methods in the field", usedMethods) 99 | 100 | val methodRenaming = 101 | usedMethods.map { 102 | case k @ (cr, mr) => 103 | k -> mr.anotherUniqueName(memberPrefix + fr.name, mr.name) 104 | }.toMap 105 | val fieldRenaming = 106 | fieldInstance.fields.keys.map { 107 | case k @ (cr, fr1) => 108 | k -> fr1.anotherUniqueName(memberPrefix + fr.name, fr1.name) 109 | }.toMap 110 | val newFields = fieldRenaming.map { case ((cr, fr), nfr) => nfr -> fieldInstance.fields(cr, fr) } 111 | val newMethods = 112 | usedMethods 113 | .toIterable 114 | .map { case k @ (cr, mr) => k -> fieldInstance.dataflow(cr, mr) } 115 | .toMap 116 | .map { 117 | case ((cr, mr), df) => 118 | methodRenaming(cr -> mr) -> df.body.rewrite { 119 | case (label, bc: Bytecode.InvokeInstanceMethod) if df.mustThis(label, bc.objectref) => 120 | Bytecode.invokespecial(self.thisRef, methodRenaming(bc.classRef, bc.methodRef)) 121 | case (label, bc: Bytecode.InstanceFieldAccess) if df.mustThis(label, bc.objectref) => 122 | bc.rewriteFieldRef(self.thisRef, fieldRenaming(bc.classRef, bc.fieldRef)) 123 | }.makePrivate 124 | } 125 | val rewrittenMethods = 126 | methods 127 | .map { 128 | case (cr, mr) => 129 | val df = self.dataflow(cr, mr) 130 | import Bytecode._ 131 | val leaked = 132 | df.body.bytecode 133 | .exists { 134 | case (label, bc: Bytecode.Shuffle) => false 135 | case (label, bc: Bytecode.Control) => false 136 | case (label, bc @ getfield(_, _)) => false 137 | case (label, bc @ putfield(_, _)) => 138 | df.dataSource(label, bc.value).mayFieldAccess(fcr, fr) 139 | case (label, bc: InvokeMethod) => 140 | bc.args.exists { arg => df.dataSource(label, arg).mayFieldAccess(fcr, fr) } 141 | case _ => false 142 | } 143 | if (leaked) { 144 | el.log(s"[SKIP] the field is leaked in method $mr") 145 | mr -> df.body 146 | } else { 147 | // TODO: use df.mustFieldRef instead of df.mustInstance when rewriting 148 | mr -> df.body.rewrite { 149 | case (label, bc @ getfield(cr1, fr1)) if df.mustThis(label, bc.objectref) && self.resolveField(cr1, fr1) == cr && fr1 == fr => 150 | nop() 151 | case (label, bc: InvokeInstanceMethod) if df.mustInstance(label, bc.objectref, fieldInstance) => 152 | methodRenaming.get(fieldInstance.resolveVirtualMethod(bc.methodRef) -> bc.methodRef).fold { 153 | throw new AssertionError(s"Can't find renamed method for ${bc.classRef}.${bc.methodRef}") 154 | } { mr => 155 | invokespecial(self.thisRef, mr) 156 | } 157 | case (label, bc: InstanceFieldAccess) if df.mustInstance(label, bc.objectref, fieldInstance) => 158 | fieldRenaming.get(fieldInstance.resolveField(bc.classRef, bc.fieldRef) -> bc.fieldRef).fold(bc) { 159 | case fr => 160 | bc.rewriteFieldRef(self.thisRef, fr) 161 | } 162 | } 163 | } 164 | }.toMap 165 | val newSelf = self.addFields(newFields).addMethods(newMethods).addMethods(rewrittenMethods) 166 | el.section(s"Fuse new methods from ${fr.name}") { el => 167 | fuse(memberPrefix + fr.name + "__", newSelf, newMethods.keys.map { case mr => (self.thisRef -> mr) }.toSet, el) 168 | } 169 | case _ => 170 | el.log("Pass") 171 | self 172 | } 173 | } 174 | } 175 | } 176 | } 177 | 178 | object methodInlining extends Transformer { 179 | override def name = "methodInlining" 180 | override def apply0[A <: AnyRef](orig: Instance.Concrete[A], el: EventLogger): Instance.Duplicate[A] = { 181 | orig 182 | .rewritableVirtualMethods 183 | .keys 184 | .filterNot { mr => orig.resolveVirtualMethod(mr) == ClassRef.Object } 185 | .foldLeft(orig.duplicate1(el)) { 186 | case (self, mr) => 187 | val cr = self.resolveVirtualMethod(mr) 188 | el.log(s"Inlining $mr") 189 | val inlined = 190 | el.enterMethod(cr, mr) { el => inline(self.dataflow(cr, mr), Set(cr -> mr), el) } 191 | self.addMethod(mr, inlined) 192 | } 193 | } 194 | 195 | private[this] def inline(df: DataFlow, ignore: Set[(ClassRef, MethodRef)], el: EventLogger): MethodBody = { 196 | // TODO: if(df.localModified(0)) df.body 197 | var localOffset = df.maxLocals 198 | import Bytecode._ 199 | df.body.rewrite_* { 200 | case (label, bc: InvokeInstanceMethod) if df.mustThis(label, bc.objectref) => 201 | el.section(s"Inline invocation of ${bc.classRef}.${bc.methodRef}") { el => 202 | val mr = bc.methodRef 203 | val cr = 204 | bc match { 205 | case invokespecial(cr, mr) => 206 | // TODO[BUG]: resolve special 207 | cr 208 | case invokevirtual(cr, mr) => 209 | df.instance.resolveVirtualMethod(mr) 210 | case invokeinterface(cr, mr, _) => 211 | df.instance.resolveVirtualMethod(mr) 212 | } 213 | val calleeDf = 214 | if (ignore.contains(cr -> mr)) df.instance.dataflow(cr, mr) 215 | else inline(df.instance.dataflow(cr, mr), ignore + (cr -> mr), el).dataflow(df.instance) 216 | 217 | // TODO[BUG]: if(calleeDf.localModified(0)) ... 218 | val argOffset = if (calleeDf.body.isStatic) localOffset else localOffset + 1 219 | // TODO: support exception 220 | val cf = 221 | calleeDf.body.rewrite_* { 222 | case (_, bc: LocalAccess) => 223 | CodeFragment.bytecode(bc.rewriteLocalIndex(bc.localIndex + localOffset)) 224 | case (label, bc: XReturn) => 225 | // TODO: [BUG] goto tail 226 | val resultLocal = localOffset + calleeDf.maxLocals 227 | CodeFragment.bytecode( 228 | Seq(autoStore(bc.returnType, resultLocal)) ++ ( 229 | calleeDf.beforeFrames(label).stack.drop(bc.returnType.wordSize).map { 230 | case FrameItem(src, d) => 231 | autoPop(d.typeRef) 232 | } 233 | ) ++ Seq( 234 | autoLoad(bc.returnType, resultLocal) 235 | ) 236 | : _* 237 | ) 238 | case (label, bc: VoidReturn) => 239 | CodeFragment.bytecode( 240 | calleeDf.beforeFrames(label).stack.map { 241 | case FrameItem(src, d) => 242 | autoPop(d.typeRef) 243 | }: _* 244 | ) 245 | }.codeFragment 246 | .prepend( 247 | CodeFragment.bytecode( 248 | mr.descriptor.args.reverse.zipWithIndex.map { case (t, i) => autoStore(t, i + argOffset) } ++ 249 | (if (calleeDf.body.isStatic) Seq.empty else Seq(astore(localOffset))) 250 | : _* 251 | ) 252 | ) 253 | localOffset += calleeDf.maxLocals + 1 // TODO: inefficient if static method 254 | cf 255 | } 256 | } 257 | } 258 | } 259 | 260 | object localInstanceInlining extends Transformer { 261 | override def name = "localInstanceInlining" 262 | override def apply0[A <: AnyRef](orig: Instance.Concrete[A], el: EventLogger): Instance.Duplicate[A] = { 263 | orig 264 | .rewritableVirtualMethods 265 | .keys 266 | .filterNot { mr => orig.resolveVirtualMethod(mr) == ClassRef.Object } 267 | .foldLeft(orig.duplicate1(el)) { 268 | case (self, mr) => 269 | val cr = self.resolveVirtualMethod(mr) 270 | el.log(s"Local instance inlining: $mr") 271 | val inlined = 272 | el.enterMethod(cr, mr) { el => inline(self.dataflow(cr, mr), el) } 273 | self.addMethod(mr, inlined) 274 | } 275 | } 276 | 277 | private[this] def inline(df: DataFlow, el: EventLogger): MethodBody = { 278 | import Bytecode._ 279 | 280 | el.log("New instances:") 281 | df.newInstances.foreach { case ((l, p), i) => 282 | el.log(s" - ${l},${p}: ${i}") 283 | } 284 | 285 | // TODO: check all required method is inlinable 286 | val inlinables: Map[(Bytecode.Label, DataPort.Out), (Instance.New[_ <: AnyRef], Map[(ClassRef, FieldRef), Int])] = 287 | // TODO: check used method only in ni.escaped 288 | df.newInstances 289 | .filter { case ((l, p), ni) => !df.escaped(l, p) && !ni.escaped } 290 | .filter { case ((l, p), ni) => 291 | df.useSites(l, p).forall { case (ul, ubc, ups) => 292 | ups.forall { up => df.dataSource(ul, up).unambiguous } 293 | } 294 | }.map { 295 | case (v, d) => 296 | v -> (d -> d.klass.instanceFieldAttributes.keys.zipWithIndex.toMap) 297 | }.toMap 298 | 299 | el.log("Inlinable:") 300 | df.newInstances.foreach { case ((l, p), i) => 301 | el.log(s" - ${l},${p}: ${i}") 302 | } 303 | 304 | // TODO: [BUG] manage local offset for each inlinable instance 305 | def toInlineForm( 306 | base: DataFlow, 307 | fieldMap: Map[(ClassRef, FieldRef), Int], 308 | inlined: Set[(ClassRef, MethodRef)] = Set() 309 | ): CodeFragment = { 310 | require(!base.body.isStatic) 311 | val retValIndex = df.maxLocals 312 | val fieldLocalOffset = retValIndex + 1 313 | val localOffset = fieldLocalOffset + fieldMap.size 314 | 315 | val prepareArgs = CodeFragment.bytecode( 316 | base.body.descriptor.args.zipWithIndex.reverse.map { case (t, i) => 317 | autoStore(t, 1 + i + localOffset) 318 | } :+ astore(localOffset): _* 319 | ) 320 | val inlinableBody = base.body.codeFragment 321 | .rewrite_* { 322 | case (l, bc @ getfield(cr, fr)) if base.mustThis(l, bc.objectref) => 323 | val index = fieldLocalOffset + fieldMap(base.instance.resolveField(cr, fr) -> fr) 324 | CodeFragment.bytecode( 325 | autoLoad(fr.typeRef, index) 326 | ) 327 | case (l, bc @ putfield(cr, fr)) if base.mustThis(l, bc.objectref) => 328 | val index = fieldLocalOffset + fieldMap(base.instance.resolveField(cr, fr) -> fr) 329 | CodeFragment.bytecode( 330 | autoStore(fr.typeRef, index) 331 | ) 332 | case (_, bc: LocalAccess) => 333 | CodeFragment.bytecode(bc.rewriteLocalIndex(bc.localIndex + localOffset)) 334 | case (l, bc: InvokeInstanceMethod) if base.mustThis(l, bc.objectref) => 335 | val cr = bc.resolveMethod(base.instance) 336 | if (inlined.contains(cr -> bc.methodRef)) 337 | throw new RuntimeException("recursive method not supported") 338 | toInlineForm( 339 | base.instance.dataflow(cr, bc.methodRef), 340 | fieldMap, 341 | inlined + (cr -> bc.methodRef) 342 | ) 343 | case (l, bc: Return) => 344 | val cleanupStack = 345 | CodeFragment.bytecode( 346 | base.beforeFrames(l).stack 347 | .drop(base.body.descriptor.ret.wordSize).map { 348 | case FrameItem(src, d) => 349 | autoPop(d.typeRef) 350 | } 351 | : _* 352 | ) 353 | val gotoExit = 354 | CodeFragment.abstractJump(goto(), "exit") 355 | 356 | if(base.body.descriptor.isVoid) { 357 | cleanupStack + gotoExit 358 | } else { 359 | val saveRetVal = 360 | CodeFragment.bytecode(autoStore(base.body.descriptor.ret, retValIndex)) 361 | saveRetVal + cleanupStack + gotoExit 362 | } 363 | } 364 | val exit = 365 | (if(base.body.descriptor.isVoid) { 366 | CodeFragment.empty 367 | } else { 368 | CodeFragment.bytecode(autoLoad(base.body.descriptor.ret, retValIndex)) 369 | }).name("exit") 370 | (prepareArgs + inlinableBody + exit).complete() 371 | } 372 | 373 | def inlinable(l: Bytecode.Label, p: DataPort): Option[(Instance.New[_ <: AnyRef], Map[(ClassRef, FieldRef), Int])] = 374 | df.dataSource(l, p).single.collect { case DataSource.New(l, p) => inlinables.get(l -> p) }.flatten 375 | 376 | df.body.rewrite_* { 377 | case (l, bc @ new_(cr)) if inlinables.contains(l -> bc.objectref) => 378 | CodeFragment.bytecode(aconst_null()) // dummy value 379 | case (l, bc: InvokeInstanceMethod) if inlinable(l, bc.objectref).nonEmpty => 380 | val Some((newInstance, fieldMap)) = inlinable(l, bc.objectref) 381 | toInlineForm(newInstance.dataflow(bc.resolveMethod(newInstance), bc.methodRef), fieldMap) 382 | case (l, bc @ getfield(cr, fr)) if inlinable(l, bc.objectref).nonEmpty => 383 | val Some((newInstance, fieldMap)) = inlinable(l, bc.objectref) 384 | CodeFragment.bytecode(pop(), autoLoad(fr.typeRef, fieldMap(cr -> fr))) 385 | case (l, bc @ putfield(cr, fr)) if inlinable(l, bc.objectref).nonEmpty => 386 | val Some((newInstance, fieldMap)) = inlinable(l, bc.objectref) 387 | CodeFragment.bytecode(autoStore(fr.typeRef, fieldMap(cr -> fr)), pop()) 388 | } 389 | } 390 | } 391 | 392 | // TODO: eliminate load-pop pair etc 393 | // TODO: check access rights and resolve members 394 | } 395 | -------------------------------------------------------------------------------- /core/src/main/scala/Javassist.scala: -------------------------------------------------------------------------------- 1 | package com.todesking.unveil 2 | 3 | import javassist.{ ClassPool, CtClass, CtBehavior } 4 | import javassist.bytecode.{ CodeAttribute, ConstPool, Bytecode => JABytecode, MethodInfo } 5 | import java.lang.reflect.{ Method => JMethod, Constructor } 6 | 7 | import scala.collection.mutable 8 | 9 | object Javassist { 10 | def ctClass(tr: TypeRef): CtClass = { 11 | tr match { 12 | case TypeRef.Int => CtClass.intType 13 | case TypeRef.Reference(ClassRef.Concrete(name, cl)) => 14 | val pool = buildPool(cl) 15 | pool.get(name) 16 | case unk => throw new NotImplementedError(s"${unk}") 17 | } 18 | } 19 | 20 | def compile(classPool: ClassPool, constPool: ConstPool, df: DataFlow): CodeAttribute = { 21 | val body = df.body 22 | val ctObject = classPool.get("java.lang.Object") 23 | val out = new JABytecode(constPool, 0, 0) 24 | val jumps = mutable.HashMap.empty[Int, (Int, JumpTarget)] // jump operand address -> (insn addr -> target) 25 | val label2addr = mutable.HashMap.empty[Bytecode.Label, Int] 26 | val addr2label = mutable.HashMap.empty[Int, Bytecode.Label] 27 | import Bytecode._ 28 | body.bytecode foreach { 29 | case (label, bc) => 30 | label2addr(label) = out.getSize 31 | addr2label(out.getSize) = label 32 | bc match { 33 | case nop() => 34 | out.add(0x00) 35 | case aconst_null() => 36 | out.addConstZero(ctObject) 37 | case vreturn() => 38 | out.addReturn(null) 39 | case ireturn() => 40 | out.addReturn(CtClass.intType) 41 | case lreturn() => 42 | out.addReturn(CtClass.longType) 43 | case areturn() => 44 | out.add(0xB0) 45 | case freturn() => 46 | out.add(0xAE) 47 | case dreturn() => 48 | out.add(0xAF) 49 | case iload(n) => 50 | out.addIload(n) 51 | case aload(n) => 52 | out.addAload(n) 53 | case fload(n) => 54 | out.addFload(n) 55 | case dload(n) => 56 | out.addDload(n) 57 | case lload(n) => 58 | out.addLload(n) 59 | case istore(n) => 60 | out.addIstore(n) 61 | case astore(n) => 62 | out.addAstore(n) 63 | case dstore(n) => 64 | out.addDstore(n) 65 | case iconst(c) => 66 | out.addIconst(c) 67 | case lconst(c) => 68 | out.addLconst(c) 69 | case goto(target) => 70 | out.add(0xA7) 71 | jumps(out.getSize) = (out.getSize - 1) -> target 72 | out.add(0x00, 0x03) 73 | case dup() => 74 | out.add(0x59) 75 | case ldc2_double(value) => 76 | out.addLdc2w(value) 77 | case pop() => 78 | out.add(0x57) 79 | case pop2() => 80 | out.add(0x58) 81 | case iadd() => 82 | out.add(0x60) 83 | case dadd() => 84 | out.add(0x63) 85 | case dsub() => 86 | out.add(0x67) 87 | case imul() => 88 | out.add(0x68) 89 | case isub() => 90 | out.add(0x64) 91 | case dmul() => 92 | out.add(0x6B) 93 | case i2d() => 94 | out.add(0x87) 95 | case d2i() => 96 | out.add(0x8E) 97 | case if_acmpne(target) => 98 | out.add(0xA6) 99 | jumps(out.getSize) = (out.getSize - 1) -> target 100 | out.add(0x00, 0x03) 101 | case invokevirtual(classRef, methodRef) => 102 | // TODO: check resolved class 103 | out.addInvokevirtual(classRef.binaryName, methodRef.name, methodRef.descriptor.str) 104 | case invokespecial(classRef, methodRef) => 105 | // TODO: check resolved class 106 | out.addInvokespecial(classRef.binaryName, methodRef.name, methodRef.descriptor.str) 107 | case invokestatic(classRef, methodRef) => 108 | out.addInvokestatic(classRef.binaryName, methodRef.name, methodRef.descriptor.str) 109 | case invokeinterface(classRef, methodRef, count) => 110 | out.addInvokeinterface(classRef.binaryName, methodRef.name, methodRef.descriptor.str, count) 111 | case if_icmpge(target) => 112 | out.add(0xA2) 113 | jumps(out.getSize) = (out.getSize - 1) -> target 114 | out.add(0x00, 0x03) 115 | case if_icmple(target) => 116 | out.add(0xA4) 117 | jumps(out.getSize) = (out.getSize - 1) -> target 118 | out.add(0x00, 0x03) 119 | case ifnonnull(target) => 120 | out.add(0xC7) 121 | jumps(out.getSize) = (out.getSize - 1) -> target 122 | out.add(0x00, 0x03) 123 | case athrow() => 124 | out.add(0xBF) 125 | case getfield(classRef, fieldRef) => 126 | out.addGetfield(classRef.binaryName, fieldRef.name, fieldRef.descriptor.str) 127 | case getstatic(classRef, fieldRef) => 128 | out.addGetstatic(classRef.binaryName, fieldRef.name, fieldRef.descriptor.str) 129 | case putfield(classRef, fieldRef) => 130 | out.addPutfield(classRef.binaryName, fieldRef.name, fieldRef.descriptor.str) 131 | case new_(classRef) => 132 | out.addNew(classRef.binaryName) 133 | } 134 | } 135 | jumps foreach { 136 | case (dataIndex, (index, target)) => 137 | val label = body.jumpTargets(addr2label(index) -> target) 138 | val targetIndex = label2addr(label) 139 | out.write16bit(dataIndex, targetIndex - index) 140 | } 141 | out.setMaxLocals(df.maxLocals) 142 | out.setMaxStack(df.maxStackDepth) 143 | out.toCodeAttribute 144 | } 145 | 146 | def decompile(m: JMethod): Option[MethodBody] = { 147 | require(m != null) 148 | 149 | val jClass = m.getDeclaringClass 150 | val classPool = buildPool(jClass) 151 | 152 | val ctClass = classPool.get(jClass.getName) 153 | val mRef = MethodRef.from(m) 154 | 155 | val ctMethod = ctClass.getMethod(mRef.name, mRef.descriptor.str) 156 | 157 | decompile0(jClass, mRef, ctMethod) 158 | } 159 | 160 | def decompile(m: Constructor[_]): Option[MethodBody] = { 161 | val classPool = buildPool(m.getDeclaringClass) 162 | val jClass = m.getDeclaringClass 163 | val ctClass = classPool.get(jClass.getName) 164 | val mRef = MethodRef.from(m) 165 | val ctMethod = ctClass.getConstructor(mRef.descriptor.str) 166 | decompile0(jClass, mRef, ctMethod) 167 | } 168 | 169 | private[this] def buildPool(jClass: Class[_]): ClassPool = { 170 | import javassist.{ ClassClassPath, ByteArrayClassPath } 171 | 172 | val classPool = new ClassPool(null) 173 | Instance.findMaterializedClasses(jClass.getClassLoader).foreach { 174 | case (name, bytes) => 175 | classPool.appendClassPath(new ByteArrayClassPath(name, bytes)) 176 | } 177 | classPool.appendClassPath(new ClassClassPath(jClass)) 178 | classPool 179 | } 180 | 181 | private[this] def buildPool(cl: ClassLoader): ClassPool = { 182 | import javassist.{ LoaderClassPath, ByteArrayClassPath } 183 | 184 | val classPool = new ClassPool(null) 185 | Instance.findMaterializedClasses(cl).foreach { 186 | case (name, bytes) => 187 | classPool.appendClassPath(new ByteArrayClassPath(name, bytes)) 188 | } 189 | classPool.appendClassPath(new LoaderClassPath(if (cl == null) ClassLoader.getSystemClassLoader else cl)) 190 | classPool 191 | } 192 | 193 | private[this] def getField(obj: AnyRef, fr: FieldRef): Any = { 194 | val f = obj.getClass.getDeclaredFields.find(_.getName == fr.name).get 195 | f.setAccessible(true) 196 | f.get(obj) 197 | } 198 | 199 | private[this] object cpools { 200 | private[this] val cl = getClass.getClassLoader 201 | val getItemMethodRef = MethodRef.parse("getItem(I)Ljavassist/bytecode/ConstInfo;", cl) 202 | val constPoolClassRef = ClassRef.of(classOf[ConstPool]) 203 | val doubleValue = FieldRef("value", FieldDescriptor(TypeRef.Double)) 204 | } 205 | private[this] def getConstantFromCpool(cpool: ConstPool, index: Int): Any = { 206 | val getItem = Reflect.allJMethods(cpool.getClass).apply(cpools.constPoolClassRef -> cpools.getItemMethodRef) 207 | getItem.setAccessible(true) 208 | val item = getItem.invoke(cpool, index.asInstanceOf[Object]) 209 | item.getClass.getSimpleName match { 210 | case "DoubleInfo" => 211 | getField(item, cpools.doubleValue) 212 | case unk => 213 | throw new NotImplementedError(s"Constant pool item $unk") 214 | } 215 | } 216 | 217 | private[this] def decompile0(jClass: Class[_], mRef: MethodRef, ctMethod: CtBehavior): Option[MethodBody] = { 218 | if (ctMethod.getMethodInfo2.getCodeAttribute == null) { 219 | None 220 | } else { 221 | val isStatic = (ctMethod.getMethodInfo2.getAccessFlags & 0x08) == 0x08 222 | 223 | val codeAttribute = ctMethod.getMethodInfo2.getCodeAttribute 224 | val it = codeAttribute.iterator 225 | val cpool = ctMethod.getDeclaringClass.getClassFile.getConstPool 226 | val bcs = mutable.ArrayBuffer.empty[Bytecode] 227 | val addr2label = mutable.HashMap.empty[Int, Bytecode.Label] 228 | val jumps = mutable.HashMap.empty[(Bytecode.Label, JumpTarget), Int] 229 | 230 | def onInstruction(index: Int, bc: Bytecode): Unit = { 231 | val label = Bytecode.Label(bcs.size) 232 | addr2label(index) = label 233 | bcs += bc 234 | } 235 | 236 | while (it.hasNext) { 237 | val index = it.next() 238 | import Bytecode._ 239 | it.byteAt(index) match { 240 | case 0x00 => // nop 241 | onInstruction(index, nop()) 242 | case 0x01 => // aconst_null 243 | onInstruction(index, aconst_null()) 244 | 245 | case 0x03 => // iconst_0 246 | onInstruction(index, iconst(0)) 247 | case 0x04 => // iconst_1 248 | onInstruction(index, iconst(1)) 249 | case 0x05 => // iconst_2 250 | onInstruction(index, iconst(2)) 251 | case 0x06 => // iconst_3 252 | onInstruction(index, iconst(3)) 253 | case 0x07 => // iconst_4 254 | onInstruction(index, iconst(4)) 255 | case 0x08 => // iconst_5 256 | onInstruction(index, iconst(5)) 257 | case 0x09 => // lconst_0 258 | onInstruction(index, lconst(0)) 259 | case 0x10 => // bipush 260 | onInstruction(index, iconst(it.signedByteAt(index + 1))) 261 | case 0x11 => // sipush 262 | onInstruction(index, iconst(it.s16bitAt(index + 1))) 263 | 264 | case 0x14 => // ldc2_w 265 | getConstantFromCpool(cpool, it.s16bitAt(index + 1)) match { 266 | case d: Double => 267 | onInstruction(index, ldc2_double(d)) 268 | } 269 | 270 | case 0x16 => // lload 271 | onInstruction(index, lload(it.byteAt(index + 1))) 272 | case 0x17 => // fload 273 | onInstruction(index, fload(it.byteAt(index + 1))) 274 | case 0x18 => // dload 275 | onInstruction(index, dload(it.byteAt(index + 1))) 276 | case 0x19 => // aload 277 | onInstruction(index, aload(it.byteAt(index + 1))) 278 | case 0x1A => // iload_0 279 | onInstruction(index, iload(0)) 280 | case 0x1B => // iload_1 281 | onInstruction(index, iload(1)) 282 | case 0x1C => // iload_2 283 | onInstruction(index, iload(2)) 284 | case 0x1D => // iload_3 285 | onInstruction(index, iload(3)) 286 | case 0x1E => // lload_0 287 | onInstruction(index, lload(0)) 288 | case 0x1F => // lload_1 289 | onInstruction(index, lload(1)) 290 | case 0x20 => // lload_2 291 | onInstruction(index, lload(2)) 292 | case 0x21 => // lload_3 293 | onInstruction(index, lload(3)) 294 | case 0x22 => // fload_0 295 | onInstruction(index, fload(0)) 296 | case 0x23 => // fload_1 297 | onInstruction(index, fload(1)) 298 | case 0x24 => // fload_2 299 | onInstruction(index, fload(2)) 300 | case 0x25 => // fload_3 301 | onInstruction(index, fload(3)) 302 | case 0x26 => // dload_0 303 | onInstruction(index, dload(0)) 304 | case 0x27 => // dload_1 305 | onInstruction(index, dload(1)) 306 | case 0x28 => // dload_2 307 | onInstruction(index, dload(2)) 308 | case 0x29 => // dload_3 309 | onInstruction(index, dload(3)) 310 | case 0x2A => // aload_0 311 | onInstruction(index, aload(0)) 312 | case 0x2B => // aload_1 313 | onInstruction(index, aload(1)) 314 | case 0x2C => // aload_2 315 | onInstruction(index, aload(2)) 316 | case 0x2D => // aload_3 317 | onInstruction(index, aload(3)) 318 | 319 | case 0x39 => // dstore 320 | onInstruction(index, dstore(it.byteAt(index + 1))) 321 | 322 | case 0x3C => // istore_1 323 | onInstruction(index, istore(1)) 324 | 325 | case 0x59 => // dup 326 | onInstruction(index, dup()) 327 | 328 | case 0x60 => // iadd 329 | onInstruction(index, iadd()) 330 | 331 | case 0x63 => // dadd 332 | onInstruction(index, dadd()) 333 | case 0x64 => // isub 334 | onInstruction(index, isub()) 335 | 336 | case 0x67 => // dsub 337 | onInstruction(index, dsub()) 338 | case 0x68 => // imul 339 | onInstruction(index, imul()) 340 | 341 | case 0x6B => // dmul 342 | onInstruction(index, dmul()) 343 | 344 | case 0x87 => // i2d 345 | onInstruction(index, i2d()) 346 | 347 | case 0x8E => // d2i 348 | onInstruction(index, d2i()) 349 | 350 | case 0xA2 => // if_icmpge 351 | val jt = JumpTarget("branch") 352 | onInstruction(index, if_icmpge(jt)) 353 | jumps(addr2label(index) -> jt) = index + it.s16bitAt(index + 1) 354 | 355 | case 0xA4 => // if_icmple 356 | val jt = JumpTarget("branch") 357 | onInstruction(index, if_icmple(jt)) 358 | jumps(addr2label(index) -> jt) = index + it.s16bitAt(index + 1) 359 | 360 | case 0xA6 => // if_acmpne 361 | val jt = JumpTarget("branch") 362 | onInstruction(index, if_acmpne(jt)) 363 | jumps(addr2label(index) -> jt) = index + it.s16bitAt(index + 1) 364 | case 0xA7 => // goto 365 | val jt = JumpTarget("branch") 366 | onInstruction(index, goto(jt)) 367 | jumps(addr2label(index) -> jt) = index + it.s16bitAt(index + 1) 368 | 369 | case 0xAC => // ireturn 370 | onInstruction(index, ireturn()) 371 | case 0xAD => // lreturn 372 | onInstruction(index, lreturn()) 373 | case 0xAE => // freturn 374 | onInstruction(index, freturn()) 375 | case 0xAF => // dreturn 376 | onInstruction(index, dreturn()) 377 | 378 | case 0xB0 => // areturn 379 | onInstruction(index, areturn()) 380 | case 0xB1 => // return 381 | onInstruction(index, vreturn()) 382 | case 0xB2 => // getstatic 383 | val constIndex = it.u16bitAt(index + 1) 384 | val className = cpool.getFieldrefClassName(constIndex) 385 | val classRef = ClassRef.of(jClass.getClassLoader.loadClass(className)) 386 | val fieldName = cpool.getFieldrefName(constIndex) 387 | val fieldDescriptor = FieldDescriptor.parse(cpool.getFieldrefType(constIndex), jClass.getClassLoader) 388 | val fieldRef = FieldRef(fieldName, fieldDescriptor) 389 | onInstruction(index, getstatic(classRef, fieldRef)) 390 | 391 | case 0xB4 => // getfield 392 | // TODO: refactor 393 | val constIndex = it.u16bitAt(index + 1) 394 | val className = cpool.getFieldrefClassName(constIndex) 395 | val classRef = ClassRef.of(jClass.getClassLoader.loadClass(className)) 396 | val fieldName = cpool.getFieldrefName(constIndex) 397 | val fieldDescriptor = FieldDescriptor.parse(cpool.getFieldrefType(constIndex), jClass.getClassLoader) 398 | val fieldRef = FieldRef(fieldName, fieldDescriptor) 399 | onInstruction(index, getfield(classRef, fieldRef)) 400 | case 0xB5 => // putfield 401 | val constIndex = it.u16bitAt(index + 1) 402 | val className = cpool.getFieldrefClassName(constIndex) 403 | val classRef = ClassRef.of(jClass.getClassLoader.loadClass(className)) 404 | val fieldName = cpool.getFieldrefName(constIndex) 405 | val fieldDescriptor = FieldDescriptor.parse(cpool.getFieldrefType(constIndex), jClass.getClassLoader) 406 | val fieldRef = FieldRef(fieldName, fieldDescriptor) 407 | onInstruction(index, putfield(classRef, fieldRef)) 408 | case 0xB6 => // invokevirtual 409 | // TODO: refactor 410 | val constIndex = it.u16bitAt(index + 1) 411 | val className = cpool.getMethodrefClassName(constIndex) 412 | val methodName = cpool.getMethodrefName(constIndex) 413 | val methodType = cpool.getMethodrefType(constIndex) 414 | val classRef = ClassRef.of(className, jClass.getClassLoader) 415 | onInstruction( 416 | index, 417 | invokevirtual( 418 | classRef, 419 | MethodRef(methodName, MethodDescriptor.parse(methodType, jClass.getClassLoader)) 420 | ) 421 | ) 422 | case 0xB7 => // invokespecial 423 | val constIndex = it.u16bitAt(index + 1) 424 | val className = cpool.getMethodrefClassName(constIndex) 425 | val methodName = cpool.getMethodrefName(constIndex) 426 | val methodType = cpool.getMethodrefType(constIndex) 427 | val classRef = ClassRef.of(className, jClass.getClassLoader) 428 | onInstruction( 429 | index, 430 | invokespecial( 431 | classRef, 432 | MethodRef(methodName, MethodDescriptor.parse(methodType, jClass.getClassLoader)) 433 | ) 434 | ) 435 | case 0xB8 => // invokestatic 436 | val constIndex = it.u16bitAt(index + 1) 437 | val className = cpool.getMethodrefClassName(constIndex) 438 | val methodName = cpool.getMethodrefName(constIndex) 439 | val methodType = cpool.getMethodrefType(constIndex) 440 | val classRef = ClassRef.of(className, jClass.getClassLoader) 441 | onInstruction( 442 | index, 443 | invokestatic( 444 | classRef, 445 | MethodRef(methodName, MethodDescriptor.parse(methodType, jClass.getClassLoader)) 446 | ) 447 | ) 448 | case 0xB9 => // invokeinterface 449 | val count = it.byteAt(index + 3) 450 | val constIndex = it.u16bitAt(index + 1) 451 | val className = cpool.getMethodrefClassName(constIndex) 452 | val methodName = cpool.getMethodrefName(constIndex) 453 | val methodType = cpool.getMethodrefType(constIndex) 454 | val classRef = ClassRef.of(className, jClass.getClassLoader) 455 | onInstruction( 456 | index, 457 | invokeinterface( 458 | classRef, 459 | MethodRef(methodName, MethodDescriptor.parse(methodType, jClass.getClassLoader)), 460 | count 461 | ) 462 | ) 463 | 464 | case 0xBB => // new 465 | val constIndex = it.u16bitAt(index + 1) 466 | val className = cpool.getClassInfo(constIndex) 467 | onInstruction( 468 | index, 469 | new_(ClassRef.of(className, jClass.getClassLoader)) 470 | ) 471 | 472 | case 0xBF => // athrow 473 | onInstruction(index, athrow()) 474 | 475 | case 0xC7 => // ifnonnull 476 | val jt = JumpTarget("branch") 477 | onInstruction(index, ifnonnull(jt)) 478 | jumps(addr2label(index) -> jt) = index + it.s16bitAt(index + 1) 479 | 480 | case unk => 481 | throw new UnsupportedOpcodeException(ClassRef.of(jClass), mRef, unk) 482 | } 483 | } 484 | val jumpTargets: Map[(Bytecode.Label, JumpTarget), Bytecode.Label] = 485 | jumps.map { case ((l, jt), index) => (l -> jt) -> addr2label(index) }.toMap 486 | Some(MethodBody( 487 | mRef.isInit, 488 | mRef.descriptor, 489 | MethodAttribute.from(ctMethod.getModifiers), 490 | new CodeFragment.Complete(bcs.toSeq, jumpTargets) 491 | )) 492 | } 493 | } 494 | 495 | // TODO: Make javassist getItem to public 496 | def printConstPool(cfile: javassist.bytecode.ClassFile): Unit = { 497 | val cop = cfile.getConstPool 498 | val gi = cop.getClass.getDeclaredMethods.find(_.getName == "getItem").get 499 | gi.setAccessible(true) 500 | (1 until cop.getSize) foreach { i => 501 | val a = gi.invoke(cop, i.asInstanceOf[java.lang.Integer]) 502 | val x = a.getClass.getMethods.find(_.getName == "print").get 503 | x.setAccessible(true) 504 | val pw = new java.io.PrintWriter(System.out) 505 | println(s"${i} -> ${a.getClass}") 506 | print(" ") 507 | x.invoke(a, pw) 508 | pw.flush() 509 | } 510 | } 511 | } 512 | -------------------------------------------------------------------------------- /core/src/main/scala/Bytecode.scala: -------------------------------------------------------------------------------- 1 | package com.todesking.unveil 2 | 3 | import scala.language.existentials 4 | import scala.language.higherKinds 5 | 6 | import scala.reflect.{ classTag, ClassTag } 7 | import scala.collection.mutable 8 | 9 | import java.lang.reflect.{ Method => JMethod } 10 | 11 | import com.todesking.scalapp.syntax._ 12 | sealed abstract class Bytecode { 13 | type Self <: Bytecode 14 | protected final def self: Self = this.asInstanceOf[Self] // :( 15 | 16 | def inputs: Seq[DataPort.In] 17 | def output: Option[DataPort.Out] 18 | def nextFrame(label: Bytecode.Label, frame: Frame): FrameUpdate 19 | def pretty: String = toString 20 | 21 | protected def update(label: Bytecode.Label, frame: Frame): FrameUpdate = 22 | new FrameUpdate(label, this, frame) 23 | } 24 | object Bytecode { 25 | case class Label(index: Int) { 26 | def format(f: String): String = 27 | f.format(index) 28 | 29 | def offset(n: Int): Label = 30 | Label(index + n) 31 | 32 | override def toString = s"L$index" 33 | } 34 | 35 | def autoLoad(t: TypeRef, n: Int): Bytecode = 36 | t match { 37 | case TypeRef.Int => iload(n) 38 | case TypeRef.Double => dload(n) 39 | case TypeRef.Float => fload(n) 40 | case TypeRef.Long => lload(n) 41 | case TypeRef.Reference(_) => aload(n) 42 | case unk => 43 | throw new IllegalArgumentException(s"Unsupported load instruction for ${unk}") 44 | } 45 | 46 | def autoStore(t: TypeRef, n: Int): Bytecode = 47 | t match { 48 | case TypeRef.Int => istore(n) 49 | case TypeRef.Double => dstore(n) 50 | case TypeRef.Reference(cr) => astore(n) 51 | case unk => 52 | throw new IllegalArgumentException(s"Unsupported store instruction for ${unk}") 53 | } 54 | 55 | def autoPop(t: TypeRef): Bytecode = 56 | if (t.isDoubleWord) pop2() 57 | else pop() 58 | 59 | sealed trait FallThrough extends Bytecode 60 | 61 | sealed abstract class Control extends Bytecode 62 | sealed abstract class Procedure extends Bytecode with FallThrough 63 | sealed abstract class Shuffle extends Bytecode with FallThrough { 64 | override type Self <: Shuffle 65 | override final def inputs = Seq.empty 66 | override final def output: Option[DataPort.Out] = None 67 | } 68 | 69 | sealed trait HasClassRef extends Bytecode { 70 | def classRef: ClassRef 71 | def withNewClassRef(newRef: ClassRef): Self 72 | final def rewriteClassRef(newRef: ClassRef): Self = 73 | if (classRef == newRef) self 74 | else withNewClassRef(newRef) 75 | } 76 | 77 | sealed trait HasMethodRef extends HasClassRef { 78 | override type Self <: HasMethodRef 79 | def methodRef: MethodRef 80 | def withNewMehtodRef(newRef: MethodRef): Self 81 | final def rewriteMethodRef(newRef: MethodRef): Self = 82 | if (methodRef == newRef) self 83 | else withNewMehtodRef(newRef) 84 | final def rewriteMethodRef(cr: ClassRef, mr: MethodRef): Self = 85 | rewriteClassRef(cr).rewriteMethodRef(mr).asInstanceOf[Self] 86 | } 87 | 88 | sealed trait HasFieldRef extends HasClassRef { 89 | override type Self <: HasFieldRef 90 | def fieldRef: FieldRef 91 | def withNewFieldRef(newRef: FieldRef): Self 92 | final def rewriteFieldRef(newRef: FieldRef): Self = 93 | if (fieldRef == newRef) self 94 | else withNewFieldRef(newRef) 95 | final def rewriteFieldRef(cr: ClassRef, fr: FieldRef): Self = 96 | rewriteClassRef(cr).rewriteFieldRef(fr).asInstanceOf[Self] 97 | } 98 | 99 | sealed trait HasJumpTargets extends Control { 100 | def jumpTargets: Set[JumpTarget] 101 | } 102 | 103 | sealed trait HasAJumpTarget extends HasJumpTargets { 104 | override def jumpTargets = Set(jumpTarget) 105 | def target: JumpTarget = jumpTarget // TODO: remove this 106 | def jumpTarget: JumpTarget 107 | } 108 | 109 | sealed abstract class Jump extends Control with HasAJumpTarget { 110 | override final def inputs = Seq.empty 111 | override final def output = None 112 | override final def nextFrame(l: Bytecode.Label, f: Frame) = update(l, f) 113 | } 114 | sealed abstract class Branch extends Control with HasAJumpTarget with FallThrough 115 | sealed abstract class Exit extends Control 116 | sealed abstract class Return extends Exit 117 | sealed abstract class Throw extends Exit 118 | 119 | sealed abstract class XReturn extends Return { 120 | val retval: DataPort.In = DataPort.In("retval") 121 | override final val inputs = Seq(retval) 122 | override final def output = None 123 | override final def nextFrame(l: Bytecode.Label, f: Frame) = update(l, f).ret(retval) 124 | def returnType: TypeRef.Public 125 | } 126 | // Void return 127 | sealed abstract class VoidReturn extends Return { 128 | override def inputs = Seq.empty 129 | override def output = None 130 | override def nextFrame(l: Bytecode.Label, f: Frame) = update(l, f) 131 | } 132 | 133 | sealed abstract class if_X1cmpXX extends Branch { 134 | val value1: DataPort.In = DataPort.In("value1") 135 | val value2: DataPort.In = DataPort.In("value2") 136 | override def inputs = Seq(value1, value2) 137 | override def output = None 138 | override def nextFrame(l: Bytecode.Label, f: Frame) = update(l, f).pop1(value2).pop1(value1) 139 | } 140 | 141 | sealed abstract class LocalAccess extends Shuffle { 142 | override type Self <: LocalAccess 143 | def localIndex: Int 144 | def rewriteLocalIndex(n: Int): Self 145 | } 146 | 147 | sealed abstract class Load1 extends LocalAccess { 148 | override def nextFrame(l: Bytecode.Label, f: Frame) = update(l, f).load1(localIndex) 149 | } 150 | sealed abstract class Load2 extends LocalAccess { 151 | override def nextFrame(l: Bytecode.Label, f: Frame) = update(l, f).load2(localIndex) 152 | } 153 | 154 | sealed abstract class Store1 extends LocalAccess { 155 | def storeType: TypeRef.SingleWord 156 | override def nextFrame(l: Bytecode.Label, f: Frame) = update(l, f).store1(storeType, localIndex) 157 | } 158 | 159 | sealed abstract class Store2 extends LocalAccess { 160 | def storeType: TypeRef.DoubleWord 161 | override def nextFrame(l: Bytecode.Label, f: Frame) = update(l, f).store2(storeType, localIndex) 162 | } 163 | 164 | sealed abstract class ConstX extends Procedure { 165 | def out: DataPort.Out 166 | def data: Data.Concrete 167 | override def inputs = Seq.empty 168 | override def output = Some(out) 169 | } 170 | 171 | sealed abstract class Const1 extends ConstX { 172 | final val out: DataPort.Out = DataPort.Out("const(1word)") 173 | override def nextFrame(l: Bytecode.Label, f: Frame) = 174 | update(l, f).push1(output, FrameItem(DataSource.Constant(l, out, data), data)) 175 | } 176 | 177 | sealed abstract class Const2 extends ConstX { 178 | final val out: DataPort.Out = DataPort.Out("const(2word)") 179 | override def nextFrame(l: Bytecode.Label, f: Frame) = 180 | update(l, f).push2(output, FrameItem(DataSource.Constant(l, out, data), data)) 181 | } 182 | 183 | sealed abstract class InvokeMethod extends Procedure with HasClassRef with HasMethodRef { 184 | override type Self <: InvokeMethod 185 | val args: Seq[DataPort.In] = methodRef.args.zipWithIndex.map { case (_, i) => DataPort.In(s"arg${i}") } 186 | val ret: Option[DataPort.Out] = if (methodRef.isVoid) None else Some(DataPort.Out("ret")) 187 | override final def output = ret 188 | } 189 | sealed abstract class InvokeClassMethod extends InvokeMethod { 190 | override type Self <: InvokeClassMethod 191 | override final def inputs = args 192 | override def nextFrame(l: Bytecode.Label, f: Frame) = { 193 | require(f.stack.size >= methodRef.args.size) 194 | val popped = 195 | args.zip(methodRef.args).foldRight(update(l, f)) { 196 | case ((a, t), u) => 197 | if (t.isDoubleWord) u.pop2(a) 198 | else u.pop1(a) 199 | } 200 | ret.fold(popped) { rlabel => 201 | popped.push(Some(rlabel), FrameItem(DataSource.MethodInvocation(l, rlabel), Data.Unknown(methodRef.ret))) 202 | } 203 | } 204 | } 205 | sealed abstract class InvokeInstanceMethod extends InvokeMethod { 206 | override type Self <: InvokeInstanceMethod 207 | val objectref: DataPort.In = DataPort.In("objectref") 208 | override final def inputs = objectref +: args 209 | override def nextFrame(l: Bytecode.Label, f: Frame) = { 210 | require(f.stack.size >= methodRef.args.size) 211 | val popped = 212 | args.zip(methodRef.args).foldRight(update(l, f)) { 213 | case ((a, t), u) => 214 | if (t.isDoubleWord) u.pop2(a) 215 | else u.pop1(a) 216 | }.pop1(objectref) 217 | ret.fold(popped) { rlabel => 218 | popped.push(Some(rlabel), FrameItem(DataSource.MethodInvocation(l, rlabel), Data.Unknown(methodRef.ret))) 219 | } 220 | } 221 | def resolveMethod(instance: Instance[_ <: AnyRef]): ClassRef 222 | } 223 | 224 | sealed abstract class FieldAccess extends Procedure with HasClassRef with HasFieldRef { 225 | override type Self <: FieldAccess 226 | } 227 | sealed abstract class StaticFieldAccess extends FieldAccess { 228 | override type Self <: StaticFieldAccess 229 | } 230 | sealed abstract class InstanceFieldAccess extends FieldAccess { 231 | override type Self <: InstanceFieldAccess 232 | val objectref: DataPort.In = DataPort.In("objectref") 233 | } 234 | sealed trait FieldSetter extends FieldAccess { 235 | final val value = DataPort.In("value") 236 | final override def output = None 237 | } 238 | sealed trait FieldGetter extends FieldAccess { 239 | final val out = DataPort.Out("out") 240 | final override def output = Some(out) 241 | } 242 | 243 | case class nop() extends Shuffle { 244 | override type Self = nop 245 | override def nextFrame(l: Bytecode.Label, f: Frame) = update(l, f) 246 | } 247 | case class dup() extends Shuffle { 248 | override type Self = dup 249 | override def nextFrame(l: Bytecode.Label, f: Frame) = update(l, f).push(None, f.stack.head) 250 | } 251 | case class pop() extends Shuffle { 252 | override type Self = pop 253 | override def nextFrame(l: Bytecode.Label, f: Frame) = update(l, f).pop1() 254 | } 255 | case class pop2() extends Shuffle { 256 | override type Self = pop2 257 | override def nextFrame(l: Bytecode.Label, f: Frame) = update(l, f).pop2() 258 | } 259 | case class vreturn() extends VoidReturn { 260 | override type Self = vreturn 261 | } 262 | case class iload(override val localIndex: Int) extends Load1 { 263 | override type Self = iload 264 | override def rewriteLocalIndex(m: Int) = if (localIndex == m) self else iload(m) 265 | } 266 | case class aload(override val localIndex: Int) extends Load1 { 267 | override type Self = aload 268 | override def rewriteLocalIndex(m: Int) = if (localIndex == m) self else aload(m) 269 | } 270 | case class fload(override val localIndex: Int) extends Load1 { 271 | override type Self = fload 272 | override def rewriteLocalIndex(m: Int) = if (localIndex == m) self else fload(m) 273 | } 274 | case class dload(override val localIndex: Int) extends Load2 { 275 | override type Self = dload 276 | override def rewriteLocalIndex(m: Int) = if (localIndex == m) self else dload(m) 277 | } 278 | case class lload(override val localIndex: Int) extends Load2 { 279 | override type Self = lload 280 | override def rewriteLocalIndex(m: Int) = if (localIndex == m) self else lload(m) 281 | } 282 | case class istore(override val localIndex: Int) extends Store1 { 283 | override type Self = istore 284 | override def storeType = TypeRef.Int 285 | override def rewriteLocalIndex(m: Int) = if (localIndex == m) self else istore(m) 286 | } 287 | case class astore(override val localIndex: Int) extends Store1 { 288 | override type Self = astore 289 | override def storeType = TypeRef.Object 290 | override def rewriteLocalIndex(m: Int) = if (localIndex == m) self else astore(m) 291 | } 292 | case class dstore(override val localIndex: Int) extends Store2 { 293 | override type Self = dstore 294 | override def storeType = TypeRef.Double 295 | override def rewriteLocalIndex(m: Int) = if (localIndex == m) self else dstore(m) 296 | } 297 | 298 | case class ireturn() extends XReturn { 299 | override type Self = ireturn 300 | override def returnType = TypeRef.Int 301 | } 302 | case class freturn() extends XReturn { 303 | override type Self = freturn 304 | override def returnType = TypeRef.Float 305 | } 306 | case class dreturn() extends XReturn { 307 | override type Self = dreturn 308 | override def returnType = TypeRef.Double 309 | } 310 | case class lreturn() extends XReturn { 311 | override type Self = lreturn 312 | override def returnType = TypeRef.Long 313 | } 314 | case class areturn() extends XReturn { 315 | override type Self = areturn 316 | override def returnType = TypeRef.Reference(ClassRef.Object) 317 | } 318 | case class iconst(value: Int) extends Const1 { 319 | override type Self = iconst 320 | override def data = Data.ConcretePrimitive(TypeRef.Int, value) 321 | } 322 | case class lconst(value: Long) extends Const2 { 323 | override type Self = lconst 324 | override def data = Data.ConcretePrimitive(TypeRef.Long, value) 325 | } 326 | case class aconst_null() extends Const1 { 327 | override type Self = aconst_null 328 | override def data = Data.Null 329 | } 330 | case class ldc2_double(value: Double) extends Const2 { 331 | override type Self = ldc2_double 332 | override def data = Data.ConcretePrimitive(TypeRef.Double, value) 333 | } 334 | case class goto(override val jumpTarget: JumpTarget = JumpTarget("jump")) extends Jump { 335 | override type Self = goto 336 | } 337 | case class if_icmple(override val jumpTarget: JumpTarget) extends if_X1cmpXX { 338 | override type Self = if_icmple 339 | } 340 | case class if_icmpge(override val jumpTarget: JumpTarget) extends if_X1cmpXX { 341 | override type Self = if_icmpge 342 | } 343 | case class if_acmpne(override val jumpTarget: JumpTarget) extends if_X1cmpXX { 344 | override type Self = if_acmpne 345 | } 346 | case class ifnonnull(override val jumpTarget: JumpTarget) extends Branch { 347 | override type Self = ifnonnull 348 | val value: DataPort.In = DataPort.In("value") 349 | override def pretty = "ifnonnull" 350 | override def inputs = Seq(value) 351 | override def output = None 352 | override def nextFrame(l: Bytecode.Label, f: Frame) = update(l, f).pop1(value) 353 | } 354 | sealed abstract class PrimitiveBinOp[A <: AnyVal] extends Procedure { 355 | val value1 = DataPort.In("value1") 356 | val value2 = DataPort.In("value2") 357 | val result = DataPort.Out("result") 358 | 359 | def operandType: TypeRef.Primitive 360 | def op(value1: A, value2: A): A 361 | 362 | override def inputs = Seq(value1, value2) 363 | override def output = Some(result) 364 | override def nextFrame(l: Bytecode.Label, f: Frame) = 365 | (f.stack(0), f.stack(operandType.wordSize)) match { 366 | case (d1, d2) if d1.data.typeRef == operandType && d2.data.typeRef == operandType => 367 | update(l, f) 368 | .pop(operandType, value2) 369 | .pop(operandType, value1) 370 | .push( 371 | Some(result), 372 | FrameItem( 373 | DataSource.Generic(l, result), 374 | d1.data.value.flatMap { v1 => 375 | d2.data.value.map { v2 => 376 | Data.ConcretePrimitive( 377 | operandType, 378 | op(v1.asInstanceOf[A], v2.asInstanceOf[A]) 379 | ) 380 | } 381 | }.getOrElse { Data.Unknown(operandType) } 382 | ) 383 | ) 384 | case (d1, d2) => throw new AnalyzeException(s"$this: Type error: ${(d1, d2)}") 385 | } 386 | } 387 | sealed abstract class PrimitiveUniOp[A <: AnyVal, B <: AnyVal] extends Procedure { 388 | val value = DataPort.In("value") 389 | val result = DataPort.Out("result") 390 | 391 | override def inputs = Seq(value) 392 | override def output = Some(result) 393 | 394 | def operandType: TypeRef.Primitive 395 | def resultType: TypeRef.Primitive 396 | def op(value: A): B 397 | 398 | override def nextFrame(l: Bytecode.Label, f: Frame) = { 399 | update(l, f) 400 | .pop(operandType) 401 | .push( 402 | Some(result), 403 | FrameItem( 404 | DataSource.Generic(l, result), 405 | f.stackTop.data.value.map { v => 406 | Data.ConcretePrimitive(resultType, op(v.asInstanceOf[A])) 407 | }.getOrElse { Data.Unknown(resultType) } 408 | ) 409 | ) 410 | } 411 | } 412 | case class iadd() extends PrimitiveBinOp[Int] { 413 | override type Self = iadd 414 | override def operandType = TypeRef.Int 415 | override def op(value1: Int, value2: Int) = value1 + value2 416 | } 417 | case class dadd() extends PrimitiveBinOp[Double] { 418 | override type Self = dadd 419 | override def operandType = TypeRef.Double 420 | override def op(value1: Double, value2: Double) = value1 + value2 421 | } 422 | case class isub() extends PrimitiveBinOp[Int] { 423 | override type Self = isub 424 | override def operandType = TypeRef.Int 425 | override def op(value1: Int, value2: Int) = value1 - value2 426 | } 427 | case class dsub() extends PrimitiveBinOp[Double] { 428 | override type Self = dsub 429 | override def operandType = TypeRef.Double 430 | override def op(value1: Double, value2: Double) = value1 - value2 431 | } 432 | case class dmul() extends PrimitiveBinOp[Double] { 433 | override type Self = dmul 434 | override def operandType = TypeRef.Double 435 | override def op(value1: Double, value2: Double) = value1 * value2 436 | } 437 | case class imul() extends PrimitiveBinOp[Int] { 438 | override type Self = imul 439 | override def operandType = TypeRef.Int 440 | override def op(value1: Int, value2: Int) = value1 * value2 441 | } 442 | case class d2i() extends PrimitiveUniOp[Double, Int] { 443 | override type Self = d2i 444 | override def operandType = TypeRef.Double 445 | override def resultType = TypeRef.Int 446 | override def op(value: Double) = value.toInt 447 | } 448 | case class i2d() extends PrimitiveUniOp[Int, Double] { 449 | override type Self = i2d 450 | override def operandType = TypeRef.Int 451 | override def resultType = TypeRef.Double 452 | override def op(value: Int) = value.toDouble 453 | } 454 | case class invokevirtual(override val classRef: ClassRef, override val methodRef: MethodRef) extends InvokeInstanceMethod { 455 | override type Self = invokevirtual 456 | override def withNewClassRef(newRef: ClassRef) = copy(classRef = newRef) 457 | override def withNewMehtodRef(newRef: MethodRef) = copy(methodRef = newRef) 458 | override def pretty = s"invokevirtual ${classRef.pretty}.${methodRef.str}" 459 | override def resolveMethod(instance: Instance[_ <: AnyRef]): ClassRef = 460 | instance.resolveVirtualMethod(methodRef) 461 | } 462 | case class invokeinterface(override val classRef: ClassRef, override val methodRef: MethodRef, count: Int) extends InvokeInstanceMethod { 463 | override type Self = invokeinterface 464 | override def withNewClassRef(newRef: ClassRef) = copy(classRef = newRef) 465 | override def withNewMehtodRef(newRef: MethodRef) = copy(methodRef = newRef) 466 | override def pretty = s"invokeinterface ${classRef.pretty}.${methodRef.str}" 467 | override def resolveMethod(instance: Instance[_ <: AnyRef]): ClassRef = 468 | instance.resolveVirtualMethod(methodRef) 469 | } 470 | case class invokespecial(override val classRef: ClassRef, override val methodRef: MethodRef) extends InvokeInstanceMethod { 471 | override type Self = invokespecial 472 | override def withNewClassRef(newRef: ClassRef) = copy(classRef = newRef) 473 | override def withNewMehtodRef(newRef: MethodRef) = copy(methodRef = newRef) 474 | override def pretty = s"invokespecial ${classRef.pretty}.${methodRef.str}" 475 | override def nextFrame(l: Bytecode.Label, f: Frame) = { 476 | require(f.stack.size >= methodRef.args.size + 1) 477 | f.stack(methodRef.args.size) match { 478 | case FrameItem(ds @ DataSource.New(sl, sp), uninitialized @ Data.Uninitialized(t)) => 479 | // should ctor 480 | if(!methodRef.isInit) { 481 | throw new IllegalArgumentException(s"expect ctor call but ${methodRef}") 482 | } 483 | if(t != classRef) { 484 | throw new IllegalArgumentException(s"Illegal ctor invocation: instance is $t and ctor is $classRef.$methodRef") 485 | } 486 | if(!classRef.isInstanceOf[ClassRef.Concrete]) { 487 | throw new IllegalArgumentException(s"Unexpected abstract ClassRef: ${classRef}") 488 | } 489 | assert(methodRef.ret == TypeRef.Void) 490 | 491 | val initialized = 492 | Data.AbstractReference( 493 | new Instance.New[AnyRef]( 494 | classRef.asInstanceOf[ClassRef.Concrete].loadKlass, // TODO: make safer 495 | methodRef.descriptor 496 | ) 497 | ) 498 | 499 | args.zip(methodRef.args) 500 | .foldRight(update(l, f)) { 501 | case ((a, t), u) => 502 | if (t.isDoubleWord) u.pop2(a) 503 | else u.pop1(a) 504 | }.pop1(objectref) 505 | .initializeInstance(sl, sp, initialized) 506 | case FrameItem(src, Data.Uninitialized(t)) if src.must(DataSource.This) => 507 | // super ctor call in ctor 508 | super.nextFrame(l, f) 509 | case FrameItem(src, Data.Uninitialized(t)) => 510 | // unexpected 511 | throw new IllegalArgumentException(s"Unexpected ctor call: data src = $src") 512 | case fi => // normal 513 | if(methodRef.isInit) { 514 | throw new IllegalArgumentException(s"Unexpected ctor call: data src = ${fi.source}") 515 | } 516 | super.nextFrame(l, f) 517 | } 518 | } 519 | override def resolveMethod(instance: Instance[_ <: AnyRef]): ClassRef = 520 | classRef // TODO: [BUG] 521 | } 522 | case class invokestatic(override val classRef: ClassRef, override val methodRef: MethodRef) extends InvokeClassMethod { 523 | override type Self = invokestatic 524 | override def withNewClassRef(newRef: ClassRef) = copy(classRef = newRef) 525 | override def withNewMehtodRef(newRef: MethodRef) = copy(methodRef = newRef) 526 | override def pretty = s"invokestatic ${classRef.pretty}.${methodRef.str}" 527 | } 528 | case class getfield(override val classRef: ClassRef, override val fieldRef: FieldRef) extends InstanceFieldAccess with FieldGetter { 529 | override type Self = getfield 530 | override def withNewClassRef(newRef: ClassRef) = copy(classRef = newRef) 531 | override def withNewFieldRef(newRef: FieldRef) = copy(fieldRef = newRef) 532 | 533 | override def inputs = Seq(objectref) 534 | 535 | override def pretty = s"getfield ${classRef}.${fieldRef}" 536 | override def nextFrame(l: Bytecode.Label, f: Frame) = { 537 | val self = f.stack(0).data 538 | val data = 539 | self match { 540 | case d: Data.Reference => 541 | val key = d.instance.resolveField(classRef, fieldRef) -> fieldRef 542 | val field = d.instance.klass.instanceFieldAttributes(key) 543 | if (field.isFinal) d.instance.fieldValues(key) 544 | else Data.Unknown(fieldRef.descriptor.typeRef) 545 | case _ => 546 | Data.Unknown(fieldRef.descriptor.typeRef) 547 | } 548 | update(l, f).pop1(objectref).push(output, FrameItem(DataSource.InstanceField(l, out, self, classRef, fieldRef), data)) 549 | } 550 | } 551 | case class getstatic(override val classRef: ClassRef, override val fieldRef: FieldRef) extends StaticFieldAccess with FieldGetter { 552 | override type Self = getstatic 553 | override def withNewClassRef(newRef: ClassRef) = copy(classRef = newRef) 554 | override def withNewFieldRef(newRef: FieldRef) = copy(fieldRef = newRef) 555 | 556 | override def inputs = Seq() 557 | 558 | override def pretty = s"getstatic ${fieldRef}" 559 | override def nextFrame(l: Bytecode.Label, f: Frame) = { 560 | val data = Data.Unknown(fieldRef.descriptor.typeRef) // TODO: set static field value if it is final 561 | update(l, f).push(output, FrameItem(DataSource.StaticField(l, out, classRef, fieldRef), data)) 562 | } 563 | } 564 | case class putfield(override val classRef: ClassRef, override val fieldRef: FieldRef) extends InstanceFieldAccess with FieldSetter { 565 | override type Self = putfield 566 | override def withNewClassRef(newRef: ClassRef) = copy(classRef = newRef) 567 | override def withNewFieldRef(newRef: FieldRef) = copy(fieldRef = newRef) 568 | override def inputs = Seq(objectref) 569 | override def pretty = s"putfield ${classRef}.${fieldRef}" 570 | override def nextFrame(l: Bytecode.Label, f: Frame) = 571 | update(l, f).pop(fieldRef.descriptor.typeRef, value).pop1(objectref) 572 | } 573 | case class athrow() extends Throw { 574 | override type Self = athrow 575 | val objectref = DataPort.In("objectref") 576 | override def pretty = s"athrow" 577 | override def inputs = Seq(objectref) 578 | override def output = None 579 | override def nextFrame(l: Bytecode.Label, f: Frame) = 580 | update(l, f).athrow(objectref) 581 | } 582 | case class new_(override val classRef: ClassRef) extends Procedure with HasClassRef { 583 | val objectref = DataPort.Out("new") 584 | override type Self = new_ 585 | override def withNewClassRef(cr: ClassRef) = copy(classRef = cr) 586 | override def inputs = Seq() 587 | override def output = Some(objectref) 588 | override def nextFrame(l: Bytecode.Label, f: Frame) = 589 | update(l, f).push(output, FrameItem(DataSource.New(l, objectref), Data.Uninitialized(classRef))) 590 | } 591 | } 592 | --------------------------------------------------------------------------------