├── .gitignore ├── .gitmodules ├── README.md ├── build.sbt ├── project └── build.properties └── src ├── main └── scala │ ├── Benchmark.scala │ ├── CCodeGen.scala │ ├── Utils.scala │ ├── circuit │ ├── QASMTranslator.scala │ └── Syntax.scala │ ├── feynman │ ├── EvalState.scala │ ├── QCompilerCPS.scala │ ├── QContSim1.scala │ └── QContSim2.scala │ └── schrodinger │ ├── Complex.scala │ ├── DecompMatrix.scala │ ├── Gate.scala │ ├── Matrix.scala │ ├── SchrodingerInterpreter.scala │ ├── Shonan.scala │ ├── StagedComplex.scala │ ├── StagedHLSchrodinger.scala │ ├── StagedSchrodinger.scala │ └── UnstagedSchrodinger.scala └── test └── scala ├── MatrixTest.scala └── Schrodinger.scala /.gitignore: -------------------------------------------------------------------------------- 1 | project/target/ 2 | target/ 3 | *.log 4 | lms-clean 5 | sbt.json -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "lms-clean"] 2 | path = lms-clean 3 | url = git@github.com:TiarkRompf/lms-clean.git 4 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Quantum Circuits Compiler with Staging 2 | 3 | This repository contains various experimental quantum circuit evaluators and compilers. 4 | They take a Toffoli-Hadamard quantum circuit as input, and simulates the probability amplitudes of all possible outcomes. 5 | 6 | - The directory `src/main/scala/feynman` contains various Feynman-style simulators implemented using continuations. 7 | - `QContSim1.scala` contains a pure implementation and uses delimited continuations `shift`/`reset` (following the [Quantum Continuation](https://andykeep.com/SchemeWorkshop2022/scheme2022-final37.pdf) paper). 8 | - `QContSim2.scala` implements an evaluator written in CPS and uses side-effect to perform path summarization. 9 | - `QCompilerCPS.scala` is a staged CPS evaluator (written with [Lightweight Modular 10 | Staging](https://github.com/TiarkRompf/lms-clean)) that can generate C code 11 | for simulation. 12 | 13 | - The directory `src/main/scala/schrodinger` contains various Schrodinger-style simulators implemented with linear algebra computation. 14 | - `Schrodinger.scala` is an unstaged implementation. 15 | - `StagedSchrodinger.scala` is a staged implementation that specializes over static gate matrices. 16 | 17 | ## Example 18 | 19 | To see an example of the compiler in action, run the following command 20 | in `sbt`: 21 | 22 | ``` 23 | sbt:quantum-lms-compiler> runMain quantum.feynman.staged.TestQC 24 | ``` 25 | 26 | This will take the circuit for [the Simon problem](https://en.wikipedia.org/wiki/Simon%27s_problem) 27 | as input and execute the generated C program `snippet.cpp`. 28 | The C program is compiled with `g++ -std=c++20 -O3`. 29 | Running the generated program prints all states and their probability amplitudes: 30 | 31 | ``` 32 | 0.5|0000⟩ 33 | 0.5|0011⟩ 34 | 0.5|1100⟩ 35 | -0.5|1111⟩ 36 | ``` 37 | -------------------------------------------------------------------------------- /build.sbt: -------------------------------------------------------------------------------- 1 | name := "quantum-lms-compiler" 2 | organization := "edu.purdue" 3 | scalaVersion := "2.12.10" 4 | version := "0.1.0-SNAPSHOT" 5 | autoCompilerPlugins := true 6 | 7 | val paradiseVersion = "2.1.0" 8 | 9 | resolvers += Resolver.sonatypeRepo("releases") 10 | resolvers += Resolver.sonatypeRepo("snapshots") 11 | 12 | libraryDependencies += "org.scala-lang.plugins" %% "scala-continuations-library" % "1.0.3" 13 | libraryDependencies += "org.scalatest" %% "scalatest" % "3.2.9" % Test 14 | 15 | addCompilerPlugin("org.typelevel" %% "kind-projector" % "0.10.3") 16 | addCompilerPlugin("org.scalamacros" % "paradise" % paradiseVersion cross CrossVersion.full) 17 | addCompilerPlugin("org.scala-lang.plugins" % "scala-continuations-plugin_2.12.2" % "1.0.3") 18 | parallelExecution in Test := false 19 | 20 | scalacOptions += "-P:continuations:enable" 21 | 22 | lazy val lms = ProjectRef(file("./lms-clean"), "lms-clean") 23 | 24 | lazy val root = (project in file(".")).dependsOn(lms % "test->test; compile->compile") 25 | -------------------------------------------------------------------------------- /project/build.properties: -------------------------------------------------------------------------------- 1 | sbt.version=1.4.6 2 | -------------------------------------------------------------------------------- /src/main/scala/Benchmark.scala: -------------------------------------------------------------------------------- 1 | package quantum 2 | 3 | // The quantum circuit compiler written in CPS, generate C code 4 | 5 | import lms.core._ 6 | import lms.core.stub._ 7 | import lms.core.Backend._ 8 | import lms.core.virtualize 9 | import lms.macros.SourceContext 10 | import lms.thirdparty.CLibs 11 | import lms.thirdparty.CCodeGenLibs 12 | 13 | import scala.util.continuations._ 14 | import scala.collection.immutable.{List => SList} 15 | 16 | import quantum.circuit.Syntax.{Exp => QExp, _} 17 | import quantum.circuit.Examples._ 18 | import quantum.feynman._ 19 | import quantum.feynman.staged._ 20 | import quantum.feynman.EvalState.{prettyPrint, State} 21 | import quantum.utils.Utils 22 | 23 | object Benchmark { 24 | val benchmarks: List[(Circuit, Int)] = List( 25 | // (simon, 4), 26 | // (rand4, 4), 27 | // (rand8, 8), 28 | (rand16, 16) 29 | ) 30 | 31 | def test(ci: (Circuit, Int)): Unit = { 32 | val (circuit, size) = ci 33 | println(s"circuit size: ${circuit.size}") 34 | // warm up 35 | for (i <- 0 to 5) { 36 | QuantumEvalCPS.runCircuit(circuit, State(size)) 37 | } 38 | val (_, t) = Utils.time { 39 | QuantumEvalCPS.runCircuit(circuit, State(size)) 40 | } 41 | println(s"$t sec") 42 | 43 | val snippet = new QCDriver[Int, Unit] with QCompilerCPS { 44 | val circuitSize: Int = size 45 | override val repeat: Int = 1 46 | def snippet(s: Rep[Int]): Rep[Unit] = runCircuit(circuit, State(circuitSize)) 47 | } 48 | snippet.eval(0) 49 | } 50 | 51 | def main(args: Array[String]): Unit = { 52 | benchmarks.foreach(test) 53 | } 54 | 55 | } 56 | -------------------------------------------------------------------------------- /src/main/scala/CCodeGen.scala: -------------------------------------------------------------------------------- 1 | package quantum.schrodinger.staged 2 | 3 | import lms.core._ 4 | import lms.core.stub._ 5 | import lms.core.virtualize 6 | import lms.macros.SourceContext 7 | import lms.thirdparty.CLibs 8 | import lms.thirdparty.CCodeGenLibs 9 | import lms.core.Backend._ 10 | import java.io.{ByteArrayOutputStream, PrintStream} 11 | 12 | // Extends LMS C++ code generator with staticData 13 | 14 | abstract class QCodeGen extends DslGenCPP { 15 | override def shallow(n: Node): Unit = n match { 16 | case n @ Node(s, "staticData", List(Backend.Const(a)), _) => 17 | val q = a match { 18 | case x: Array[_] => "Array(" + x.mkString(",") + ")" 19 | case _ => a 20 | } 21 | emit("p" + quote(s)); emit(s" /* staticData $q */") 22 | case n => 23 | super.shallow(n) 24 | } 25 | 26 | // Note: so far only handles scalar values and flat arrays 27 | override def quoteStatic(n: Node) = n match { 28 | case Node(s, "staticData", List(Backend.Const(a)), _) => 29 | val arg = "p" + quote(s) 30 | val m = typeMap.getOrElse(s, manifest[Unknown]) 31 | val (tpe, postfix) = m.typeArguments match { 32 | case Nil => (remap(m), "") 33 | case List(inner) => (remap(inner), "[]") 34 | } 35 | val rhs = m.typeArguments match { 36 | case Nil => a.toString 37 | case List(inner) => "{" + a.asInstanceOf[Array[_]].mkString(",") + "}" 38 | } 39 | s"$tpe $arg$postfix = $rhs;" 40 | } 41 | 42 | def emitStatics(out: PrintStream): Unit = dce.statics.foreach { n => out.println(quoteStatic(n)) } 43 | 44 | registerHeader("") 45 | registerHeader("") 46 | registerHeader("") 47 | 48 | lazy val prelude = """ 49 | |using namespace std::chrono; 50 | |void printArray(int arr[], int size) { 51 | | printf("["); 52 | | for (int i = 0; i < size; i++) { 53 | | printf("%d", arr[i]); 54 | | if (i < size - 1) { printf(", "); } 55 | | } 56 | | printf("]\n"); 57 | |} 58 | """.stripMargin 59 | 60 | lazy val initInput: String = "int input[] = {1, 2, 3, 4, 5};" 61 | lazy val procOutput: String = "printArray(output, 5);"; 62 | def declareOutput(m: Manifest[_]): String = { 63 | if (remap(m) == "void") "" 64 | else s"${remap(m)} output = " 65 | } 66 | 67 | override def emitAll(g: Graph, name: String)(m1: Manifest[_], m2: Manifest[_]): Unit = { 68 | val ng = init(g) 69 | val efs = "" 70 | val src = run(name, ng) 71 | emitDefines(stream) 72 | emitHeaders(stream) 73 | emit(prelude) 74 | emitStatics(stream) 75 | emitFunctionDecls(stream) 76 | emitDatastructures(stream) 77 | emitFunctions(stream) 78 | emitInit(stream) 79 | emitln(s"\n/**************** $name ****************/") 80 | emit(src) 81 | emitln(s""" 82 | |int main(int argc, char *argv[]) { 83 | | $initInput 84 | | auto start = high_resolution_clock::now(); 85 | | ${declareOutput(m2)}$name(input); 86 | | auto end = high_resolution_clock::now(); 87 | | $procOutput 88 | | auto duration = duration_cast(end - start); 89 | | std::cout << std::fixed; 90 | | std::cout << "time: "; 91 | | std::cout << (duration_cast(duration).count() / 1.0e6) << "s\\n"; 92 | | return 0; 93 | |}""".stripMargin) 94 | } 95 | } 96 | -------------------------------------------------------------------------------- /src/main/scala/Utils.scala: -------------------------------------------------------------------------------- 1 | package quantum.utils 2 | 3 | object Utils { 4 | def time[R](block: => R): (R, Double) = { 5 | val t0 = System.nanoTime() 6 | val result = block // call-by-name 7 | val t1 = System.nanoTime() 8 | // val t = (t1 - t0) / 1000000.0 //to ms 9 | val t = (t1 - t0) / 1000000000.0 // to s 10 | println("Elapsed time: " + t + "s") 11 | (result, t) 12 | } 13 | } 14 | -------------------------------------------------------------------------------- /src/main/scala/circuit/QASMTranslator.scala: -------------------------------------------------------------------------------- 1 | package quantum.circuit 2 | 3 | // Translate to OpenQASM format 4 | 5 | import scala.util.continuations._ 6 | import scala.collection.immutable.{List => SList} 7 | 8 | import Syntax.{Exp => QExp, _} 9 | 10 | object QASMTranslator { 11 | var qreg: String = "q" 12 | 13 | def emit(g: Gate): Unit = 14 | g match { 15 | case CCX(Bit(true), Bit(true), z) => 16 | println(s"x $qreg[$z]") 17 | case CCX(Bit(true), y, z) => 18 | println(s"cx $qreg[$y] $qreg[$z]") 19 | case CCX(x, y, z) => 20 | println(s"ccx $qreg[$x] $qreg[$y] $qreg[$z]") 21 | case H(x) => 22 | println(s"h $qreg[$x]") 23 | } 24 | def emit(c: Circuit): Unit = c.foreach(emit) 25 | } 26 | -------------------------------------------------------------------------------- /src/main/scala/circuit/Syntax.scala: -------------------------------------------------------------------------------- 1 | package quantum.circuit 2 | 3 | object Syntax { 4 | abstract class Exp 5 | case class Wire(pos: Int) extends Exp { 6 | override def toString = pos.toString 7 | } 8 | case class Bit(b: Boolean) extends Exp { 9 | override def toString = s"""$$${if (b) "1" else "0"}""" 10 | } 11 | 12 | abstract class Gate 13 | // The Toffoli/Controlled-Controlled-Not gate 14 | case class CCX(x: Exp, y: Exp, z: Exp) extends Gate { 15 | override def toString = 16 | if (x == Bit(true)) 17 | if (y == Bit(true)) s"X $z" else s"CX $y $z" 18 | else s"CCX $x $y $z" 19 | } 20 | // The Hadamard gate 21 | case class H(x: Exp) extends Gate { 22 | override def toString = s"H $x" 23 | } 24 | // The Not gate 25 | def X(x: Exp): Gate = CCX(Bit(true), Bit(true), x) 26 | // The Controlled-Not gate 27 | def CX(y: Exp, z: Exp): Gate = CCX(Bit(true), y, z) 28 | 29 | implicit def intToExp(i: Int): Exp = Wire(i) 30 | implicit def intToBool(i: Int): Boolean = i != 0 31 | 32 | type Circuit = List[Gate] 33 | } 34 | 35 | import Syntax._ 36 | 37 | object RandCircuit { 38 | def randGate(size: Int): Gate = ??? 39 | } 40 | 41 | object Examples { 42 | // 2 qubits 43 | val circuit1: Circuit = List( 44 | H(0), 45 | CX(0, 1) 46 | ) 47 | 48 | // 1 qubits 49 | val circuit2: Circuit = List( 50 | H(0), 51 | X(0), 52 | H(0) 53 | ) 54 | 55 | // 4 qubits 56 | val simon: Circuit = List( 57 | H(0), 58 | H(1), 59 | CX(0, 2), 60 | CX(0, 3), 61 | CX(1, 2), 62 | CX(1, 3), 63 | H(0), 64 | H(1) 65 | ) 66 | 67 | val rand4: Circuit = List( 68 | H(0), 69 | H(1), 70 | CX(0, 2), 71 | CX(0, 3), 72 | CX(1, 2), 73 | CX(1, 3), 74 | H(0), 75 | H(1), 76 | H(0), 77 | H(1), 78 | CX(0, 2), 79 | CX(0, 3), 80 | CX(1, 2), 81 | CX(1, 3), 82 | H(0), 83 | H(1), 84 | H(0), 85 | H(1), 86 | CX(0, 2), 87 | CX(0, 3), 88 | CX(1, 2), 89 | CX(1, 3), 90 | H(0), 91 | H(1), 92 | H(0), 93 | H(1), 94 | CX(0, 2), 95 | CX(0, 3), 96 | CX(1, 2), 97 | CX(1, 3), 98 | H(0), 99 | H(1), 100 | H(0), 101 | H(1), 102 | CX(0, 2), 103 | CX(0, 3), 104 | CX(1, 2), 105 | CX(1, 3), 106 | H(0), 107 | H(1) 108 | ) 109 | 110 | val rand8: Circuit = List( 111 | H(0), 112 | H(1), 113 | H(2), 114 | H(3), 115 | H(4), 116 | H(5), 117 | H(6), 118 | H(7), 119 | CX(0, 2), 120 | CX(0, 3), 121 | CX(1, 2), 122 | CX(1, 3), 123 | CX(5, 7), 124 | CX(6, 7), 125 | CX(1, 3), 126 | H(5), 127 | H(6), 128 | H(7), 129 | CCX(1, 2, 7), 130 | CCX(3, 6, 7), 131 | X(1), 132 | CX(1, 3), 133 | H(1), 134 | H(3), 135 | H(4), 136 | CCX(2, 5, 3), 137 | CX(6, 7), 138 | CX(2, 4), 139 | CX(1, 3), 140 | H(1), 141 | H(2), 142 | H(3), 143 | H(4), 144 | H(5), 145 | H(6), 146 | H(7), 147 | X(0), 148 | CX(6, 7), 149 | CX(4, 1), 150 | CX(1, 3) 151 | ) 152 | 153 | val rand16: Circuit = List( 154 | H(0), 155 | H(1), 156 | H(2), 157 | H(3), 158 | H(4), 159 | H(5), 160 | H(6), 161 | H(7), 162 | H(8), 163 | H(9), 164 | H(10), 165 | H(11), 166 | H(12), 167 | H(13), 168 | H(14), 169 | H(15), 170 | CX(0, 2), 171 | CX(0, 3), 172 | CX(1, 2), 173 | CX(1, 3), 174 | CX(5, 10), 175 | CX(6, 7), 176 | CX(1, 3), 177 | CX(11, 8), 178 | CX(14, 3), 179 | CX(1, 12), 180 | CX(9, 3), 181 | CX(7, 10), 182 | CX(15, 4), 183 | CX(1, 13), 184 | H(11), 185 | H(12), 186 | H(13), 187 | H(14), 188 | CCX(11, 2, 7), 189 | CCX(3, 6, 8), 190 | X(1), 191 | CX(1, 3), 192 | CCX(2, 5, 10), 193 | CX(6, 7), 194 | CX(12, 4), 195 | CX(1, 3), 196 | H(1), 197 | H(2), 198 | H(3), 199 | H(4), 200 | H(5), 201 | H(6), 202 | H(7), 203 | H(8), 204 | X(10), 205 | CX(6, 7), 206 | CX(12, 4), 207 | CX(1, 3) 208 | ) 209 | 210 | val rand20: Circuit = List( 211 | H(0), 212 | H(1), 213 | H(2), 214 | H(3), 215 | H(4), 216 | H(5), 217 | H(6), 218 | H(7), 219 | H(8), 220 | H(9), 221 | H(10), 222 | H(11), 223 | H(12), 224 | H(13), 225 | H(14), 226 | H(15), 227 | H(16), 228 | H(17), 229 | H(18), 230 | H(19), 231 | CX(0, 2), 232 | CX(0, 3), 233 | CX(1, 2), 234 | CX(1, 3), 235 | CX(5, 10), 236 | CX(6, 7), 237 | CX(18, 4), 238 | CX(1, 3), 239 | CX(11, 8), 240 | CX(14, 3), 241 | CX(1, 12), 242 | CX(9, 3), 243 | CX(7, 10), 244 | CX(16, 5), 245 | CX(19, 4), 246 | CX(1, 13), 247 | H(11), 248 | H(12), 249 | H(13), 250 | H(14), 251 | H(15), 252 | H(16), 253 | H(17), 254 | H(18), 255 | H(19), 256 | CX(0, 2), 257 | CX(0, 3), 258 | CX(1, 2), 259 | CX(1, 3), 260 | CX(5, 10), 261 | CX(6, 7), 262 | CX(18, 4), 263 | CX(1, 3) 264 | /* 265 | H(0), H(1), H(2), H(3), H(4), H(5), H(6), H(7), H(8), H(9), H(10), 266 | CX(11, 8), CX(14, 3), CX(1, 12), CX(9, 3), CX(7, 10), CX(16, 5), CX(19, 4), CX(1, 13), 267 | H(0), H(1), H(2), H(3), H(4), H(5), H(6), H(7), H(8), H(9), H(10), 268 | H(11), H(12), H(13), H(14), H(15), H(16), H(17), H(18), H(19), 269 | CX(0, 2), CX(0, 3), CX(1, 2), CX(1, 3), CX(5, 10), CX(6, 7), CX(18, 4), CX(1, 3), 270 | CX(11, 8), CX(14, 3), CX(1, 12), CX(9, 3), CX(7, 10), CX(16, 5), CX(19, 4), CX(1, 13), 271 | H(11), H(12), H(13), H(14), H(15), H(16), H(17), H(18), H(19), 272 | CX(0, 2), CX(0, 3), CX(1, 2), CX(1, 3), CX(5, 10), CX(6, 7), CX(18, 4), CX(1, 3), 273 | H(0), H(1), H(2), H(3), H(4), H(5), H(6), H(7), H(8), H(9), H(10), 274 | CX(11, 8), CX(14, 3), CX(1, 12), CX(9, 3), CX(7, 10), CX(16, 5), CX(19, 4), CX(1, 13), 275 | */ 276 | ) 277 | 278 | // 2 qubits 279 | // XXX: it needs to _only_ measure q0, 280 | // but I don't know how to do that yet. 281 | val DeutschJozsa: Circuit = List( 282 | X(1), 283 | H(0), 284 | H(1), 285 | H(0) 286 | ) 287 | } 288 | -------------------------------------------------------------------------------- /src/main/scala/feynman/EvalState.scala: -------------------------------------------------------------------------------- 1 | package quantum.feynman 2 | 3 | // Auxiliary definitions for the evaluators 4 | 5 | import quantum.circuit.Syntax._ 6 | 7 | object EvalState { 8 | case class State(d: Double, bs: Vector[Boolean]) { 9 | def toMap: Map[Vector[Boolean], Double] = Map(bs -> d) 10 | } 11 | object State { 12 | def apply(i: Int): State = State(1.0, Vector.fill(i)(false)) 13 | } 14 | 15 | val hscale: Double = 1.0 / math.sqrt(2.0) 16 | 17 | def isSet(bs: Vector[Boolean], x: Exp): Boolean = x match { 18 | case Wire(pos) => bs(pos) 19 | case Bit(b) => b 20 | } 21 | 22 | def neg(bs: Vector[Boolean], x: Exp): Vector[Boolean] = x match { 23 | case Wire(pos) => bs.updated(pos, !bs(pos)) 24 | } 25 | 26 | def prettyPrint(m: Map[Vector[Boolean], Double]): Unit = { 27 | m.filter(kv => math.abs(kv._2) > 0.001).foreach { case (k, v) => 28 | val p = (if (v > 0) "+" else "") + f"$v%.3f" 29 | val vs = k.map(x => if (x) "1" else "0").mkString 30 | print(s"$p|$vs⟩ ") 31 | } 32 | } 33 | } 34 | -------------------------------------------------------------------------------- /src/main/scala/feynman/QCompilerCPS.scala: -------------------------------------------------------------------------------- 1 | package quantum.feynman.staged 2 | 3 | // Quantum circuit compiler in Feynman-style using continuations 4 | 5 | import lms.core._ 6 | import lms.core.stub._ 7 | import lms.core.virtualize 8 | import lms.macros.SourceContext 9 | import lms.thirdparty.CLibs 10 | import lms.thirdparty.CCodeGenLibs 11 | 12 | import lms.core.Backend._ 13 | 14 | import scala.util.continuations._ 15 | import scala.collection.immutable.{List => SList} 16 | 17 | import quantum.circuit.Examples._ 18 | import quantum.circuit.Syntax.{Exp => QExp, _} 19 | 20 | @virtualize 21 | trait QCState extends Dsl { 22 | lazy val hscale: Rep[Double] = 1.0 / math.sqrt(2.0) 23 | 24 | abstract class Bits 25 | 26 | object Bits { 27 | def apply(bs: List[Boolean]): Rep[Bits] = { 28 | val sbs = bs.map(Backend.Const(_)) 29 | Wrap[Bits](Adapter.g.reflect("new_bits", sbs: _*)) 30 | } 31 | } 32 | 33 | implicit class BitsOps(bs: Rep[Bits]) { 34 | def apply(i: Int): Rep[Boolean] = Unwrap(bs) match { 35 | case Adapter.g.Def("new_bits", xs) => 36 | Wrap[Boolean](Backend.Const(xs(i).asInstanceOf[Backend.Const].x.asInstanceOf[Boolean])) 37 | case _ => Wrap[Boolean](Adapter.g.reflect("bits_get", Unwrap(bs), Backend.Const(i))) 38 | } 39 | def set(i: Int, v: Rep[Boolean]): Rep[Bits] = 40 | Wrap[Bits](Adapter.g.reflect("bits_set", Unwrap(bs), Backend.Const(i), Unwrap(v))) 41 | } 42 | 43 | abstract class State 44 | 45 | object State { 46 | def apply(d: Rep[Double], bs: Rep[Bits]): Rep[State] = 47 | Wrap[State](Adapter.g.reflect("new_state", Unwrap(d), Unwrap(bs))) 48 | def apply(size: Int): Rep[State] = apply(1.0, Bits(List.fill(size)(false))) 49 | } 50 | 51 | implicit class StateOps(s: Rep[State]) { 52 | def d: Rep[Double] = Unwrap(s) match { 53 | case Adapter.g.Def("new_state", SList(d: Backend.Exp, _)) => Wrap[Double](d) 54 | case _ => Wrap[Double](Adapter.g.reflect("state_d", Unwrap(s))) 55 | } 56 | def bs: Rep[Bits] = Unwrap(s) match { 57 | case Adapter.g.Def("new_state", SList(_, bs: Backend.Exp)) => Wrap[Bits](bs) 58 | case _ => Wrap[Bits](Adapter.g.reflect("state_bs", Unwrap(s))) 59 | } 60 | } 61 | 62 | def isSet(bs: Rep[Bits], x: QExp): Rep[Boolean] = x match { 63 | case Wire(pos) => bs(pos) 64 | case Bit(b) => b 65 | } 66 | 67 | def neg(bs: Rep[Bits], x: QExp): Rep[Bits] = x match { 68 | case Wire(pos) => bs.set(pos, !bs(pos)) 69 | } 70 | } 71 | 72 | @virtualize 73 | trait QCompilerCPS extends QCState { 74 | type Ans = Unit 75 | 76 | val genCircuitInfo = true 77 | 78 | def info(s: String): Unit = if (genCircuitInfo) unchecked[Unit](s) else () 79 | 80 | def summarize(s: Rep[State]): Rep[Ans] = 81 | Wrap[Ans](Adapter.g.reflectWrite("summarize", Unwrap(s))(Adapter.CTRL)) 82 | 83 | def evalGate(g: Gate, v: Rep[State], k: Rep[State] => Rep[Ans]): Rep[Ans] = { 84 | info(s"// $g") 85 | val repK: Rep[State => Ans] = topFun(k) 86 | // val repK: Rep[State => Ans] = Wrap[State => Ans](__topFun(k, 1, xn => Unwrap(k(Wrap[State](xn(0)))), "inline")) 87 | g match { 88 | case CCX(x, y, z) => 89 | if (isSet(v.bs, x) && isSet(v.bs, y)) repK(State(v.d, neg(v.bs, z))) else repK(v) 90 | case H(x) => 91 | if (isSet(v.bs, x)) { 92 | repK(State(hscale * v.d, neg(v.bs, x))) 93 | repK(State(-1.0 * hscale * v.d, v.bs)) 94 | } else { 95 | repK(State(hscale * v.d, neg(v.bs, x))) 96 | repK(State(hscale * v.d, v.bs)) 97 | } 98 | } 99 | } 100 | 101 | def evalCircuit(c: Circuit, v: Rep[State], k: Rep[State] => Rep[Ans]): Rep[Ans] = 102 | if (c.isEmpty) k(v) else evalGate(c.head, v, s => evalCircuit(c.tail, s, k)) 103 | 104 | def runCircuit(c: Circuit, v: Rep[State]): Rep[Ans] = evalCircuit(c, v, summarize) 105 | 106 | } 107 | 108 | abstract class QCDriver[A: Manifest, B: Manifest] extends DslDriverCPP[A, B] { q => 109 | val circuitSize: Int 110 | val repeat: Int = 1 111 | 112 | override val compilerCommand = "g++ -std=c++20 -O3" 113 | override val sourceFile = "snippet.cpp" 114 | override val executable = "./snippet" 115 | override val codegen = new DslGenCPP { 116 | val IR: q.type = q 117 | registerHeader("") 118 | registerHeader("") 119 | registerHeader("") 120 | registerHeader("") 121 | registerHeader("") 122 | 123 | // TODO: consider using std::bitset 124 | // TODO: instead of using a map as summary, we could allocate 125 | // an array of 2^n elements to store the prob amplitudes of each 126 | // possible states. 127 | lazy val prelude: String = s""" 128 | |using namespace std::chrono; 129 | |typedef std::array Bits; 130 | |typedef struct State { double d; Bits bs; } State; 131 | |Bits bits_set(const Bits& bs, size_t i, bool v) { Bits res = bs; res[i] = v; return res; } 132 | |void print_state(State s) { 133 | | std::cout << s.d << "|"; 134 | | for (int i = 0; i < $circuitSize; i++) std::cout << (s.bs[i] ? "1" : "0"); 135 | | std::cout << "⟩\\n"; 136 | |} 137 | |struct BitsCmp { 138 | | bool operator()(const Bits& lhs, const Bits& rhs) const { 139 | | for (size_t i = 0; i < $circuitSize; i++) { if (lhs[i] < rhs[i]) return true; } 140 | | return false; 141 | | } 142 | |}; 143 | |std::map summary; 144 | |void summarize(State s) { 145 | | if (summary.contains(s.bs)) { summary[s.bs] = summary[s.bs] + s.d; } 146 | | else { summary[s.bs] = s.d; } 147 | |} 148 | |void print_summary() { 149 | | for (const auto& [bs, d] : summary) { 150 | | if (abs(d) < 0.0001) continue; 151 | | std::cout << d << "|"; 152 | | for (int i = 0; i < $circuitSize; i++) std::cout << (bs[i] ? "1" : "0"); 153 | | std::cout << "⟩\\n"; 154 | | } 155 | | std::cout << \"#results: \" << summary.size() << \"\\n\"; 156 | |} 157 | """.stripMargin 158 | 159 | override def remap(m: Manifest[_]): String = { 160 | if (m.toString.endsWith("$State")) "State" 161 | else if (m.toString.endsWith("$Bits")) "Bits" 162 | else super.remap(m) 163 | } 164 | 165 | override def shallow(n: Node): Unit = n match { 166 | case Node(s, "new_bits", xs, _) => es"""{${xs.mkString(", ")}}""" 167 | case Node(s, "bits_get", bs :: i :: Nil, _) => es"$bs[$i]" 168 | case Node(s, "new_state", d :: bs :: Nil, _) => es"{ $d, $bs }" 169 | case Node(s, "state_d", st :: Nil, _) => es"$st.d" 170 | case Node(s, "state_bs", st :: Nil, _) => es"$st.bs" 171 | case _ => super.shallow(n) 172 | } 173 | 174 | override def emitAll(g: Graph, name: String)(m1: Manifest[_], m2: Manifest[_]): Unit = { 175 | val ng = init(g) 176 | val efs = "" 177 | val stt = dce.statics.toList.map(quoteStatic).mkString(", ") 178 | val src = run(name, ng) 179 | emitDefines(stream) 180 | emitHeaders(stream) 181 | emit(prelude) 182 | emitFunctionDecls(stream) 183 | emitDatastructures(stream) 184 | emitFunctions(stream) 185 | emitInit(stream) 186 | emitln(s"\n/**************** $name ****************/") 187 | emit(src) 188 | emitln(s""" 189 | |int main(int argc, char *argv[]) { 190 | | auto start = high_resolution_clock::now(); 191 | | for (size_t i = 0; i < $repeat; i++) { summary.clear(); $name(0); } 192 | | auto end = high_resolution_clock::now(); 193 | | auto duration = duration_cast(end - start); 194 | | std::cout << std::fixed; 195 | | std::cout << "time: "; 196 | | std::cout << (duration_cast(duration).count() / 1.0e6) << "s\\n"; 197 | | print_summary(); 198 | | return 0; 199 | |}""".stripMargin) 200 | } 201 | } 202 | } 203 | 204 | object TestQC { 205 | def main(args: Array[String]): Unit = { 206 | val snippet = new QCDriver[Int, Unit] with QCompilerCPS { 207 | val circuitSize: Int = 4 208 | override val repeat: Int = 1 209 | def snippet(s: Rep[Int]): Rep[Unit] = runCircuit(simon, State(circuitSize)) 210 | } 211 | snippet.eval(0) 212 | } 213 | } 214 | -------------------------------------------------------------------------------- /src/main/scala/feynman/QContSim1.scala: -------------------------------------------------------------------------------- 1 | package quantum.feynman 2 | 3 | // Unstaged quantum circuit evaluator in Feynman-style using continuations 4 | // Following the Scheme Pearl paper on Quantum Continuation by Choudhury, Agapiev and Sabry 5 | 6 | import scala.util.continuations._ 7 | 8 | import quantum.circuit.Syntax._ 9 | import quantum.circuit.Examples._ 10 | import quantum.utils.Utils 11 | import EvalState._ 12 | 13 | object QuantumEvalCont { 14 | // Accumulate states and their probability amplitudes 15 | type Ans = Map[Vector[Boolean], Double] 16 | 17 | def collect(x: State, y: State): State @cps[Ans] = shift { k => 18 | val a = k(x) 19 | val b = k(y) 20 | a.foldLeft(b) { case (m, (k, v)) => m + (k -> (m.getOrElse(k, 0.0) + v)) } 21 | } 22 | 23 | def evalGate(g: Gate, v: State): State @cps[Ans] = { 24 | val State(d, bs) = v 25 | g match { 26 | case CCX(x, y, z) if isSet(bs, x) && isSet(bs, y) => State(d, neg(bs, z)) 27 | case CCX(x, y, z) => v 28 | case H(x) if isSet(bs, x) => collect(State(hscale * d, neg(bs, x)), State(-1.0 * hscale * d, bs)) 29 | case H(x) => collect(State(hscale * d, neg(bs, x)), State(hscale * d, bs)) 30 | } 31 | } 32 | 33 | def evalCircuit(c: Circuit, v: State): State @cps[Ans] = 34 | if (c.isEmpty) v else evalCircuit(c.tail, evalGate(c.head, v)) 35 | 36 | def runCircuit(c: Circuit, v: State): Ans = reset { evalCircuit(c, v).toMap } 37 | 38 | } 39 | 40 | object TestQContSim { 41 | // Comparing the performance of QuantumEvalCont and QuantumEvalCPS 42 | def main(args: Array[String]): Unit = { 43 | val N = 1000 44 | val (_, t1) = Utils.time { 45 | for (i <- 0 to N) { 46 | QuantumEvalCont.runCircuit(simon, State(4)) 47 | // prettyPrint(QuantumEvalCont.runCircuit(simon, State(4))) 48 | // println() 49 | } 50 | } 51 | val (_, t2) = Utils.time { 52 | for (i <- 0 to N) { 53 | QuantumEvalCPS.runCircuit(simon, State(4)) 54 | // prettyPrint(QuantumEvalCPS.summary.toMap) 55 | // println() 56 | } 57 | } 58 | println(s"$t1 sec; $t2 sec") 59 | } 60 | 61 | } 62 | -------------------------------------------------------------------------------- /src/main/scala/feynman/QContSim2.scala: -------------------------------------------------------------------------------- 1 | package quantum.feynman 2 | 3 | // Unstaged quantum circuit evaluator in Feynman-style using continuations 4 | // Instead of using shift/reset, this one is directly written in CPS 5 | 6 | import quantum.circuit.Syntax._ 7 | import quantum.utils.Utils 8 | import EvalState._ 9 | 10 | import scala.collection.mutable.HashMap 11 | 12 | object QuantumEvalCPS { 13 | type Ans = Unit 14 | 15 | val summary: HashMap[Vector[Boolean], Double] = HashMap() 16 | 17 | def evalGate(g: Gate, s: State, k: State => Ans): Ans = 18 | g match { 19 | case CCX(x, y, z) if isSet(s.bs, x) && isSet(s.bs, y) => k(State(s.d, neg(s.bs, z))) 20 | case CCX(x, y, z) => k(s) 21 | case H(x) if isSet(s.bs, x) => 22 | k(State(hscale * s.d, neg(s.bs, x))) 23 | k(State(-1.0 * hscale * s.d, s.bs)) 24 | case H(x) => 25 | k(State(hscale * s.d, neg(s.bs, x))) 26 | k(State(hscale * s.d, s.bs)) 27 | } 28 | 29 | def evalCircuit(c: Circuit, s: State, k: State => Ans): Ans = 30 | if (c.isEmpty) k(s) else evalGate(c.head, s, s => evalCircuit(c.tail, s, k)) 31 | 32 | def runCircuit(c: Circuit, s: State): Ans = { 33 | summary.clear 34 | evalCircuit(c, s, summarize) 35 | } 36 | 37 | def summarize(s: State): Unit = 38 | if (summary.contains(s.bs)) summary(s.bs) = summary(s.bs) + s.d 39 | else summary(s.bs) = s.d 40 | } 41 | -------------------------------------------------------------------------------- /src/main/scala/schrodinger/Complex.scala: -------------------------------------------------------------------------------- 1 | package quantum.schrodinger 2 | 3 | import java.text.DecimalFormat 4 | import Math.abs 5 | 6 | // Unstaged complex numbers 7 | 8 | case class Complex(re: Double, im: Double) { 9 | override def toString = s"{$re, $im}" 10 | 11 | def prettyPrint: String = { 12 | val decFormat = new DecimalFormat("#.000") 13 | val reStr = decFormat.format(re) 14 | val imStr = decFormat.format(im) 15 | if (im == 0) reStr 16 | else reStr + (if (im > 0) "+" + imStr else imStr) + "*i" 17 | } 18 | 19 | def +(c: Complex) = Complex(re + c.re, im + c.im) 20 | def -(c: Complex) = Complex(re - c.re, im - c.im) 21 | def *(c: Complex) = Complex(re * c.re - im * c.im, re * c.im + im * c.re) 22 | 23 | def ≈(c: Complex): Boolean = { 24 | val eps = 0.001 25 | abs(im - c.im) <= eps && abs(re - c.re) <= eps 26 | } 27 | def !≈(c: Complex): Boolean = !this.≈(c) 28 | } 29 | 30 | object Complex { 31 | implicit def fromDouble(d: Double): Complex = Complex(d, 0) 32 | } 33 | -------------------------------------------------------------------------------- /src/main/scala/schrodinger/DecompMatrix.scala: -------------------------------------------------------------------------------- 1 | package quantum.schrodinger.staged 2 | 3 | import math.pow 4 | 5 | import lms.core._ 6 | import lms.core.stub._ 7 | import lms.core.Backend._ 8 | import lms.core.virtualize 9 | import lms.macros.SourceContext 10 | import lms.thirdparty.CLibs 11 | import lms.thirdparty.CCodeGenLibs 12 | 13 | import quantum._ 14 | import quantum.circuit.Syntax._ 15 | import quantum.schrodinger._ 16 | import quantum.schrodinger.gate.{Gate, _} 17 | 18 | import scala.collection.immutable.{List => SList} 19 | 20 | trait DecompMatrix { q: Dsl => 21 | // basic matrix element, represented as opaque IR node 22 | abstract class MElem 23 | 24 | def zeros(cols: Int, rows: Int): AtomMat = 25 | AtomMat(Wrap[MElem](Adapter.g.reflect("zeros", Unwrap(unit(cols)), Unwrap(unit(rows))))) 26 | def rand(cols: Int, rows: Int): AtomMat = 27 | AtomMat(Wrap[MElem](Adapter.g.reflect("rand", Unwrap(unit(cols)), Unwrap(unit(rows))))) 28 | def sqZeros(n: Int): AtomMat = zeros(n, n) 29 | def id(n: Int): AtomMat = 30 | AtomMat(Wrap[MElem](Adapter.g.reflect("id", Unwrap(unit(n))))) 31 | def invId(n: Int): AtomMat = 32 | AtomMat(Wrap[MElem](Adapter.g.reflect("inv-id", Unwrap(unit(n))))) 33 | 34 | 35 | // abstract vector, representing the initial input vector and final result 36 | abstract class AbsVec 37 | 38 | type State = Rep[AbsVec] 39 | def initState(size: Int): State = 40 | Wrap[AbsVec](Adapter.g.reflectWrite("init-state", Unwrap(unit(size)))(Adapter.CTRL)) 41 | 42 | implicit class AbsVecOps(v: Rep[AbsVec]) { 43 | def size: Int = Unwrap(v) match { 44 | case Adapter.g.Def(_, ops) => 45 | val Backend.Const(n: Int) = ops.last 46 | n 47 | } 48 | } 49 | 50 | // abstract matrix, being either an atomic matrix holding an MElem, 51 | // or a large matrix that can be decomposed to smaller AbsMats 52 | abstract class AbsMat { 53 | type T = Int // TODO: should generalize 54 | def draw(offset: (Int, Int) = (0, 0)): Unit 55 | def apply(i: Int, j: Int): Rep[T] 56 | def dim: (Int, Int) /* (cols, rows) */ 57 | def ⊗(y: AbsMat): AbsMat 58 | def *(y: Rep[AbsVec]): Rep[AbsVec] 59 | } 60 | case class AtomMat(e: Rep[MElem]) extends AbsMat { 61 | def draw(offset: (Int, Int)) = 62 | Adapter.g.reflectWrite("draw", Unwrap(e), Unwrap(unit(offset._1)), Unwrap(unit(offset._2)))(Adapter.CTRL) 63 | def apply(i: Int, j: Int): Rep[T] = Unwrap(e) match { 64 | case Adapter.g.Def("zeros", _) => unit(0) 65 | case Adapter.g.Def("id", _) => if (i == j) unit(1) else unit(0) 66 | case Adapter.g.Def("inv-id", SList(Backend.Const(n: Int))) => if (j + j == n) unit(1) else unit(0) 67 | case _ => Wrap[T](Adapter.g.reflect("at", Unwrap(e), Unwrap(unit(i)), Unwrap(unit(j)))) 68 | } 69 | def dim: (Int, Int) = Unwrap(e) match { 70 | // The last two operands are always dimensions 71 | case Adapter.g.Def("zeros", SList(Backend.Const(cols: Int), Backend.Const(rows: Int))) => (cols, rows) 72 | case Adapter.g.Def("rand", SList(Backend.Const(cols: Int), Backend.Const(rows: Int))) => (cols, rows) 73 | case Adapter.g.Def("id", SList(Backend.Const(n: Int))) => (n, n) 74 | case Adapter.g.Def("inv-id", SList(Backend.Const(n: Int))) => (n, n) 75 | case Adapter.g.Def("kron", SList(_, _, Backend.Const(cols: Int), Backend.Const(rows: Int))) => 76 | (cols, rows) 77 | case _ => System.out.println(e); ??? 78 | } 79 | 80 | def ⊗(y: AbsMat): AbsMat = Unwrap(e) match { 81 | case Adapter.g.Def("zeros", SList(Backend.Const(p: Int), Backend.Const(q: Int))) => 82 | val (m, n) = y.dim 83 | zeros(p*m, q*n) 84 | case Adapter.g.Def("id", SList(Backend.Const(h: Int))) => 85 | val (m, n) = y.dim 86 | val zs = List.fill(h)(List.fill(h)(zeros(m, n))) 87 | DecomposedMat(zs.zipWithIndex.map { case (z, i) => z.updated(i, y) }) 88 | case _ => 89 | val (m, n) = this.dim 90 | val (p, q) = y.dim 91 | y match { 92 | case AtomMat(y) => Unwrap(y) match { 93 | case Adapter.g.Def("zeros", SList(Backend.Const(p: Int), Backend.Const(q: Int))) => 94 | zeros(m*p, q*n) 95 | case _ => 96 | // the kronecker product of two atom matrix is still an atom matrix (represented as opaque IR node) 97 | val mat = Wrap[MElem](Adapter.g.reflect("kron", Unwrap(e), Unwrap(y), Unwrap(unit(m*p)), Unwrap(unit(q*n)))) 98 | AtomMat(mat) 99 | } 100 | case DecomposedMat(m) => DecomposedMat(m.map { row => row.map(this ⊗ _) }) 101 | } 102 | } 103 | def *(y: Rep[AbsVec]): Rep[AbsVec] = { 104 | val (m, n) = this.dim 105 | val p: Int = y.size 106 | require(n == p, "dimension error") 107 | // result vector size m 108 | Unwrap(e) match { 109 | case Adapter.g.Def("zeros", SList(Backend.Const(p: Int), Backend.Const(q: Int))) => 110 | Wrap[AbsVec](Adapter.g.reflect("zeros_vec", Unwrap(unit(m)))) 111 | case Adapter.g.Def("id", SList(Backend.Const(h: Int))) => y 112 | case _ => Wrap[AbsVec](Adapter.g.reflect("matvecprod", Unwrap(e), Unwrap(y), Unwrap(unit(m)))) 113 | } 114 | } 115 | } 116 | case class DecomposedMat(m: List[List[AbsMat]]) extends AbsMat { 117 | // need additional indexing, generalizing CSR 118 | def draw(offset: (Int, Int)) = 119 | m.foldLeft(offset._2) { case (rowAcc, row) => 120 | row.foldLeft(offset._1) { case (colAcc, e) => 121 | e.draw((rowAcc, colAcc)) 122 | colAcc + e.dim._1 123 | } 124 | rowAcc + row(0).dim._2 125 | } 126 | def apply(i: Int, j: Int): Rep[T] = { 127 | // Note(GW): This is not a very efficient "random access", some form of indexing would help. 128 | def accessRow(m: List[List[AbsMat]], rowAcc: Int): (List[AbsMat], Int) = m match { 129 | case Nil => (List(), rowAcc) 130 | case r::rest => 131 | if (rowAcc <= i && i < rowAcc + r.head.dim._1) (r, rowAcc) 132 | else accessRow(rest, rowAcc + r.head.dim._1) 133 | } 134 | def accessCol(r: List[AbsMat], colAcc: Int): (AbsMat, Int) = r match { 135 | case Nil => throw new RuntimeException(s"Invalid index ($i, $j)") 136 | case m::rest => 137 | if (colAcc <= j && j < colAcc + m.dim._2) (m, colAcc) 138 | else accessCol(rest, colAcc + m.dim._2) 139 | } 140 | val (row, rowOffset) = accessRow(m, 0) 141 | val (mat, colOffset) = accessCol(row, 0) 142 | mat(i-rowOffset, j-colOffset) 143 | } 144 | def dim: (Int, Int) = { 145 | val cols = m(0).foldLeft(0) { (acc, x) => acc + x.dim._1 } 146 | val rows = m.map(_.head).foldLeft(0) { (acc, x) => acc + x.dim._2 } 147 | (cols, rows) 148 | } 149 | def ⊗(y: AbsMat): AbsMat = DecomposedMat(m.map { row => row.map(_ ⊗ y) }) 150 | def *(y: Rep[AbsVec]): Rep[AbsVec] = ??? 151 | } 152 | } 153 | 154 | trait CppCodeGen_DecompMatrix extends ExtendedCPPCodeGen { 155 | override def remap(m: Manifest[_]): String = { 156 | if (m.runtimeClass.getName.endsWith("AbsMat")) "AbsMat" 157 | else if (m.runtimeClass.getName.endsWith("AbsVec")) "AbsVec" 158 | else if (m.runtimeClass.getName.endsWith("MElem")) "AbsMat" 159 | else super.remap(m) 160 | } 161 | 162 | override def quote(s: Def): String = s match { 163 | case _ => super.quote(s) 164 | } 165 | 166 | override def shallow(n: Node): Unit = n match { 167 | case n @ Node(s, "kron", List(x, y, cols, rows), _) => es"$x ⊗ $y /* dim: ($cols, $rows) */" 168 | case n @ Node(s, "matvecprod", List(x, y, size), _) => es"$x * $y /* size: $size */" 169 | case _ => super.shallow(n) 170 | } 171 | } 172 | 173 | abstract class StagedDecomposedMat extends DslDriverCPP[Int, Unit] with DecompMatrix { q => 174 | override val codegen = new QCodeGen with CppCodeGen_DecompMatrix { 175 | val IR: q.type = q 176 | } 177 | 178 | def snippet(n: Rep[Int]): Rep[Unit] = { 179 | val id2 = id(2) 180 | val zr2 = sqZeros(2) 181 | val m1 = DecomposedMat(List( 182 | List(id(2), sqZeros(2)), 183 | List(sqZeros(2), id(2)))) // id(4) 184 | val m2 = rand(4, 4) 185 | /* 186 | m2 z4 z4 z4 187 | z4 m2 z4 z4 188 | m2 z4 189 | m2 190 | */ 191 | val m3 = m1 ⊗ m2 192 | // m3.regroup() // XXX: would be great/interesting to regourp the structure 193 | val m4 = m3 ⊗ m1 194 | m4.draw() 195 | //m1.draw() 196 | //println((m2 ⊗ id2)(0, 0)) 197 | 198 | //val v = zr2 * initState(2) 199 | //println(v) 200 | //(zr2 ⊗ m2).draw 201 | println("End") 202 | } 203 | } 204 | 205 | object TestStagedDecompMatrix { 206 | def main(args: Array[String]): Unit = { 207 | val driver = new StagedDecomposedMat {} 208 | println(driver.code) 209 | } 210 | } 211 | -------------------------------------------------------------------------------- /src/main/scala/schrodinger/Gate.scala: -------------------------------------------------------------------------------- 1 | package quantum.schrodinger.gate 2 | 3 | import math.{pow, log} 4 | import quantum.schrodinger.Complex 5 | import quantum.schrodinger.Complex._ 6 | import quantum.schrodinger.Matrix._ 7 | 8 | // Static gate matrix-representations 9 | 10 | case class Gate(id: String, m: Matrix) { 11 | def arity: Int = (log(m.size) / log(2)).toInt 12 | } 13 | 14 | object Gate { 15 | val isq2 = 1.0 / pow(2.0, 0.5) 16 | 17 | val H = Gate( 18 | "H", 19 | Array( 20 | Array(isq2, isq2), 21 | Array(isq2, -isq2) 22 | ) 23 | ) 24 | 25 | val SWAP = Gate( 26 | "SWAP", 27 | Array( 28 | Array(1, 0, 0, 0), 29 | Array(0, 0, 1, 0), 30 | Array(0, 1, 0, 0), 31 | Array(0, 0, 0, 1) 32 | ) 33 | ) 34 | 35 | val NOT = Gate( 36 | "NOT", 37 | Array( 38 | Array(0, 1), 39 | Array(1, 0) 40 | ) 41 | ) 42 | 43 | val CNOT = Gate( 44 | "CNOT", 45 | Array( 46 | Array(1, 0, 0, 0), 47 | Array(0, 1, 0, 0), 48 | Array(0, 0, 0, 1), 49 | Array(0, 0, 1, 0) 50 | ) 51 | ) 52 | 53 | val S = Gate( 54 | "S", 55 | Array( 56 | Array(1, 0), 57 | Array(0, Complex(0, 1)) 58 | ) 59 | ) 60 | 61 | val T = Gate( 62 | "T", 63 | Array( 64 | Array(1, 0), 65 | Array(0, isq2 + isq2 * Complex(0, 1)) 66 | ) 67 | ) 68 | 69 | val Z = Gate( 70 | "Z", 71 | Array( 72 | Array(1, 0), 73 | Array(0, -1) 74 | ) 75 | ) 76 | 77 | val P = Gate( 78 | "P", 79 | 0.5 * Z.m 80 | ) 81 | 82 | val CZ = Gate( 83 | "CZ", 84 | Array( 85 | Array(1, 0, 0, 0), 86 | Array(0, 1, 0, 0), 87 | Array(0, 0, 1, 0), 88 | Array(0, 0, 0, -1) 89 | ) 90 | ) 91 | 92 | val CCNOT = Gate( 93 | "CCNOT", 94 | Array( 95 | Array(1, 0, 0, 0, 0, 0, 0, 0), 96 | Array(0, 1, 0, 0, 0, 0, 0, 0), 97 | Array(0, 0, 1, 0, 0, 0, 0, 0), 98 | Array(0, 0, 0, 1, 0, 0, 0, 0), 99 | Array(0, 0, 0, 0, 1, 0, 0, 0), 100 | Array(0, 0, 0, 0, 0, 1, 0, 0), 101 | Array(0, 0, 0, 0, 0, 0, 0, 1), 102 | Array(0, 0, 0, 0, 0, 0, 1, 0) 103 | ) 104 | ) 105 | } 106 | -------------------------------------------------------------------------------- /src/main/scala/schrodinger/Matrix.scala: -------------------------------------------------------------------------------- 1 | package quantum.schrodinger 2 | 3 | import Complex._ 4 | 5 | // Unstaged matrix operations 6 | 7 | object Matrix { 8 | type Matrix = Array[Array[Complex]] 9 | 10 | def identity(n: Int): Matrix = { 11 | val result = zeros(n) 12 | for (i <- 0 until n) { 13 | result(i)(i) = 1 14 | } 15 | result 16 | } 17 | 18 | def zeros(n: Int): Matrix = { 19 | val result = Array.ofDim[Complex](n, n) 20 | for (i <- 0 until n) { 21 | for (j <- 0 until n) { 22 | result(i)(j) = 0 23 | } 24 | } 25 | result 26 | } 27 | 28 | def zerosVec(n: Int): Array[Complex] = Array.fill(n)(0) 29 | 30 | def prettyPrint(A: Matrix): String = { 31 | val sb = new StringBuilder 32 | val nRows = A.size 33 | val nCols = A(0).size 34 | sb ++= "[" 35 | for (i <- 0 until nRows) { 36 | sb ++= (new ArrayOps(A(i))).toString 37 | sb ++= "\n" 38 | } 39 | sb ++= "]\n" 40 | sb.toString 41 | } 42 | 43 | implicit class ArrayOps(A: Array[Complex]) { 44 | override def toString: String = { 45 | val sb = new StringBuilder 46 | sb ++= "[" 47 | for (i <- 0 until A.size) sb ++= s"${A(i)}, " 48 | sb ++= "]" 49 | sb.toString 50 | } 51 | } 52 | 53 | implicit class DoubleOps(scalar: Double) { 54 | def *(A: Matrix): Matrix = { 55 | val nRowsA = A.size 56 | val nColsA = A(0).size 57 | val result = Array.ofDim[Complex](nRowsA, nColsA) 58 | for (i <- 0 until nRowsA) { 59 | for (j <- 0 until nColsA) { 60 | result(i)(j) = scalar * A(i)(j) 61 | } 62 | } 63 | result 64 | } 65 | } 66 | 67 | implicit class MatrixOps(A: Matrix) { 68 | def pPrint: String = prettyPrint(A) 69 | // Unstaged Kronecker product 70 | def ⊗(B: Matrix): Matrix = { 71 | val nRowsA = A.size 72 | val nColsA = A(0).size 73 | val nRowsB = B.size 74 | val nColsB = B(0).size 75 | val result = Array.ofDim[Complex](nRowsA * nRowsB, nColsA * nColsB) 76 | for (i <- 0 until nRowsA; j <- 0 until nColsA) { 77 | for (k <- 0 until nRowsB; l <- 0 until nColsB) { 78 | result(i * nRowsB + k)(j * nColsB + l) = A(i)(j) * B(k)(l) 79 | } 80 | } 81 | // println(s"${A.dim} ⊗ ${B.dim} = ${result.dim}") 82 | result 83 | } 84 | 85 | // matrix multiplication 86 | def *(B: Matrix): Matrix = { 87 | val nRowsA = A.size 88 | val nColsA = A(0).size 89 | val nRowsB = B.size 90 | val nColsB = B(0).size 91 | require( 92 | nColsA == nRowsB, 93 | s"dimension error nColsA=$nColsA, nRowsB=$nRowsB \n${prettyPrint(A)} * ${prettyPrint(B)}" 94 | ) 95 | val result = Array.ofDim[Complex](nRowsA, nColsB) 96 | for (i <- 0 until nRowsA) { 97 | for (j <- 0 until nColsB) { 98 | var sum: Complex = 0.0 99 | for (k <- 0 until nColsA) { 100 | sum += A(i)(k) * B(k)(j) 101 | } 102 | result(i)(j) = sum 103 | } 104 | } 105 | // println(s"${A.dim} * ${B.dim} = ${result.dim}") 106 | result 107 | } 108 | 109 | // matrix-vector product 110 | def *(V: Array[Complex]): Array[Complex] = { 111 | val nRowsA = A.size 112 | val nColsA = A(0).size 113 | require(nColsA == V.size, s"dimension error, nColA: $nColsA, V.size: ${V.size}") 114 | val result = zerosVec(nRowsA) 115 | for (i <- 0 until nRowsA) { 116 | for (j <- 0 until nColsA) { 117 | result(i) += A(i)(j) * V(j) 118 | } 119 | } 120 | // println(s"${A.dim} * ${V.size} = ${result.size}") 121 | result 122 | } 123 | 124 | def dim: (Int, Int) = (A.size, A(0).size) 125 | } 126 | } 127 | -------------------------------------------------------------------------------- /src/main/scala/schrodinger/SchrodingerInterpreter.scala: -------------------------------------------------------------------------------- 1 | package quantum.schrodinger 2 | 3 | import quantum.schrodinger.gate.{Gate, _} 4 | 5 | trait SchrodingerInterpreter { 6 | type State 7 | val size: Int 8 | def op(g: Gate, i: Int): Unit 9 | def circuit: Unit 10 | def evalCircuit: Unit 11 | 12 | def H(i: Int): Unit = op(Gate.H, i) 13 | def SWAP(i: Int): Unit = op(Gate.SWAP, i) 14 | def NOT(i: Int): Unit = op(Gate.NOT, i) 15 | def CNOT(i: Int): Unit = op(Gate.CNOT, i) 16 | def CCNOT(i: Int): Unit = op(Gate.CCNOT, i) 17 | def S(i: Int): Unit = op(Gate.S, i) 18 | def T(i: Int): Unit = op(Gate.T, i) 19 | def Z(i: Int): Unit = op(Gate.Z, i) 20 | def CZ(i: Int): Unit = op(Gate.CZ, i) 21 | } 22 | 23 | -------------------------------------------------------------------------------- /src/main/scala/schrodinger/Shonan.scala: -------------------------------------------------------------------------------- 1 | package quantum.schrodinger.staged 2 | 3 | // The Shonan Challenge - matrix vector product 4 | // https://scala-lms.github.io/tutorials/shonan.html 5 | 6 | import lms.core._ 7 | import lms.core.stub._ 8 | import lms.core.virtualize 9 | import lms.macros.SourceContext 10 | import lms.thirdparty.CLibs 11 | import lms.thirdparty.CCodeGenLibs 12 | 13 | import lms.core.Backend._ 14 | import java.io.{ByteArrayOutputStream, PrintStream} 15 | 16 | object Shonan { 17 | val A = scala.Array 18 | val a = 19 | A(A(1, 1, 1, 1, 1), A(0, 0, 0, 0, 0), A(0, 0, 1, 0, 0), A(0, 0, 0, 0, 0), A(0, 0, 1, 0, 1)) 20 | 21 | val snippet = new DslDriverCPP[Array[Int], Array[Int]] { q => 22 | override val codegen = new QCodeGen { 23 | val IR: q.type = q 24 | } 25 | 26 | def unrollIf(c: Boolean, r: Range) = new { 27 | def foreach(f: Rep[Int] => Rep[Unit]) = { 28 | if (c) for (j <- (r.start until r.end): Range) f(j) 29 | else for (j <- (r.start until r.end): Rep[Range]) f(j) 30 | } 31 | } 32 | 33 | def snippet(v: Rep[Array[Int]]) = { 34 | def matVecProd(a0: Array[Array[Int]], v: Rep[Array[Int]]): Rep[Array[Int]] = { 35 | val n = a0.length 36 | val a = staticData(a0) 37 | val v1 = NewArray[Int](n) 38 | 39 | for (i <- (0 until n): Range) { 40 | val sparse = a0(i).count(_ != 0) < 3 41 | for (j <- unrollIf(sparse, 0 until n)) { 42 | v1(i) = v1(i) + a(i).apply(j) * v(j) 43 | } 44 | } 45 | v1 46 | } 47 | matVecProd(a, v) 48 | } 49 | } 50 | 51 | def main(args: Array[String]): Unit = { 52 | println(snippet.code) 53 | snippet.eval(Array(1, 2, 3, 4, 5)) 54 | // assert(snippet.eval(Array(1,2,3,4,5)).toList == List(15, 0, 3, 0, 8)) 55 | } 56 | } 57 | -------------------------------------------------------------------------------- /src/main/scala/schrodinger/StagedComplex.scala: -------------------------------------------------------------------------------- 1 | package quantum.schrodinger.staged 2 | 3 | import math.pow 4 | import quantum._ 5 | 6 | import lms.core._ 7 | import lms.core.stub._ 8 | import lms.core.virtualize 9 | import lms.macros.SourceContext 10 | import lms.thirdparty.CLibs 11 | import lms.thirdparty.CCodeGenLibs 12 | 13 | import lms.core.Backend._ 14 | import quantum.schrodinger.Complex 15 | 16 | // Staged complex numbers 17 | 18 | trait ComplexOps { b: Dsl => 19 | object Complex { 20 | def apply(re: Rep[Double], im: Rep[Double]): Rep[Complex] = 21 | Wrap[Complex](Adapter.g.reflect("complex-new", Unwrap(re), Unwrap(im))) 22 | } 23 | 24 | implicit def liftComplex(c: Complex): Rep[Complex] = Complex(c.re, c.im) 25 | implicit def liftDouble(d: Double): Rep[Complex] = Complex(d, 0) 26 | implicit def liftRepDouble(d: Rep[Double]): Rep[Complex] = Complex(d, 0) 27 | 28 | implicit class ComplexOps(c: Rep[Complex]) { 29 | def re: Rep[Double] = Wrap[Double](Adapter.g.reflect("complex-re", Unwrap(c))) 30 | def im: Rep[Double] = Wrap[Double](Adapter.g.reflect("complex-im", Unwrap(c))) 31 | def +(d: Rep[Complex]): Rep[Complex] = Complex(c.re + d.re, c.im + d.im) 32 | def +(d: Complex): Rep[Complex] = Complex(c.re + d.re, c.im + d.im) 33 | def -(d: Rep[Complex]): Rep[Complex] = Complex(c.re - d.re, c.im - d.im) 34 | def -(d: Complex): Rep[Complex] = Complex(c.re - d.re, c.im - d.im) 35 | def *(d: Rep[Complex]): Rep[Complex] = Complex(c.re * d.re - c.im * d.im, c.re * d.im + c.im * d.re) 36 | def *(d: Complex): Rep[Complex] = Complex(c.re * d.re - c.im * d.im, c.re * d.im + c.im * d.re) 37 | } 38 | 39 | implicit class StaticComplexOps(c: Complex) { 40 | def +(d: Rep[Complex]): Rep[Complex] = Complex(c.re + d.re, c.im + d.im) 41 | def -(d: Rep[Complex]): Rep[Complex] = Complex(c.re - d.re, c.im - d.im) 42 | def *(d: Rep[Complex]): Rep[Complex] = Complex(c.re * d.re - c.im * d.im, c.re * d.im + c.im * d.re) 43 | } 44 | } 45 | 46 | trait CppCodeGen_Complex extends ExtendedCPPCodeGen { 47 | override def remap(m: Manifest[_]): String = { 48 | if (m.runtimeClass.getName.endsWith("Complex")) "Complex" 49 | else super.remap(m) 50 | } 51 | 52 | override def quote(s: Def): String = s match { 53 | case Const(c: Complex) => s"{ ${c.re}, ${c.im} }" 54 | case _ => super.quote(s) 55 | } 56 | 57 | override def shallow(n: Node): Unit = n match { 58 | case Node(s, "complex-new", List(re, im), _) => es"{ ${re}, ${im} }" 59 | case Node(s, "complex-re", List(c), _) => es"$c.re" 60 | case Node(s, "complex-im", List(c), _) => es"$c.im" 61 | case _ => super.shallow(n) 62 | } 63 | } 64 | -------------------------------------------------------------------------------- /src/main/scala/schrodinger/StagedHLSchrodinger.scala: -------------------------------------------------------------------------------- 1 | package quantum.schrodinger.staged 2 | 3 | import math.pow 4 | 5 | import lms.core._ 6 | import lms.core.stub._ 7 | import lms.core.Backend._ 8 | import lms.core.virtualize 9 | import lms.macros.SourceContext 10 | import lms.thirdparty.CLibs 11 | import lms.thirdparty.CCodeGenLibs 12 | 13 | import utils.time 14 | import scala.sys.process._ 15 | 16 | import quantum._ 17 | import quantum.circuit.Syntax._ 18 | import quantum.schrodinger._ 19 | import quantum.schrodinger.Matrix._ 20 | import quantum.schrodinger.gate.{Gate, _} 21 | 22 | abstract class AbsMat 23 | abstract class AbsVec 24 | 25 | abstract class DslDriverPy[A:Manifest,B:Manifest] extends DslSnippet[A,B] with DslExp { q => 26 | val codegen = new DslGen { val IR: q.type = q } 27 | lazy val (code, statics) = { 28 | val source = new java.io.ByteArrayOutputStream() 29 | val statics = codegen.emitSource[A,B](wrapper, "Snippet", new java.io.PrintStream(source)) 30 | (source.toString, statics) 31 | } 32 | lazy val f = { 33 | val (c1,s1) = (code,statics); 34 | val out = new java.io.PrintStream("snippet.py") 35 | out.println(code) 36 | out.close 37 | val pb: ProcessBuilder = s"python3 snippet.py" 38 | (a: A) => { pb.lines.foreach(Console.println _) } 39 | } 40 | def precompile: Unit = f 41 | def precompileSilently: Unit = lms.core.utils.devnull(f) 42 | def eval(x: A): Unit = { val f1 = f; time("eval")(f1(x)) } 43 | } 44 | 45 | abstract class StagedHLSchrodinger extends DslDriverPy[Int, Unit] with SchrodingerInterpreter { q => 46 | override val codegen = new DslGen { 47 | val IR: q.type = q 48 | lazy val prelude = """import time 49 | import numpy as np 50 | from numpy import kron, matmul, array, eye, absolute, real, imag, log2, zeros, sqrt 51 | SWAP = array([ 52 | [1, 0, 0, 0], 53 | [0, 0, 1, 0], 54 | [0, 1, 0, 0], 55 | [0, 0, 0, 1] 56 | ]) 57 | CNOT = array([ 58 | [1, 0, 0, 0], 59 | [0, 1, 0, 0], 60 | [0, 0, 0, 1], 61 | [0, 0, 1, 0] 62 | ]) 63 | isq2 = 1.0 / (2.0 ** 0.5) 64 | H = isq2 * array([ 65 | [1, 1], 66 | [1, -1], 67 | ]) 68 | def Id(i): return eye(i, dtype=complex) 69 | def Init(i): 70 | s = zeros(2 ** i, dtype=complex) 71 | s[0] = 1 72 | return s 73 | def print_binary(index, size_sqrt): 74 | bin_format = f"0{int(log2(size_sqrt**2))}b" 75 | binary = format(index, bin_format) 76 | print(f"{binary}", end="") 77 | def print_result(arr, size): 78 | print("[", end="") 79 | for i in range(size): 80 | if absolute(real(arr[i])) < 1e-18 and imag(arr[i]) == 0.0: 81 | continue 82 | print(arr[i], end="") 83 | print("|", end="") 84 | print_binary(i, sqrt(size)) 85 | print("⟩", end="") 86 | if i < size - 1: 87 | print(", ", end="") 88 | print("]") 89 | """ 90 | lazy val initInput: String = "" 91 | lazy val procOutput: String = "" 92 | 93 | override def remap(m: Manifest[_]) = { 94 | if (m.runtimeClass.getName.endsWith("AbsVec")) "AbsVec" 95 | else if (m.runtimeClass.getName.endsWith("AbsMat")) "AbsMat" 96 | else super.remap(m) 97 | } 98 | 99 | override def shallow(n: Node): Unit = n match { 100 | case n @ Node(s, "kron", List(x, y), _) => es"kron($x, $y)" 101 | case n @ Node(s, "matvecprod", List(x, y), _) => es"matmul($x, $y)" 102 | case n @ Node(s, "init-state", List(Backend.Const(i: Int)), _) => es"Init($i)" 103 | case n @ Node(s, "H", _, _) => es"H" 104 | case n @ Node(s, "SWAP", _, _) => es"SWAP" 105 | case n @ Node(s, "CNOT", _, _) => es"CNOT" 106 | case _ => super.shallow(n) 107 | } 108 | override def traverse(n: Node): Unit = n match { 109 | case n @ Node(s, "copy", List(from, to), _) => 110 | esln"$to = $from" 111 | case n @ Node(s,"P",List(x),_) => esln"print_result($x, 2 ** $size)" 112 | case _ => super.traverse(n) 113 | } 114 | override def quoteBlockP(prec: Int)(f: => Unit) = { 115 | def wraper(numStms: Int, l: Option[Node], y: Block)(f: => Unit) = { 116 | val paren = numStms == 0 && l.map(n => precedence(n) < prec).getOrElse(false) 117 | if (paren) emit("(") //else if (numStms > 0) emitln("{") 118 | f 119 | if (y.res != Const(())) { shallow(y.res) } 120 | emit(quoteEff(y.eff)) 121 | if (paren) emit(")") //else if (numStms > 0) emit("\n}") 122 | } 123 | withWraper(wraper _)(f) 124 | } 125 | override def emitValDef(n: Node): Unit = { 126 | if (dce.live(n.n)) emit(s"${quote(n.n)} = "); 127 | shallow(n); emitln() 128 | } 129 | override def emitAll(g: Graph, name: String)(m1: Manifest[_], m2: Manifest[_]): Unit = { 130 | val ng = init(g) 131 | val arg = quote(g.block.in.head) 132 | val stt = dce.statics.toList.map(quoteStatic).mkString(", ") 133 | val (ms1, ms2) = (remap(m1), remap(m2)) 134 | emitln(prelude) 135 | emitln("#############") 136 | emitln("start_time = time.time()") 137 | quoteBlock(apply(ng)) 138 | emitln("print(\"--- %s seconds ---\" % (time.time() - start_time))") 139 | } 140 | } 141 | 142 | implicit class AbsMatOps(x: Rep[AbsMat]) { 143 | def ⊗(y: Rep[AbsMat]): Rep[AbsMat] = 144 | Wrap[AbsMat](Adapter.g.reflectWrite("kron", Unwrap(x), Unwrap(y))(Adapter.CTRL)) 145 | def *(y: Rep[AbsVec]): Rep[AbsVec] = 146 | Wrap[AbsVec](Adapter.g.reflectWrite("matvecprod", Unwrap(x), Unwrap(y))(Adapter.CTRL)) 147 | } 148 | 149 | type State = Rep[AbsVec] 150 | lazy val state: State = Wrap[AbsVec](Adapter.g.reflectWrite("init-state", Unwrap(unit(size)))(Adapter.CTRL)) 151 | 152 | def op(g: Gate, i: Int): Unit = { 153 | val iLeft = Wrap[AbsMat](Adapter.g.reflect("Id", Unwrap(unit(pow(2, i).toInt)))) 154 | val iRight = Wrap[AbsMat](Adapter.g.reflect("Id", Unwrap(unit(pow(2, size-i-g.arity).toInt)))) 155 | val gate = Wrap[AbsMat](Adapter.g.reflectWrite(g.id)(Adapter.CTRL)) 156 | val buf = iLeft ⊗ gate ⊗ iRight * state 157 | Adapter.g.reflectWrite("copy", Unwrap(buf), Unwrap(state))(Adapter.CTRL) 158 | } 159 | 160 | def snippet(n: Rep[Int]): Rep[Unit] = { 161 | circuit 162 | println(state) 163 | } 164 | 165 | def evalCircuit: Unit = eval(0) 166 | } 167 | 168 | object TestStagedHLSchrodinger { 169 | def main(args: Array[String]): Unit = { 170 | 171 | val driver = new StagedHLSchrodinger { 172 | val size = 4 173 | def circuit: Unit = { 174 | // Simon's problem 175 | H(0) 176 | H(1) 177 | SWAP(0) // swap 0 and 1 178 | CNOT(1) // CNOT(1, 2) 179 | SWAP(2) // swap 2 and 3 180 | CNOT(1) // CNOT(1, 2) 181 | SWAP(0) 182 | SWAP(1) 183 | CNOT(2) 184 | SWAP(1) 185 | CNOT(1) 186 | H(0) 187 | H(1) 188 | } 189 | } 190 | 191 | driver.evalCircuit 192 | } 193 | } 194 | -------------------------------------------------------------------------------- /src/main/scala/schrodinger/StagedSchrodinger.scala: -------------------------------------------------------------------------------- 1 | package quantum.schrodinger.staged 2 | 3 | import math.pow 4 | 5 | import lms.core._ 6 | import lms.core.stub._ 7 | import lms.core.Backend._ 8 | import lms.core.virtualize 9 | import lms.macros.SourceContext 10 | import lms.thirdparty.CLibs 11 | import lms.thirdparty.CCodeGenLibs 12 | 13 | import quantum._ 14 | import quantum.circuit.Syntax._ 15 | import quantum.schrodinger._ 16 | import quantum.schrodinger.Matrix._ 17 | import quantum.schrodinger.gate.{Gate, _} 18 | 19 | abstract class StagedSchrodinger extends DslDriverCPP[Array[Complex], Array[Complex]] with ComplexOps with SchrodingerInterpreter { q => 20 | override val codegen = new QCodeGen with CppCodeGen_Complex { 21 | registerHeader("") 22 | 23 | val IR: q.type = q 24 | override lazy val initInput: String = s""" 25 | | Complex* input = (Complex*) malloc(${pow(2, size)} * sizeof(Complex)); 26 | | input[0] = {1, 0}; 27 | |""".stripMargin 28 | override lazy val procOutput: String = s"printResult(input, ${pow(2, size)});"; 29 | override lazy val prelude = """ 30 | |using namespace std::chrono; 31 | |typedef struct Complex { double re; double im; } Complex; 32 | |void printComplex(Complex* c) { 33 | | if (c->im == 0.0) { printf("%.3f", c->re); } 34 | | else { printf("%.3f + %.3fi", c->re, c->im); } 35 | |} 36 | |void printBinary(uint64_t n, size_t size) { 37 | | for (int i = size - 1; i >= 0; i--) { 38 | | int shifted = n >> i; 39 | | int bit = shifted & 1; 40 | | printf("%d", bit); 41 | | } 42 | |} 43 | |void printResult(Complex arr[], size_t size) { 44 | | printf("["); 45 | | for (int i = 0; i < size; i++) { 46 | | if ((arr+i)->re == 0.0 && (arr+i)->im == 0.0) continue; 47 | | printComplex(arr+i); 48 | | printf("|"); 49 | | printBinary(i, sqrt(size)); 50 | | printf("⟩"); 51 | | if (i < size - 1) { printf(", "); } 52 | | } 53 | | printf("]\n"); 54 | |} 55 | """.stripMargin 56 | override def traverse(n: Node): Unit = n match { 57 | case n @ Node(s, "copy", List(from, to), _) => 58 | esln"$to = $from;" 59 | case _ => super.traverse(n) 60 | } 61 | } 62 | override val compilerCommand = "g++ -std=c++20 -O3" 63 | override val sourceFile = "snippet.cpp" 64 | override val executable = "./snippet" 65 | 66 | def unrollIf(c: Boolean, r: Range) = new { 67 | def foreach(f: Rep[Int] => Rep[Unit]) = { 68 | if (c) for (j <- (r.start until r.end): Range) f(j) 69 | else for (j <- (r.start until r.end): Rep[Range]) f(j) 70 | } 71 | } 72 | 73 | def matVecProd(a0: Array[Array[Complex]], v: Rep[Array[Complex]], des: Rep[Array[Complex]]): Unit = { 74 | val n = a0.length 75 | val a = staticData(a0) 76 | for (i <- (0 until n): Range) { 77 | des(i) = 0.0 78 | val sparse = false //a0(i).count(_ != (0: Complex)) < 0.5 * a0(i).size 79 | // System.out.println(s"sparsity: ${a0(i).toList} $sparse") 80 | for (j <- unrollIf(sparse, 0 until a0(0).size)) { 81 | des(i) = des(i) + a(i).apply(j) * v(j) 82 | } 83 | } 84 | } 85 | 86 | def sizeof(s: String): Int = s match { 87 | case "int" => 4 88 | case "int64" => 8 89 | case "double" => 8 90 | case "Complex" => sizeof("double") * 2 91 | } 92 | 93 | type State = Rep[Array[Complex]] 94 | lazy val buf = NewArray[Complex](pow(2, size).toInt) 95 | var state: State = _ 96 | 97 | @virtualize 98 | def op(g: Gate, i: Int): Unit = { 99 | val iLeft = Matrix.identity(pow(2, i).toInt) 100 | val iRight = Matrix.identity(pow(2, size - i - g.arity).toInt) 101 | matVecProd(iLeft ⊗ g.m ⊗ iRight, state, buf) 102 | // XXX: can we eliminate this copy? Could alternate buf and state 103 | buf.copyToArray(state, 0, pow(2, size).toInt * sizeof("Complex")) 104 | } 105 | 106 | def snippet(input: State): State = { 107 | state = input 108 | circuit 109 | state 110 | } 111 | 112 | def evalCircuit: Unit = eval(scala.Array()) 113 | } 114 | 115 | object TestStagedSchrodinger { 116 | def main(args: Array[String]): Unit = { 117 | val driver = new StagedSchrodinger { 118 | val size = 4 119 | def circuit: Unit = { 120 | // H(0) 121 | // CNOT(0) 122 | // S(0) 123 | // T(0) 124 | // Simon's problem 125 | H(0) 126 | H(1) 127 | SWAP(0) // swap 0 and 1 128 | CNOT(1) // CNOT(1, 2) 129 | SWAP(2) // swap 2 and 3 130 | CNOT(1) // CNOT(1, 2) 131 | SWAP(0) 132 | SWAP(1) 133 | CNOT(2) 134 | SWAP(1) 135 | CNOT(1) 136 | H(0) 137 | H(1) 138 | } 139 | } 140 | println(driver.code) 141 | driver.evalCircuit 142 | } 143 | } 144 | -------------------------------------------------------------------------------- /src/main/scala/schrodinger/UnstagedSchrodinger.scala: -------------------------------------------------------------------------------- 1 | package quantum.schrodinger 2 | 3 | import math.pow 4 | import quantum.utils._ 5 | import quantum.schrodinger.gate._ 6 | import quantum.schrodinger.Matrix._ 7 | 8 | // Unstaged Schrodinger-style simulation 9 | 10 | abstract class UnstagedSchrodinger(val size: Int) extends SchrodingerInterpreter { 11 | type State = Array[Complex] 12 | 13 | var state: State = Matrix.zerosVec(pow(2, size).toInt) 14 | state(0) = 1 // all deterministically zero 15 | 16 | def evalCircuit: Unit = { 17 | circuit 18 | prettyPrintSummary 19 | } 20 | 21 | def setState(s: State) = { 22 | assert(state.size == s.size, "incompatible size"); 23 | state = s 24 | } 25 | 26 | def op(g: Gate, i: Int) = { 27 | // println(pow(2, i).toInt) 28 | val iLeft = Matrix.identity(pow(2, i).toInt) 29 | // println(iLeft.pPrint) 30 | // println(pow(2, size - i - g.arity).toInt) 31 | val iRight = Matrix.identity(pow(2, size - i - g.arity).toInt) 32 | // println(iRight.pPrint) 33 | state = iLeft ⊗ g.m ⊗ iRight * state 34 | // tiling, auto vec 35 | } 36 | 37 | def summary: List[(String, Complex)] = { 38 | state.toList.zipWithIndex 39 | .map({ case (s, i) => 40 | val bin = Integer.toBinaryString(i) 41 | ("0" * (size - bin.length) + bin, s) 42 | }) 43 | .filter(_._2 != (0: Complex)) 44 | } 45 | 46 | def prettyPrintSummary: Unit = { 47 | summary.foreach { case (s, d) => 48 | println(s"${d.prettyPrint}|$s⟩") 49 | } 50 | } 51 | } 52 | 53 | object TestUnstagedSchrodinger { 54 | def main(args: Array[String]): Unit = { 55 | val q = new UnstagedSchrodinger(4) { 56 | def circuit: Unit = { 57 | H(0) 58 | H(1) 59 | SWAP(0) // swap 0 and 1 60 | CNOT(1) // CNOT(1, 2) 61 | SWAP(2) // swap 2 and 3 62 | CNOT(1) // CNOT(1, 2) 63 | SWAP(0) 64 | SWAP(1) 65 | CNOT(2) 66 | SWAP(1) 67 | CNOT(1) 68 | H(0) 69 | H(1) 70 | } 71 | } 72 | 73 | Utils.time { q.evalCircuit } 74 | } 75 | } 76 | -------------------------------------------------------------------------------- /src/test/scala/MatrixTest.scala: -------------------------------------------------------------------------------- 1 | package quantum.schrodinger 2 | 3 | import math.pow 4 | import org.scalatest.funsuite.AnyFunSuite 5 | 6 | import Matrix._ 7 | 8 | class MatrixTest extends AnyFunSuite { 9 | test("Identity") { 10 | val m1: Matrix = Array( 11 | Array(1, 0, 0), 12 | Array(0, 1, 0), 13 | Array(0, 0, 1) 14 | ) 15 | assert(Matrix.identity(3).flatten.toList == m1.flatten.toList) 16 | } 17 | 18 | test("MatMul") { 19 | val m1: Matrix = Array( 20 | Array(1, 0, 1), 21 | Array(2, 1, 1), 22 | Array(0, 1, 1), 23 | Array(1, 1, 2) 24 | ) 25 | val m2: Matrix = Array( 26 | Array(1, 2, 1), 27 | Array(2, 3, 1), 28 | Array(4, 2, 2) 29 | ) 30 | val m3: Matrix = Array( 31 | Array(5, 4, 3), 32 | Array(8, 9, 5), 33 | Array(6, 5, 3), 34 | Array(11, 9, 6) 35 | ) 36 | assert((m1 * m2).flatten.toList == m3.flatten.toList) 37 | } 38 | 39 | test("Kronecker") { 40 | val m1: Matrix = Array( 41 | Array(1, 2), 42 | Array(3, 4) 43 | ) 44 | val m2: Matrix = Array( 45 | Array(0, 5), 46 | Array(6, 7) 47 | ) 48 | val m3: Matrix = Array( 49 | Array(0, 5, 0, 10), 50 | Array(6, 7, 12, 14), 51 | Array(0, 15, 0, 20), 52 | Array(18, 21, 24, 28) 53 | ) 54 | 55 | assert((m1 ⊗ m2).flatten.toList == m3.flatten.toList) 56 | 57 | val m4: Matrix = Array( 58 | Array(1, -4, 7), 59 | Array(-2, 3, 3) 60 | ) 61 | val m5: Matrix = Array( 62 | Array(8, -9, -6, 5), 63 | Array(1, -3, -4, 7), 64 | Array(2, 8, -8, -3), 65 | Array(1, 2, -5, -1) 66 | ) 67 | val m6: Matrix = Array( 68 | Array(8, -9, -6, 5, -32, 36, 24, -20, 56, -63, -42, 35), 69 | Array(1, -3, -4, 7, -4, 12, 16, -28, 7, -21, -28, 49), 70 | Array(2, 8, -8, -3, -8, -32, 32, 12, 14, 56, -56, -21), 71 | Array(1, 2, -5, -1, -4, -8, 20, 4, 7, 14, -35, -7), 72 | Array(-16, 18, 12, -10, 24, -27, -18, 15, 24, -27, -18, 15), 73 | Array(-2, 6, 8, -14, 3, -9, -12, 21, 3, -9, -12, 21), 74 | Array(-4, -16, 16, 6, 6, 24, -24, -9, 6, 24, -24, -9), 75 | Array(-2, -4, 10, 2, 3, 6, -15, -3, 3, 6, -15, -3) 76 | ) 77 | assert((m4 ⊗ m5).flatten.toList == m6.flatten.toList) 78 | } 79 | } 80 | -------------------------------------------------------------------------------- /src/test/scala/Schrodinger.scala: -------------------------------------------------------------------------------- 1 | package quantum.schrodinger 2 | 3 | import math.pow 4 | import org.scalatest.funsuite.AnyFunSuite 5 | 6 | import quantum.schrodinger.Matrix._ 7 | import quantum.schrodinger.gate.{Gate, _} 8 | 9 | class SchrodingerTest extends AnyFunSuite { 10 | def checkEq(s1: Array[Complex], s2: Array[Complex]): Boolean = { 11 | s1.zip(s2).forall { case (c1, c2) => c1 ≈ c2 } 12 | } 13 | def checkEq(s1: List[(String, Complex)], s2: List[(String, Complex)]): Boolean = { 14 | s1.zip(s2).forall { case (c1, c2) => c1._1 == c2._1 && c1._2 ≈ c2._2 } 15 | } 16 | 17 | test("EPR") { 18 | val s = QState(2) 19 | s.H(0) 20 | s.CNOT(0) 21 | assert(s.state.toList == List[Complex](Gate.isq2, 0, 0, Gate.isq2)) 22 | } 23 | 24 | test("S&T") { 25 | val s = QState(2) 26 | s.H(0) 27 | s.CNOT(0) 28 | s.S(0) 29 | s.T(0) 30 | assert(checkEq(s.state, Array[Complex](Gate.isq2, 0, 0, Complex(-0.5, 0.5)))) 31 | } 32 | 33 | test("Simon") { 34 | // Note: because CNOT only works on two adjacent wires and stores the result 35 | // into the second wire, we have to swap the wires before applying CNOT to 36 | // two non-adjacent wires. 37 | val s = QState(4) 38 | s.H(0) 39 | s.H(1) 40 | s.SWAP(0) // swap 0 and 1 41 | s.CNOT(1) // CNOT(1, 2) 42 | s.SWAP(2) // swap 2 and 3 43 | s.CNOT(1) // CNOT(1, 2) 44 | s.SWAP(0) 45 | s.SWAP(1) 46 | s.CNOT(2) 47 | s.SWAP(1) 48 | s.CNOT(1) 49 | s.H(0) 50 | s.H(1) 51 | assert( 52 | checkEq( 53 | s.summary, 54 | List[(String, Complex)]( 55 | ("0000", 0.5), 56 | ("0011", 0.5), 57 | ("1100", 0.5), 58 | ("1111", -0.5) 59 | ) 60 | ) 61 | ) 62 | } 63 | 64 | test("Toffoli") { 65 | var s = QState(3) 66 | s.CCNOT(0) 67 | assert(checkEq(s.summary, List[(String, Complex)](("000", 1.0)))) 68 | 69 | s = QState(3) 70 | val b110 = Matrix.zerosVec(pow(2, 3).toInt) 71 | b110(6) = 1 72 | s.setState(b110) 73 | s.CCNOT(0) 74 | assert(checkEq(s.summary, List[(String, Complex)](("111", 1.0)))) 75 | 76 | s = QState(3) 77 | val b111 = Matrix.zerosVec(pow(2, 3).toInt) 78 | b111(7) = 1 79 | s.setState(b111) 80 | s.CCNOT(0) 81 | assert(checkEq(s.summary, List[(String, Complex)](("110", 1.0)))) 82 | } 83 | } 84 | --------------------------------------------------------------------------------