├── .gitignore ├── README.md ├── build.sbt ├── project ├── build.properties └── plugins.sbt ├── scalastyle-config.xml ├── scalastyle-test-config.xml └── src ├── main └── scala │ ├── VPQC │ ├── CommonParameters.scala │ ├── KeccakCore.scala │ ├── KeccakParameters.scala │ ├── ModularArithmetic.scala │ ├── NTT.scala │ ├── NTTParameters.scala │ ├── PQCCoprocessor.scala │ ├── PQCDecode.scala │ ├── PQCExu.scala │ ├── SamplerParameters.scala │ ├── Samplers.scala │ └── VectorRegister.scala │ └── utility │ ├── ShiftRegs.scala │ ├── SyncFifo.scala │ └── SyncRam.scala └── test └── scala └── VPQC ├── KeccakTest.scala ├── NTTTest.scala ├── SamplersTest.scala └── procTest.scala /.gitignore: -------------------------------------------------------------------------------- 1 | ### Project Specific stuff 2 | test_run_dir/* 3 | *.fir 4 | *.anno.json 5 | *.v 6 | ### XilinxISE template 7 | # intermediate build files 8 | *.bgn 9 | *.bit 10 | *.bld 11 | *.cmd_log 12 | *.drc 13 | *.ll 14 | *.lso 15 | *.msd 16 | *.msk 17 | *.ncd 18 | *.ngc 19 | *.ngd 20 | *.ngr 21 | *.pad 22 | *.par 23 | *.pcf 24 | *.prj 25 | *.ptwx 26 | *.rbb 27 | *.rbd 28 | *.stx 29 | *.syr 30 | *.twr 31 | *.twx 32 | *.unroutes 33 | *.ut 34 | *.xpi 35 | *.xst 36 | *_bitgen.xwbt 37 | *_envsettings.html 38 | *_map.map 39 | *_map.mrp 40 | *_map.ngm 41 | *_map.xrpt 42 | *_ngdbuild.xrpt 43 | *_pad.csv 44 | *_pad.txt 45 | *_par.xrpt 46 | *_summary.html 47 | *_summary.xml 48 | *_usage.xml 49 | *_xst.xrpt 50 | 51 | # project-wide generated files 52 | *.gise 53 | par_usage_statistics.html 54 | usage_statistics_webtalk.html 55 | webtalk.log 56 | webtalk_pn.xml 57 | 58 | # generated folders 59 | iseconfig/ 60 | xlnx_auto_0_xdb/ 61 | xst/ 62 | _ngo/ 63 | _xmsgs/ 64 | ### Eclipse template 65 | *.pydevproject 66 | .metadata 67 | .gradle 68 | bin/ 69 | tmp/ 70 | *.tmp 71 | *.bak 72 | *.swp 73 | *~.nib 74 | local.properties 75 | .settings/ 76 | .loadpath 77 | 78 | # Eclipse Core 79 | .project 80 | 81 | # External tool builders 82 | .externalToolBuilders/ 83 | 84 | # Locally stored "Eclipse launch configurations" 85 | *.launch 86 | 87 | # CDT-specific 88 | .cproject 89 | 90 | # JDT-specific (Eclipse Java Development Tools) 91 | .classpath 92 | 93 | # Java annotation processor (APT) 94 | .factorypath 95 | 96 | # PDT-specific 97 | .buildpath 98 | 99 | # sbteclipse plugin 100 | .target 101 | 102 | # TeXlipse plugin 103 | .texlipse 104 | ### C template 105 | # Object files 106 | *.o 107 | *.ko 108 | *.obj 109 | *.elf 110 | 111 | # Precompiled Headers 112 | *.gch 113 | *.pch 114 | 115 | # Libraries 116 | *.lib 117 | *.a 118 | *.la 119 | *.lo 120 | 121 | # Shared objects (inc. Windows DLLs) 122 | *.dll 123 | *.so 124 | *.so.* 125 | *.dylib 126 | 127 | # Executables 128 | *.exe 129 | *.out 130 | *.app 131 | *.i*86 132 | *.x86_64 133 | *.hex 134 | 135 | # Debug files 136 | *.dSYM/ 137 | ### SBT template 138 | # Simple Build Tool 139 | # http://www.scala-sbt.org/release/docs/Getting-Started/Directories.html#configuring-version-control 140 | 141 | target/ 142 | lib_managed/ 143 | src_managed/ 144 | project/boot/ 145 | .history 146 | .cache 147 | ### Emacs template 148 | # -*- mode: gitignore; -*- 149 | *~ 150 | \#*\# 151 | /.emacs.desktop 152 | /.emacs.desktop.lock 153 | *.elc 154 | auto-save-list 155 | tramp 156 | .\#* 157 | 158 | # Org-mode 159 | .org-id-locations 160 | *_archive 161 | 162 | # flymake-mode 163 | *_flymake.* 164 | 165 | # eshell files 166 | /eshell/history 167 | /eshell/lastdir 168 | 169 | # elpa packages 170 | /elpa/ 171 | 172 | # reftex files 173 | *.rel 174 | 175 | # AUCTeX auto folder 176 | /auto/ 177 | 178 | # cask packages 179 | .cask/ 180 | ### Vim template 181 | [._]*.s[a-w][a-z] 182 | [._]s[a-w][a-z] 183 | *.un~ 184 | Session.vim 185 | .netrwhist 186 | *~ 187 | ### JetBrains template 188 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio 189 | 190 | *.iml 191 | 192 | ## Directory-based project format: 193 | .idea/ 194 | # if you remove the above rule, at least ignore the following: 195 | 196 | # User-specific stuff: 197 | # .idea/workspace.xml 198 | # .idea/tasks.xml 199 | # .idea/dictionaries 200 | 201 | # Sensitive or high-churn files: 202 | # .idea/dataSources.ids 203 | # .idea/dataSources.xml 204 | # .idea/sqlDataSources.xml 205 | # .idea/dynamic.xml 206 | # .idea/uiDesigner.xml 207 | 208 | # Gradle: 209 | # .idea/gradle.xml 210 | # .idea/libraries 211 | 212 | # Mongo Explorer plugin: 213 | # .idea/mongoSettings.xml 214 | 215 | ## File-based project format: 216 | *.ipr 217 | *.iws 218 | 219 | ## Plugin-specific files: 220 | 221 | # IntelliJ 222 | /out/ 223 | 224 | # mpeltonen/sbt-idea plugin 225 | .idea_modules/ 226 | 227 | # JIRA plugin 228 | atlassian-ide-plugin.xml 229 | 230 | # Crashlytics plugin (for Android Studio and IntelliJ) 231 | com_crashlytics_export_strings.xml 232 | crashlytics.properties 233 | crashlytics-build.properties 234 | ### C++ template 235 | # Compiled Object files 236 | *.slo 237 | *.lo 238 | *.o 239 | *.obj 240 | 241 | # Precompiled Headers 242 | *.gch 243 | *.pch 244 | 245 | # Compiled Dynamic libraries 246 | *.so 247 | *.dylib 248 | *.dll 249 | 250 | # Fortran module files 251 | *.mod 252 | 253 | # Compiled Static libraries 254 | *.lai 255 | *.la 256 | *.a 257 | *.lib 258 | 259 | # Executables 260 | *.exe 261 | *.out 262 | *.app 263 | ### OSX template 264 | .DS_Store 265 | .AppleDouble 266 | .LSOverride 267 | 268 | # Icon must end with two \r 269 | Icon 270 | 271 | # Thumbnails 272 | ._* 273 | 274 | # Files that might appear in the root of a volume 275 | .DocumentRevisions-V100 276 | .fseventsd 277 | .Spotlight-V100 278 | .TemporaryItems 279 | .Trashes 280 | .VolumeIcon.icns 281 | 282 | # Directories potentially created on remote AFP share 283 | .AppleDB 284 | .AppleDesktop 285 | Network Trash Folder 286 | Temporary Items 287 | .apdisk 288 | ### Xcode template 289 | # Xcode 290 | # 291 | # gitignore contributors: remember to update Global/Xcode.gitignore, Objective-C.gitignore & Swift.gitignore 292 | 293 | ## Build generated 294 | build/ 295 | DerivedData 296 | 297 | ## Various settings 298 | *.pbxuser 299 | !default.pbxuser 300 | *.mode1v3 301 | !default.mode1v3 302 | *.mode2v3 303 | !default.mode2v3 304 | *.perspectivev3 305 | !default.perspectivev3 306 | xcuserdata 307 | 308 | ## Other 309 | *.xccheckout 310 | *.moved-aside 311 | *.xcuserstate 312 | ### Scala template 313 | *.class 314 | *.log 315 | 316 | # sbt specific 317 | .cache 318 | .history 319 | .lib/ 320 | dist/* 321 | target/ 322 | lib_managed/ 323 | src_managed/ 324 | project/boot/ 325 | project/plugins/project/ 326 | 327 | # Scala-IDE specific 328 | .scala_dependencies 329 | .worksheet 330 | ### Java template 331 | *.class 332 | 333 | # Mobile Tools for Java (J2ME) 334 | .mtj.tmp/ 335 | 336 | # Package Files # 337 | *.jar 338 | *.war 339 | *.ear 340 | 341 | # virtual machine crash logs, see http://www.java.com/en/download/help/error_hotspot.xml 342 | hs_err_pid* 343 | 344 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ### VPQC 2 | This work present a vector processor for Ring-LWE and Module-LWE schemes in post-quantum cryptography 3 | and it is implemented using Chisel, an agile hardware construction language developed by UC Berkeley. 4 | The instruction set of the processor is based on the customized extension of RISC-V, so it could 5 | be served as a coprocessor for RISC-V cores. The customed instructions are defined as follows. 6 | 7 | | **Instr \ Bit field** | **31:25** | **24:20** | **19:15** | **14:12** | **11:7** | **6:0** | **Description** | 8 | | ----------------------- | :-----: | :-----: | :-----: | :-----: | :----: | :-----: | :-------------------------------------: | 9 | | fetchData | 0000000 | - | - | - | vd | 0001011 | v[vd] <- random data from prefetch FIFO | 10 | | binomial sampling | 0000001 | vs2 | vs1 | - | vd | 0001011 | v[vd] <- vector binomial sample (v[vs1],v[vs2]) | 11 | | rejection sampling | 0000010 | vs2 | vs1 | - | vd | 0001011 | v[vd] <- vector rejection sample (v[vs1],v[vs2]) | 12 | | DIT butterfly | 0000011 | vs2 | vs1 | - | - | 0001011 | (v[vs1], v[vs2]) <- vector DIT butterfly (v[vs1], v[vs2]) | 13 | | DIF butterfly | 0000100 | vs2 | vs1 | - | - | 0001011 | (v[vs1], v[vs2]) <- vector DIF butterfly (v[vs1], v[vs2]) | 14 | | csrrw | 0000101 | csridx | vs1 | - | - | 0001011 | swap value (csr[csridx], r[vs1]) | 15 | | csrrwi | 0000110 | csridx | imm[9:5] | - | imm[4:0] | 0001011 | csr[csridx] <- imm | 16 | | vld | 0000111 | - | - | - | vd | 0001011 | v[vd] <- memory (addr) | 17 | | vst | 0001000 | - | vs1 | - | - | 0001011 | memory (addr) <- v[vs1] | 18 | | vadd | 0001001 | vs2 | vs1 | - | vd | 0001011 | v[vd] <- vector addition (v[vs1], v[vs2]) | 19 | | vsub | 0001010 | vs2 | vs1 | - | vd | 0001011 | v[vd] <- vector subtraction (v[vs1], v[vs2]) | 20 | | vmul | 0001011 | vs2 | vs1 | - | vd | 0001011 | v[vd] <- vector multiplication (v[vs1], v[vs2]) | 21 | 22 | This project follows the style of file organizations in *sbt* (scala build tools), see 23 | https://www.scala-sbt.org/ for more information about *sbt*. 24 | 25 | The source codes are located in */src/main/scala/* and the test codes are located in */src/test/scala/VPQC/*. 26 | 27 | #### performance of NTT and sampling process 28 | [IntelliJ IDEA](https://www.jetbrains.com/idea/) is a powerful IDE for scala, and it is quite convenient to use IntelliJ IDEA to build this project. 29 | 30 | Run the object *TestTopTestSimple* in */src/test/scala/VPQC/procTest*, and the result shows: 31 | 32 | When the dimension = 256, binomial/rejection sampling can be finished in 411 cycles and 33 | NTT can be finished in 45 cycles. 34 | 35 | Or you could run the object *TestTopTestMain* to obtain more information about the waveform generated by Verilator. 36 | #### For more information about Chisel, see below: 37 | https://www.chisel-lang.org/ 38 | #### For project template for Chisel, please refer to: 39 | https://github.com/freechipsproject/chisel-template 40 | #### Our source code also used a third-party chisel library for Finite State Machine, please refer to: 41 | https://github.com/dai-pch/FSM 42 | -------------------------------------------------------------------------------- /build.sbt: -------------------------------------------------------------------------------- 1 | // See README.md for license details. 2 | 3 | def scalacOptionsVersion(scalaVersion: String): Seq[String] = { 4 | Seq() ++ { 5 | // If we're building with Scala > 2.11, enable the compile option 6 | // switch to support our anonymous Bundle definitions: 7 | // https://github.com/scala/bug/issues/10047 8 | CrossVersion.partialVersion(scalaVersion) match { 9 | case Some((2, scalaMajor: Long)) if scalaMajor < 12 => Seq() 10 | case _ => Seq("-Xsource:2.11") 11 | } 12 | } 13 | } 14 | 15 | def javacOptionsVersion(scalaVersion: String): Seq[String] = { 16 | Seq() ++ { 17 | // Scala 2.12 requires Java 8. We continue to generate 18 | // Java 7 compatible code for Scala 2.11 19 | // for compatibility with old clients. 20 | CrossVersion.partialVersion(scalaVersion) match { 21 | case Some((2, scalaMajor: Long)) if scalaMajor < 12 => 22 | Seq("-source", "1.7", "-target", "1.7") 23 | case _ => 24 | Seq("-source", "1.8", "-target", "1.8") 25 | } 26 | } 27 | } 28 | 29 | name := "chisel-module-template" 30 | 31 | version := "3.1.0" 32 | 33 | scalaVersion := "2.11.12" 34 | 35 | crossScalaVersions := Seq("2.11.12", "2.12.4") 36 | 37 | resolvers ++= Seq( 38 | Resolver.sonatypeRepo("snapshots"), 39 | Resolver.sonatypeRepo("releases") 40 | ) 41 | 42 | // Provide a managed dependency on X if -DXVersion="" is supplied on the command line. 43 | val defaultVersions = Map( 44 | "chisel3" -> "3.1.+", 45 | "chisel-iotesters" -> "1.2.+" 46 | ) 47 | 48 | libraryDependencies ++= (Seq("chisel3","chisel-iotesters").map { 49 | dep: String => "edu.berkeley.cs" %% dep % sys.props.getOrElse(dep + "Version", defaultVersions(dep)) }) 50 | 51 | libraryDependencies += "org.daipch" %% "fsm" % "0.3.+" 52 | 53 | scalacOptions ++= scalacOptionsVersion(scalaVersion.value) 54 | 55 | javacOptions ++= javacOptionsVersion(scalaVersion.value) 56 | -------------------------------------------------------------------------------- /project/build.properties: -------------------------------------------------------------------------------- 1 | sbt.version = 1.3.2 2 | -------------------------------------------------------------------------------- /project/plugins.sbt: -------------------------------------------------------------------------------- 1 | logLevel := Level.Warn -------------------------------------------------------------------------------- /scalastyle-config.xml: -------------------------------------------------------------------------------- 1 | 2 | Scalastyle standard configuration 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | No lines ending with a ; 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | |\|\||&&|:=|<>|<=|>=|!=|===|<<|>>|##|unary_(~|\-%?|!))$]]> 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | -------------------------------------------------------------------------------- /scalastyle-test-config.xml: -------------------------------------------------------------------------------- 1 | 2 | Scalastyle configuration for Chisel3 unit tests 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | No lines ending with a ; 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | |\|\||&&|:=|<>|<=|>=|!=|===|<<|>>|##|unary_(~|\-%?|!))$]]> 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | -------------------------------------------------------------------------------- /src/main/scala/VPQC/CommonParameters.scala: -------------------------------------------------------------------------------- 1 | 2 | package VPQC 3 | 4 | trait HasCommonParameters { 5 | val ML = 32 6 | val DataWidth = 16 7 | } 8 | 9 | //trait HasPQCCSR{ 10 | // val CSRWIDTH : Int = 64 11 | // val CSRLENGTH : Int = 6 // 这里修改CSR寄存器的个数 12 | // 13 | // // 把每个CSR的名字和编号写在下面,必须以CSR_开头 14 | // val csrBarretu = Input(UInt((DataWidth + 2).W)) 15 | // val csrBound = Input(UInt((DataWidth * 2).W)) 16 | // val csrBinomialk = Input(UInt(3.W)) 17 | // val csrModulusq = Input(UInt(DataWidth.W)) 18 | // val csrModulusLen = Input(UInt(5.W)) 19 | //} 20 | 21 | trait HasPQCInstructions { 22 | val INSTR_QUANTITY = 12 // number of custom instructions 23 | 24 | /* instruction encoding style */ 25 | // func7 rs2 rs1 reserved rd 00010 11 26 | 27 | // PQC 28 | val INSTR_FETCHRN = 0 29 | val INSTR_SAMPLEBINOMIAL = 1 30 | val INSTR_SAMPLEREJECTION = 2 31 | val INSTR_BUTTERFLY = 3 32 | val INSTR_IBUTTERFLY = 4 33 | val INSTR_CSRRW = 5 34 | val INSTR_CSRRWI = 6 35 | val INSTR_VLD = 7 36 | val INSTR_VST = 8 37 | 38 | val INSTR_VADD = 9 39 | val INSTR_VSUB = 10 40 | val INSTR_VMUL = 11 41 | // 42 | // // LOAD STORE 43 | // val INSTR_VLD = 7 44 | // val INSTR_VLDS = 8 45 | // val INSTR_VLDX = 9 46 | // val INSTR_VST = 10 47 | // val INSTR_VSTS = 11 48 | // val INSTR_VSTX = 12 49 | 50 | } 51 | -------------------------------------------------------------------------------- /src/main/scala/VPQC/KeccakCore.scala: -------------------------------------------------------------------------------- 1 | 2 | package VPQC 3 | 4 | import chisel3._ 5 | import chisel3.util._ 6 | import utility.SyncFifo 7 | import fsm._ 8 | 9 | /** 10 | * 11 | * Description : implement Keccakc, that is Keccak-p[1600,24] 12 | * look FIPS PUB 202 for more information 13 | * 14 | **/ 15 | 16 | /** 17 | * 18 | * index = (y,x,z) 19 | * 20 | **/ 21 | class stateArray extends Bundle { 22 | val s = Vec(5, Vec(5, (UInt(64.W)))) 23 | } 24 | 25 | 26 | class ProcIO extends Bundle 27 | with HasKeccakParameters { 28 | val in = Input(new stateArray()) 29 | val out = Output(new stateArray()) 30 | } 31 | 32 | class Proc1 extends Module 33 | with HasKeccakParameters { 34 | val io = IO(new ProcIO) 35 | 36 | val state = io.in.s 37 | val columnsXorSqueez = WireInit(VecInit(Seq.fill(5)(VecInit(Seq.fill(64)(false.B))))) 38 | // step 1 39 | for (x <- 0 until 5) { 40 | for (z <- 0 until 64) { 41 | columnsXorSqueez(x)(z) := (0 until 5).map(y => state(y)(x)(z)).reduce(_ ^ _) 42 | } 43 | } 44 | val columnsXorInter = WireInit(VecInit(Seq.fill(5)(VecInit(Seq.fill(64)(false.B))))) 45 | // step 2 46 | for (x <- 0 until 5) { 47 | for (z <- 0 until 64) { 48 | columnsXorInter(x)(z) := columnsXorSqueez(mod(x - 1, 5))(z) ^ columnsXorSqueez(mod(x + 1, 5))(mod(z - 1, 64)) 49 | } 50 | } 51 | // step 3 52 | for (y <- 0 until 5) { 53 | for (x <- 0 until 5) { 54 | io.out.s(y)(x) := state(y)(x) ^ columnsXorInter(x).asUInt() 55 | } 56 | } 57 | } 58 | 59 | class Proc2 extends Module 60 | with HasKeccakParameters { 61 | val io = IO(new ProcIO()) 62 | 63 | val state = io.in.s 64 | for (y <- 0 until 5) { 65 | for (x <- 0 until 5) { 66 | var shamt = mod(proc2param(y)(x), 64) 67 | if (shamt == 0) { 68 | io.out.s(y)(x) := state(y)(x) 69 | } else { 70 | io.out.s(y)(x) := Cat(state(y)(x)(63 - shamt, 0), state(y)(x)(63, 64 - shamt)) 71 | } 72 | } 73 | } 74 | } 75 | 76 | class Proc3 extends Module 77 | with HasKeccakParameters { 78 | val io = IO(new ProcIO()) 79 | 80 | val state = io.in.s 81 | for (y <- 0 until 5) { 82 | for (x <- 0 until 5) { 83 | io.out.s(y)(x) := state(x)(mod(x + 3 * y, 5)) 84 | } 85 | } 86 | } 87 | 88 | class Proc4 extends Module 89 | with HasKeccakParameters { 90 | val io = IO(new ProcIO()) 91 | 92 | val state = io.in.s 93 | for (y <- 0 until 5) { 94 | for (x <- 0 until 5) { 95 | io.out.s(y)(x) := (~state(y)(mod(x + 1, 5)) & state(y)(mod(x + 2, 5))) ^ state(y)(x) 96 | } 97 | } 98 | } 99 | 100 | class Proc5(ir: Int) extends Module 101 | with HasKeccakParameters { 102 | val io = IO(new ProcIO()) 103 | 104 | val state = io.in.s 105 | val RC = WireInit(VecInit(Seq.fill(64)(false.B))) 106 | for (j <- 0 to 6) { 107 | RC((1 << j) - 1) := rc(j + ir * 7).asUInt() 108 | } 109 | val nextState = Wire(new stateArray()) 110 | nextState.s := state 111 | nextState.s(0)(0) := state(0)(0) ^ RC.asUInt() 112 | io.out.s := nextState.s 113 | } 114 | 115 | class KeccakCoreIO extends Bundle 116 | with HasKeccakParameters { 117 | val valid = Input(Bool()) 118 | val seed = Input(new stateArray()) 119 | val seedWrite = Input(Bool()) 120 | val prngNumber = Output(new stateArray()) 121 | val done = Output(Bool()) 122 | val initialized = Output(Bool()) 123 | } 124 | 125 | /** 126 | * cnt valid state done 127 | * 0 0 0 128 | * 0 1 129 | * 1 done care 130 | * . 131 | * . 132 | * . 133 | * 24 end point write to final value 134 | * 0 1 135 | */ 136 | class KeccakCore extends Module 137 | with HasKeccakParameters { 138 | val io = IO(new KeccakCoreIO()) 139 | 140 | // count control 141 | val cnt = RegInit(0.U(5.W)) 142 | val busy = cnt =/= 0.U 143 | val countBegin = io.valid || busy 144 | val state = Wire(Vec(5, new stateArray())) 145 | val stateReg = Reg((new stateArray())) 146 | val completeFlag = cnt === RoundsNum.asUInt() 147 | val initialized = RegInit(false.B) 148 | 149 | when(countBegin) { 150 | cnt := Mux(completeFlag, 0.U, cnt + 1.U) 151 | } 152 | 153 | when(io.seedWrite) { 154 | stateReg := io.seed 155 | initialized := true.B 156 | } .elsewhen(cnt =/= 0.U) { 157 | stateReg := state(4) 158 | } 159 | 160 | // module instantion 161 | val proc1 = Module(new Proc1()) 162 | proc1.io.in := Mux(cnt === 0.U, 0.U.asTypeOf(new stateArray()), stateReg) 163 | state(0) := proc1.io.out 164 | val proc2 = Module(new Proc2()) 165 | proc2.io.in := state(0) 166 | state(1) := proc2.io.out 167 | val proc3 = Module(new Proc3()) 168 | proc3.io.in := state(1) 169 | state(2) := proc3.io.out 170 | val proc4 = Module(new Proc4()) 171 | proc4.io.in := state(2) 172 | state(3) := proc4.io.out 173 | val proc5s = (0 until RoundsNum).map(i => Module(new Proc5(i)).io) 174 | for (i <- 0 until RoundsNum) { 175 | proc5s(i).in := state(3) 176 | } 177 | state(4) := MuxLookup(cnt, proc5s(0).out, (2 to RoundsNum).map(i => (i.asUInt() -> proc5s(i-1).out))) 178 | 179 | io.prngNumber := stateReg 180 | io.done := RegNext(completeFlag) 181 | io.initialized := initialized 182 | } 183 | 184 | class KeccakWithFifo extends Module 185 | with HasCommonParameters 186 | with HasKeccakParameters { 187 | val io = IO(new Bundle { 188 | val valid = Input(Bool()) 189 | val seed = Input(new stateArray()) 190 | val seedWrite = Input(Bool()) 191 | val prn = Output(Vec(ML, UInt(DataWidth.W))) 192 | val done = Output(Bool()) 193 | val busy = Output(Bool()) 194 | val wb = Output(Bool()) 195 | }) 196 | 197 | val prng = Module(new KeccakCore) 198 | prng.io.seed := io.seed 199 | prng.io.seedWrite := io.seedWrite 200 | val fifo = Module(new SyncFifo(dep = 32, dataType = UInt((ML * DataWidth).W))) 201 | 202 | // fill the buffer if it is not full 203 | prng.io.valid := !fifo.io.writeFull && prng.io.initialized 204 | // write prn to buffer 205 | fifo.io.writeData := prng.io.prngNumber.s.asUInt() 206 | fifo.io.writeEnable := prng.io.done 207 | fifo.io.readEnable := io.valid 208 | 209 | // get random number 210 | for (i <- 0 until ML) { 211 | io.prn(i) := fifo.io.readData(DataWidth * i + DataWidth - 1, DataWidth * i) 212 | } 213 | io.wb := true.B 214 | io.done := false.B 215 | io.busy := false.B 216 | 217 | val fsm = InstanciateFSM(new FSM{ 218 | entryState(stateName = "Idle") 219 | .act{ 220 | io.done := false.B 221 | io.busy := false.B 222 | } 223 | .when(io.valid && fifo.io.readEmpty).transferTo(destName = "Wait") 224 | .when(io.valid && !fifo.io.readEmpty).transferTo(destName = "Read") 225 | 226 | state(stateName = "Wait") 227 | .act{ 228 | io.done := false.B 229 | io.busy := true.B 230 | } 231 | .when(!fifo.io.readEmpty).transferTo(destName = "Read") 232 | 233 | state(stateName = "Read") 234 | .act{ 235 | io.done := true.B 236 | io.busy := false.B 237 | } 238 | .when(io.valid && fifo.io.readEmpty).transferTo(destName = "Wait") 239 | .when(io.valid && !fifo.io.readEmpty).transferTo(destName = "Read") 240 | .otherwise.transferToEnd 241 | }) 242 | 243 | } 244 | 245 | // for synthesis 246 | class KeccakNoFifo extends Module 247 | with HasCommonParameters 248 | with HasKeccakParameters { 249 | val io = IO(new Bundle { 250 | val valid = Input(Bool()) 251 | val seed = Input(new stateArray()) 252 | val seedWrite = Input(Bool()) 253 | val prn = Output(Vec(ML, UInt(DataWidth.W))) 254 | val done = Output(Bool()) 255 | val busy = Output(Bool()) 256 | val wb = Output(Bool()) 257 | // from/to syncfifo 258 | val writeData = Output(UInt((ML * DataWidth).W)) 259 | val writeEnable = Output(Bool()) 260 | val readEnable = Output(Bool()) 261 | val readData = Input(UInt((ML * DataWidth).W)) 262 | val readEmpty = Input(Bool()) 263 | val writeFull = Input(Bool()) 264 | }) 265 | val prng = Module(new KeccakCore) 266 | prng.io.seed := io.seed 267 | prng.io.seedWrite := io.seedWrite 268 | prng.io.valid := !io.writeFull && prng.io.initialized 269 | 270 | // write prn to buffer 271 | io.writeData := prng.io.prngNumber.s.asUInt() 272 | io.writeEnable := prng.io.done 273 | io.readEnable := io.valid 274 | 275 | // get random number 276 | for (i <- 0 until ML) { 277 | io.prn(i) := io.readData(DataWidth * i + DataWidth - 1, DataWidth * i) 278 | } 279 | io.wb := true.B 280 | io.done := false.B 281 | io.busy := false.B 282 | 283 | val fsm = InstanciateFSM(new FSM{ 284 | entryState(stateName = "Idle") 285 | .act{ 286 | io.done := false.B 287 | io.busy := false.B 288 | } 289 | .when(io.valid && io.readEmpty).transferTo(destName = "Wait") 290 | .when(io.valid && !io.readEmpty).transferTo(destName = "Read") 291 | 292 | state(stateName = "Wait") 293 | .act{ 294 | io.done := false.B 295 | io.busy := true.B 296 | } 297 | .when(!io.readEmpty).transferTo(destName = "Read") 298 | 299 | state(stateName = "Read") 300 | .act{ 301 | io.done := true.B 302 | io.busy := false.B 303 | } 304 | .when(io.valid && io.readEmpty).transferTo(destName = "Wait") 305 | .when(io.valid && !io.readEmpty).transferTo(destName = "Read") 306 | .otherwise.transferToEnd 307 | }) 308 | 309 | // val s = prng.io.prngNumber.s.asUInt() 310 | // for (i <- 0 until ML) { 311 | // io.prn(i) := s(DataWidth * i + DataWidth - 1, DataWidth * i) 312 | // } 313 | // io.wb := true.B 314 | // io.done := RegNext(io.valid) 315 | // io.busy := false.B 316 | } 317 | 318 | object elaborateKeccak extends App { 319 | chisel3.Driver.execute(args, () => new KeccakNoFifo) 320 | } 321 | -------------------------------------------------------------------------------- /src/main/scala/VPQC/KeccakParameters.scala: -------------------------------------------------------------------------------- 1 | 2 | package VPQC 3 | 4 | import chisel3._ 5 | 6 | trait HasKeccakParameters { 7 | val ArrayLength = 1600 8 | val RoundsNum = 24 9 | 10 | def mod(a: Int, q: Int): Int = { 11 | var res = 0 12 | if (a >= 0) { 13 | res = a % q 14 | } else { 15 | res = a + q 16 | } 17 | res 18 | } 19 | def range(a: Int, upBound: Int, downBound: Int) : Int = { 20 | assert(upBound < 32) 21 | assert(downBound >= 0) 22 | return (a >> downBound) & (0xffffffff >>> (31-upBound+downBound)) 23 | } 24 | 25 | val proc2param = Array( Array(0, 1, 190, 28, 91), 26 | Array(36, 300, 6, 55, 276), 27 | Array(3, 10, 171, 153, 231), 28 | Array(105, 45, 15, 21, 136), 29 | Array(210, 66, 253, 120, 78) 30 | ) 31 | def rc(t: Int): Int = { 32 | var res = 0 33 | var R: Int = 1 34 | var R8: Int = 0 35 | if (t % 255 == 0) { 36 | res = 1 37 | } 38 | else { 39 | for (i <- 1 to t % 255) { 40 | R = R << 1 41 | R8 = range(R, 8, 8) 42 | R = R ^ (R8 | (R8 << 4) | (R8 << 5) | (R8 << 6)) 43 | R = R & 0xff 44 | } 45 | res = R & 1 46 | } 47 | res 48 | } 49 | } 50 | -------------------------------------------------------------------------------- /src/main/scala/VPQC/ModularArithmetic.scala: -------------------------------------------------------------------------------- 1 | 2 | package VPQC 3 | import chisel3._ 4 | import chisel3.util._ 5 | 6 | class ModIO extends Bundle 7 | with HasCommonParameters { 8 | // two operands 9 | val a = Input(UInt(DataWidth.W)) 10 | val b = Input(UInt(DataWidth.W)) 11 | // modulus 12 | val m = Input(UInt(DataWidth.W)) 13 | // reduced value 14 | val c = Output(UInt(DataWidth.W)) 15 | } 16 | 17 | class ModMulIO extends ModIO { 18 | // precomputed value 19 | val u = Input(UInt((DataWidth + 2).W)) 20 | val n = Input(UInt(5.W)) 21 | } 22 | 23 | class ModAdd extends Module 24 | with HasCommonParameters { 25 | val io = IO(new ModIO) 26 | 27 | val ctmp1 = Wire(UInt((DataWidth + 1).W)) 28 | ctmp1 := io.a +& io.b 29 | val ctmp2 = Wire(UInt((DataWidth + 1).W)) 30 | ctmp2 := ctmp1 -& io.m 31 | val flag = ctmp2(DataWidth) 32 | io.c := Mux(flag, ctmp1(DataWidth - 1, 0), ctmp2(DataWidth - 1, 0)) 33 | } 34 | 35 | object ModAdd { 36 | def apply(a: UInt, b: UInt, m: UInt): UInt = { 37 | val inst = Module(new ModAdd()) 38 | inst.io.a := a 39 | inst.io.b := b 40 | inst.io.m := m 41 | inst.io.c 42 | } 43 | } 44 | 45 | class ModSub extends Module 46 | with HasCommonParameters { 47 | val io = IO(new ModIO) 48 | 49 | val ctmp1 = Wire(UInt((DataWidth + 1).W)) 50 | ctmp1 := io.a -& io.b 51 | val ctmp2 = Wire(UInt((DataWidth + 1).W)) 52 | ctmp2 := ctmp1 +& io.m 53 | val flag = ctmp1(DataWidth) 54 | io.c := Mux(flag, ctmp2(DataWidth - 1, 0), ctmp1(DataWidth - 1, 0)) 55 | } 56 | 57 | object ModSub { 58 | def apply(a: UInt, b: UInt, m: UInt): UInt = { 59 | val inst = Module(new ModSub()) 60 | inst.io.a := a 61 | inst.io.b := b 62 | inst.io.m := m 63 | inst.io.c 64 | } 65 | } 66 | 67 | class ModMul extends Module 68 | with HasCommonParameters { 69 | val io = IO(new ModMulIO) 70 | 71 | val a = io.a 72 | val b = io.b 73 | val m = io.m 74 | val u = io.u 75 | val n = io.n 76 | 77 | // process 1 78 | val c = a * b 79 | val shift = c >> (n - 2.U) 80 | 81 | // process 2 82 | val mul1 = u * shift.asUInt() 83 | 84 | // process 3 85 | val qGuess = mul1 >> (n + 3.U) 86 | val mul2 = io.m * qGuess.asUInt() 87 | 88 | // process 4 89 | val z = c - mul2 90 | io.c := Mux(z < io.m, z, z - io.m) 91 | } 92 | 93 | object ModMul { 94 | def apply(a: UInt, b: UInt, m: UInt, u: UInt, n: UInt): UInt = { 95 | val inst = Module(new ModMul()) 96 | inst.io.a := a 97 | inst.io.b := b 98 | inst.io.m := m 99 | inst.io.u := u 100 | inst.io.n := n 101 | inst.io.c 102 | } 103 | } 104 | 105 | // Note: This Module can be eliminated by sharing the 106 | // hardware resources in NTT Module 107 | class VectorArithIO extends Bundle 108 | with HasCommonParameters { 109 | val valid = Input(Bool()) 110 | // config 111 | // val m = Input(UInt(DataWidth.W)) 112 | // val u = Input(UInt((DataWidth + 2).W)) 113 | // val n = Input(UInt(5.W)) 114 | val csrs = Input(new CSRIO) 115 | // input 116 | val addA = Input(Vec(ML, UInt(DataWidth.W))) 117 | val addB = Input(Vec(ML, UInt(DataWidth.W))) 118 | val subA = Input(Vec(ML, UInt(DataWidth.W))) 119 | val subB = Input(Vec(ML, UInt(DataWidth.W))) 120 | val mulA = Input(Vec(ML, UInt(DataWidth.W))) 121 | val mulB = Input(Vec(ML, UInt(DataWidth.W))) 122 | // output 123 | val addRes = Output(Vec(ML, UInt(DataWidth.W))) 124 | val subRes = Output(Vec(ML, UInt(DataWidth.W))) 125 | val mulRes = Output(Vec(ML, UInt(DataWidth.W))) 126 | 127 | val done = Output(Bool()) 128 | val busy = Output(Bool()) 129 | val wb = Output(Bool()) 130 | } 131 | 132 | class VectorArith extends Module 133 | with HasCommonParameters { 134 | val io = IO(new VectorArithIO()) 135 | 136 | 137 | val adds = VecInit(Seq.fill(ML)(Module(new ModAdd()).io)) 138 | val subs = VecInit(Seq.fill(ML)(Module(new ModSub()).io)) 139 | val muls = VecInit(Seq.fill(ML)(Module(new ModMul()).io)) 140 | for (i <- 0 until ML) { 141 | adds(i).a := io.addA(i) 142 | adds(i).b := io.addB(i) 143 | adds(i).m := io.csrs.csrModulusq 144 | io.addRes(i) := adds(i).c 145 | subs(i).a := io.subA(i) 146 | subs(i).b := io.subB(i) 147 | subs(i).m := io.csrs.csrModulusq 148 | io.subRes(i) := subs(i).c 149 | muls(i).a := io.mulA(i) 150 | muls(i).b := io.mulB(i) 151 | muls(i).m := io.csrs.csrModulusq 152 | muls(i).u := io.csrs.csrBarretu 153 | muls(i).n := io.csrs.csrModulusLen 154 | io.mulRes(i) := muls(i).c 155 | } 156 | 157 | io.done := RegNext(io.valid) 158 | io.busy := false.B 159 | io.wb := true.B 160 | } 161 | 162 | object elaborateMod extends App { 163 | chisel3.Driver.execute(args, () => new VectorArith) 164 | } -------------------------------------------------------------------------------- /src/main/scala/VPQC/NTT.scala: -------------------------------------------------------------------------------- 1 | 2 | package VPQC 3 | 4 | import chisel3._ 5 | import chisel3.util._ 6 | import scala.math._ 7 | import utility._ 8 | import fsm._ 9 | 10 | /* 11 | class MRFixIO extends Bundle 12 | with HasNTTCommonParameters { 13 | val a = Input(UInt((DataWidth * 2).W)) 14 | val ar = Output(UInt(DataWidth.W)) 15 | } 16 | 17 | class MRFix extends Module 18 | with HasNTTCommonParameters 19 | with HasMRParameters { 20 | val io = IO(new MRFixIO()) 21 | 22 | // 16bit 23 | val tmp1 = io.a >> 12.U 24 | // 20bit 25 | val tmp2 = (tmp1 << 1.U) +& (tmp1 << 3.U) 26 | // 32bit 27 | val mul = Wire(UInt(32.W)) 28 | mul := (tmp2 << 12.U) + (tmp2 << 8.U) + (tmp2 << 4.U) + (tmp1 << 3.U) - tmp1 29 | // 15bit 30 | val qGuess = mul >> 17.U 31 | // 28bit 32 | val qM = Wire(UInt((DataWidth * 2).W)) 33 | qM := qGuess + (qGuess << 12.U) + (qGuess << 13.U) 34 | 35 | val z = io.a - qM 36 | io.ar := Mux(z < MRq.asUInt(), z, z-MRq.asUInt()) 37 | } 38 | 39 | object MRFix { 40 | def apply(a: UInt, ar: UInt): Module = { 41 | val inst = Module(new MRFix()) 42 | inst.io.a := a 43 | ar := inst.io.ar 44 | inst 45 | } 46 | def apply(a: UInt): UInt = { 47 | val inst = Module(new MRFix()) 48 | inst.io.a := a 49 | inst.io.ar 50 | } 51 | } 52 | */ 53 | 54 | // add configurability 55 | 56 | class MRIO extends Bundle 57 | with HasCommonParameters { 58 | val a = Input(UInt((DataWidth * 2).W)) 59 | val n = Input(UInt(5.W)) 60 | val m = Input(UInt(DataWidth.W)) 61 | val u = Input(UInt((DataWidth + 2).W)) 62 | val ar = Output(UInt(DataWidth.W)) 63 | } 64 | 65 | class MR extends Module 66 | with HasCommonParameters { 67 | val io = IO(new MRIO) 68 | 69 | val shift1 = io.a >> (io.n - 2.U) 70 | val mul1 = io.u * shift1.asUInt 71 | val qGuess = mul1 >> (io.n + 3.U) 72 | // TODO: check if pipeline is needed 73 | val mul2 = qGuess * io.m 74 | val z = io.a - mul2 75 | 76 | io.ar := Mux(z < io.m, z, z-io.m) 77 | } 78 | 79 | object MR { 80 | def apply(a: UInt, n: UInt, m: UInt, u: UInt): UInt = { 81 | val inst = Module(new MR()) 82 | inst.io.a := a 83 | inst.io.n := n 84 | inst.io.m := m 85 | inst.io.u := u 86 | inst.io.ar 87 | } 88 | } 89 | 90 | object elaborateMR extends App { 91 | chisel3.Driver.execute(args, () => new MR) 92 | } 93 | 94 | // no resource sharing version 95 | class ButterflyIO extends Bundle 96 | with HasCommonParameters { 97 | // input two data 98 | val a = Input(UInt(DataWidth.W)) 99 | val b = Input(UInt(DataWidth.W)) 100 | // twiddle factor 101 | val wn = Input(UInt(DataWidth.W)) 102 | // control whether DIT or DIF 103 | val mode = Input(Bool()) 104 | // config for ModMul 105 | val m = Input(UInt(DataWidth.W)) 106 | val u = Input(UInt((DataWidth + 2).W)) 107 | val n = Input(UInt(5.W)) 108 | // output two data 109 | val aout = Output(UInt(DataWidth.W)) 110 | val bout = Output(UInt(DataWidth.W)) 111 | } 112 | 113 | // DIT && DIF butterfly 114 | // 0 : DIT 115 | // 1 : DIF 116 | class Butterfly extends Module 117 | with HasCommonParameters { 118 | val io = IO(new ButterflyIO) 119 | 120 | val a = io.a 121 | val b = io.b 122 | val wn = io.wn 123 | 124 | // pre-process 125 | val aout1 = ModAdd(a, b, io.m) 126 | val bout1 = ModSub(a, b, io.m) 127 | val amux = Mux(io.mode, aout1, a) 128 | val bmux = Mux(io.mode, bout1, b) 129 | 130 | // post-process 131 | val amul = amux 132 | val bmul = ModMul(a = bmux, b = wn, m = io.m, u = io.u, n = io.n) 133 | 134 | val aout2 = ModAdd(amul, bmul, io.m) 135 | val bout2 = ModSub(amul, bmul, io.m) 136 | 137 | io.aout := Mux(io.mode, amul, aout2) 138 | io.bout := Mux(io.mode, bmul, bout2) 139 | } 140 | 141 | object elaborateButterfly extends App { 142 | chisel3.Driver.execute(args, () => new Butterfly) 143 | } 144 | 145 | // resource sharing version 146 | class ButterflyShareIO extends Bundle 147 | with HasCommonParameters { 148 | // input two data 149 | val a = Input(UInt(DataWidth.W)) 150 | val b = Input(UInt(DataWidth.W)) 151 | // twiddle factor 152 | val wn = Input(UInt(DataWidth.W)) 153 | // control whether DIT or DIF 154 | val mode = Input(Bool()) 155 | val vecValid = Input(Vec(3, Bool())) 156 | // config for ModMul 157 | val m = Input(UInt(DataWidth.W)) 158 | val u = Input(UInt((DataWidth + 2).W)) 159 | val n = Input(UInt(5.W)) 160 | // output two data 161 | val aout = Output(UInt(DataWidth.W)) 162 | val bout = Output(UInt(DataWidth.W)) 163 | } 164 | 165 | // DIT && DIF butterfly 166 | // 0 : DIT 167 | // 1 : DIF 168 | class ButterflyShare extends Module 169 | with HasCommonParameters { 170 | val io = IO(new ButterflyShareIO) 171 | 172 | val a = io.a 173 | val b = io.b 174 | val wn = io.wn 175 | 176 | // pre-process 177 | val aout1 = ModAdd(a, b, io.m) 178 | val bout1 = ModSub(a, b, io.m) 179 | val amux = Mux(io.mode, aout1, a) 180 | val bmux = Mux(io.mode, bout1, b) 181 | 182 | // post-process 183 | val mul1 = Mux(io.vecValid(2), a, bmux) 184 | val mul2 = Mux(io.vecValid(2), b, wn) 185 | val amul = amux 186 | val bmul = ModMul(a = mul1, b = mul2, m = io.m, u = io.u, n = io.n) 187 | 188 | val aout2 = ModAdd(amul, bmul, io.m) 189 | val bout2 = ModSub(amul, bmul, io.m) 190 | 191 | io.aout := Mux(io.vecValid(0), aout1, 192 | Mux(io.vecValid(1), bout1, 193 | Mux(io.vecValid(2), bmul, 194 | Mux(io.mode, amul, aout2)))) 195 | io.bout := Mux(io.mode, bmul, bout2) 196 | } 197 | 198 | // permutation network 199 | class PermNetIO extends Bundle 200 | with HasCommonParameters 201 | with HasNTTParameters { 202 | val in = Input(Vec(ButterflyNum * 2, UInt(DataWidth.W))) 203 | val config = Input(Vec(log2Ceil(ButterflyNum), Bool())) 204 | val out = Output(Vec(ButterflyNum * 2, UInt(DataWidth.W))) 205 | } 206 | class PermNetIn extends Module 207 | with HasCommonParameters 208 | with HasNTTParameters { 209 | val io = IO(new PermNetIO) 210 | 211 | def perm(split: Int): Vec[UInt] = { 212 | val perm = Wire(Vec(ButterflyNum * 2, UInt(DataWidth.W))) 213 | for(i <- 0 until ButterflyNum * 2 / split) { 214 | for (j <- 0 until split) { 215 | if (j % 2 == 0) { 216 | perm(i * split + j) := io.in(i * split + j / 2) 217 | } else { 218 | perm(i * split + j) := io.in(i * split + split / 2 + (j - 1) / 2) 219 | } 220 | } 221 | } 222 | perm 223 | } 224 | 225 | val multiPerm = Wire(Vec(log2Ceil(ButterflyNum), Vec(ButterflyNum * 2, UInt(DataWidth.W)))) 226 | for (i <- 0 until log2Ceil(ButterflyNum)) { 227 | multiPerm(i) := perm(pow(2, i + 2).toInt) 228 | } 229 | 230 | val out = Mux1H(io.config, multiPerm) 231 | io.out := Mux(!(io.config.reduce(_ || _)), io.in, out) 232 | } 233 | object elaboratePermNet extends App { 234 | chisel3.Driver.execute(args, () => new PermNetOut) 235 | } 236 | 237 | class PermNetOut extends Module 238 | with HasCommonParameters 239 | with HasNTTParameters { 240 | val io = IO(new PermNetIO) 241 | 242 | def perm(split: Int): Vec[UInt] = { 243 | val perm = Wire(Vec(ButterflyNum * 2, UInt(DataWidth.W))) 244 | for(i <- 0 until ButterflyNum * 2 / split) { 245 | for (j <- 0 until split) { 246 | if (j < split / 2) { 247 | perm(i * split + j) := io.in(i * split + j * 2) 248 | } else { 249 | perm(i * split + j) := io.in(i * split + 1 + (j - split / 2) * 2) 250 | } 251 | } 252 | } 253 | perm 254 | } 255 | 256 | val multiPerm = Wire(Vec(log2Ceil(ButterflyNum), Vec(ButterflyNum * 2, UInt(DataWidth.W)))) 257 | for (i <- 0 until log2Ceil(ButterflyNum)) { 258 | multiPerm(i) := perm(pow(2, i + 2).toInt) 259 | } 260 | 261 | val out = Mux1H(io.config, multiPerm) 262 | io.out := Mux(!(io.config.reduce(_ || _)), io.in, out) 263 | } 264 | 265 | class NTTIO extends Bundle 266 | with HasCommonParameters 267 | with HasNTTParameters { 268 | val valid = Input(Bool()) 269 | val mode = Input(Bool()) 270 | 271 | // ram write interface 272 | val wa = Input(UInt(log2Ceil(Dimension / ButterflyNum).W)) 273 | val di = Input(UInt((DataWidth * ButterflyNum).W)) 274 | val we = Input(Bool()) 275 | 276 | // from register 277 | val vectorReg1 = Input(Vec(ML, UInt(DataWidth.W))) 278 | val vectorReg2 = Input(Vec(ML, UInt(DataWidth.W))) 279 | 280 | // csr interface 281 | val csrs = Input(new CSRIO) 282 | 283 | // output 284 | val dataOut = Output(Vec(ButterflyNum * 2, UInt(DataWidth.W))) 285 | val done = Output(Bool()) 286 | val busy = Output(Bool()) 287 | val wb = Output(Bool()) 288 | // val wpos = Output(Bool()) 289 | } 290 | 291 | // This NTT Module does not apply resource sharing with vector arithmetic units 292 | class NTT extends Module 293 | with HasCommonParameters 294 | with HasNTTParameters { 295 | val io = IO(new NTTIO) 296 | 297 | 298 | io.wb := true.B 299 | io.done := RegNext(io.valid) 300 | io.busy := false.B 301 | 302 | // permutation in 303 | val permNet = Module(new PermNetIn) 304 | for (i <- 0 until ButterflyNum*2) { 305 | if (i < ButterflyNum) { 306 | permNet.io.in(i) := io.vectorReg1(i) 307 | } 308 | else { 309 | permNet.io.in(i) := io.vectorReg2(i - ButterflyNum) 310 | } 311 | } 312 | 313 | for (i <- 0 until log2Ceil(ButterflyNum)) { 314 | permNet.io.config(i) := false.B 315 | } 316 | 317 | // to support NTT and INTT : stage i -> stage n-1-i 318 | 319 | if (ButterflyNum == 4) { 320 | when(io.csrs.csrButterflyCtrl.stageCfg === 0.U) { 321 | 322 | } .elsewhen(io.csrs.csrButterflyCtrl.stageCfg === 1.U) { 323 | permNet.io.config(0) := true.B 324 | } .otherwise { 325 | permNet.io.config(1) := true.B 326 | } 327 | } else if (ButterflyNum == 8) { 328 | when(io.csrs.csrButterflyCtrl.stageCfg === 0.U) { 329 | 330 | } .elsewhen(io.csrs.csrButterflyCtrl.stageCfg === 1.U) { 331 | permNet.io.config(0) := true.B 332 | } .elsewhen(io.csrs.csrButterflyCtrl.stageCfg === 2.U) { 333 | permNet.io.config(1) := true.B 334 | } .otherwise { 335 | permNet.io.config(2) := true.B 336 | } 337 | } else if (ButterflyNum == 16) { 338 | when(io.csrs.csrButterflyCtrl.stageCfg === 0.U) { 339 | 340 | } .elsewhen(io.csrs.csrButterflyCtrl.stageCfg === 1.U) { 341 | permNet.io.config(0) := true.B 342 | } .elsewhen(io.csrs.csrButterflyCtrl.stageCfg === 2.U) { 343 | permNet.io.config(1) := true.B 344 | } .elsewhen(io.csrs.csrButterflyCtrl.stageCfg === 3.U) { 345 | permNet.io.config(2) := true.B 346 | } .otherwise { // split = 32 347 | permNet.io.config(3) := true.B 348 | } 349 | } else if (ButterflyNum == 32) { 350 | when(io.csrs.csrButterflyCtrl.stageCfg === 0.U) { 351 | 352 | } .elsewhen(io.csrs.csrButterflyCtrl.stageCfg === 1.U) { 353 | permNet.io.config(0) := true.B 354 | } .elsewhen(io.csrs.csrButterflyCtrl.stageCfg === 2.U) { 355 | permNet.io.config(1) := true.B 356 | } .elsewhen(io.csrs.csrButterflyCtrl.stageCfg === 3.U) { 357 | permNet.io.config(2) := true.B 358 | } .elsewhen(io.csrs.csrButterflyCtrl.stageCfg === 4.U) { // split = 32 359 | permNet.io.config(3) := true.B 360 | }.otherwise { 361 | permNet.io.config(4) := true.B 362 | } 363 | } else if (ButterflyNum == 64) { 364 | when(io.csrs.csrButterflyCtrl.stageCfg === 0.U) { 365 | 366 | } .elsewhen(io.csrs.csrButterflyCtrl.stageCfg === 1.U) { 367 | permNet.io.config(0) := true.B 368 | } .elsewhen(io.csrs.csrButterflyCtrl.stageCfg === 2.U) { 369 | permNet.io.config(1) := true.B 370 | } .elsewhen(io.csrs.csrButterflyCtrl.stageCfg === 3.U) { 371 | permNet.io.config(2) := true.B 372 | } .elsewhen(io.csrs.csrButterflyCtrl.stageCfg === 4.U) { 373 | permNet.io.config(3) := true.B 374 | }.elsewhen(io.csrs.csrButterflyCtrl.stageCfg === 5.U) { 375 | permNet.io.config(4) := true.B 376 | }.otherwise { 377 | permNet.io.config(5) := true.B 378 | } 379 | } 380 | 381 | // wn addr prepare 382 | val wnBaseAddr1 = Wire(UInt(log2Ceil(Dimension / ButterflyNum).W)) 383 | val wnBaseAddr2 = Wire(UInt(log2Ceil(Dimension / ButterflyNum).W)) 384 | // contain stage info 385 | wnBaseAddr1 := Mux(io.csrs.csrButterflyCtrl.stageCfg < log2Ceil(ButterflyNum).asUInt(), 386 | 0.U, 1.U << (io.csrs.csrButterflyCtrl.stageCfg - log2Ceil(ButterflyNum).asUInt())) 387 | // contain iter info 388 | wnBaseAddr2 := Mux(io.csrs.csrButterflyCtrl.stageCfg > log2Ceil(ButterflyNum).asUInt(), 389 | io.csrs.csrButterflyCtrl.iterCfg + wnBaseAddr1, wnBaseAddr1) 390 | 391 | 392 | // wn consts ram 393 | /** 394 | * 395 | * (1) {1 1} {1 1/4} {1 1/8 ..} {1 1/16 ... 7/16} 396 | * (2) 1 1/32 ... 15/32 397 | * (3) ... 398 | * 399 | * 32 line 1 1/1024 ... 511/1024 400 | * 401 | */ 402 | val twiddleRam = Module(new SyncRam(dep = Dimension / ButterflyNum, dw = DataWidth * ButterflyNum)) 403 | twiddleRam.io.re := io.valid 404 | twiddleRam.io.we := io.we 405 | twiddleRam.io.ra := wnBaseAddr2 406 | twiddleRam.io.wa := io.wa 407 | twiddleRam.io.di := io.di 408 | 409 | 410 | // butterfly PEs 411 | val PEs = VecInit(Seq.fill(ButterflyNum)(Module(new Butterfly()).io)) 412 | for (i <- 0 until ButterflyNum) { 413 | PEs(i).a := permNet.io.out(2 * i) 414 | PEs(i).b := permNet.io.out(2 * i + 1) 415 | 416 | PEs(i).wn := twiddleRam.io.dout(DataWidth * i + DataWidth - 1, DataWidth * i) 417 | PEs(i).mode := io.mode 418 | 419 | // wn assign 420 | if (ButterflyNum == 4) { 421 | when(io.csrs.csrButterflyCtrl.stageCfg < log2Ceil(ButterflyNum).asUInt()) { 422 | switch(io.csrs.csrButterflyCtrl.stageCfg) { 423 | is(0.U) { 424 | PEs(i).wn := twiddleRam.io.dout(DataWidth - 1, 0) 425 | } 426 | is(1.U) { 427 | PEs(i).wn := twiddleRam.io.dout(DataWidth * (i % 2 + 2) + DataWidth - 1, DataWidth * (i % 2 + 2)) 428 | } 429 | } 430 | } 431 | } else if (ButterflyNum == 8) { 432 | when(io.csrs.csrButterflyCtrl.stageCfg < log2Ceil(ButterflyNum).asUInt()) { 433 | switch(io.csrs.csrButterflyCtrl.stageCfg) { 434 | is(0.U) { 435 | PEs(i).wn := twiddleRam.io.dout(DataWidth - 1, 0) 436 | } 437 | is(1.U) { 438 | PEs(i).wn := twiddleRam.io.dout(DataWidth * (i % 2 + 2) + DataWidth - 1, DataWidth * (i % 2 + 2)) 439 | } 440 | is(2.U) { 441 | PEs(i).wn := twiddleRam.io.dout(DataWidth * (i % 4 + 4) + DataWidth - 1, DataWidth * (i % 4 + 4)) 442 | } 443 | } 444 | } 445 | } else if (ButterflyNum == 16) { 446 | when(io.csrs.csrButterflyCtrl.stageCfg < log2Ceil(ButterflyNum).asUInt()) { 447 | switch(io.csrs.csrButterflyCtrl.stageCfg) { 448 | is(0.U) { 449 | PEs(i).wn := twiddleRam.io.dout(DataWidth - 1, 0) 450 | } 451 | is(1.U) { 452 | PEs(i).wn := twiddleRam.io.dout(DataWidth * (i % 2 + 2) + DataWidth - 1, DataWidth * (i % 2 + 2)) 453 | } 454 | is(2.U) { 455 | PEs(i).wn := twiddleRam.io.dout(DataWidth * (i % 4 + 4) + DataWidth - 1, DataWidth * (i % 4 + 4)) 456 | } 457 | is(3.U) { 458 | PEs(i).wn := twiddleRam.io.dout(DataWidth * (i % 8 + 8) + DataWidth - 1, DataWidth * (i % 8 + 8)) 459 | } 460 | } 461 | } 462 | } else if (ButterflyNum == 32) { 463 | when(io.csrs.csrButterflyCtrl.stageCfg < log2Ceil(ButterflyNum).asUInt()) { 464 | switch(io.csrs.csrButterflyCtrl.stageCfg) { 465 | is(0.U) { 466 | PEs(i).wn := twiddleRam.io.dout(DataWidth - 1, 0) 467 | } 468 | is(1.U) { 469 | PEs(i).wn := twiddleRam.io.dout(DataWidth * (i % 2 + 2) + DataWidth - 1, DataWidth * (i % 2 + 2)) 470 | } 471 | is(2.U) { 472 | PEs(i).wn := twiddleRam.io.dout(DataWidth * (i % 4 + 4) + DataWidth - 1, DataWidth * (i % 4 + 4)) 473 | } 474 | is(3.U) { 475 | PEs(i).wn := twiddleRam.io.dout(DataWidth * (i % 8 + 8) + DataWidth - 1, DataWidth * (i % 8 + 8)) 476 | } 477 | is(4.U) { 478 | PEs(i).wn := twiddleRam.io.dout(DataWidth * (i % 16 + 16) + DataWidth - 1, DataWidth * (i % 16 + 16)) 479 | } 480 | } 481 | } 482 | } else if (ButterflyNum == 64) { 483 | when(io.csrs.csrButterflyCtrl.stageCfg < log2Ceil(ButterflyNum).asUInt()) { 484 | switch(io.csrs.csrButterflyCtrl.stageCfg) { 485 | is(0.U) { 486 | PEs(i).wn := twiddleRam.io.dout(DataWidth - 1, 0) 487 | } 488 | is(1.U) { 489 | PEs(i).wn := twiddleRam.io.dout(DataWidth * (i % 2 + 2) + DataWidth - 1, DataWidth * (i % 2 + 2)) 490 | } 491 | is(2.U) { 492 | PEs(i).wn := twiddleRam.io.dout(DataWidth * (i % 4 + 4) + DataWidth - 1, DataWidth * (i % 4 + 4)) 493 | } 494 | is(3.U) { 495 | PEs(i).wn := twiddleRam.io.dout(DataWidth * (i % 8 + 8) + DataWidth - 1, DataWidth * (i % 8 + 8)) 496 | } 497 | is(4.U) { 498 | PEs(i).wn := twiddleRam.io.dout(DataWidth * (i % 16 + 16) + DataWidth - 1, DataWidth * (i % 16 + 16)) 499 | } 500 | is(5.U) { 501 | PEs(i).wn := twiddleRam.io.dout(DataWidth * (i % 32 + 32) + DataWidth - 1, DataWidth * (i % 32 + 32)) 502 | } 503 | } 504 | } 505 | } 506 | 507 | PEs(i).n := io.csrs.csrModulusLen 508 | PEs(i).m := io.csrs.csrModulusq 509 | PEs(i).u := io.csrs.csrBarretu 510 | 511 | } 512 | 513 | // permutation out 514 | val permNetOut = Module(new PermNetOut) 515 | permNetOut.io.config := permNet.io.config 516 | for (i <- 0 until ButterflyNum) { 517 | permNetOut.io.in(2 * i) := PEs(i).aout 518 | permNetOut.io.in(2 * i + 1) := PEs(i).bout 519 | } 520 | io.dataOut := permNetOut.io.out 521 | } 522 | 523 | class NTTWithoutRam extends Module 524 | with HasCommonParameters 525 | with HasNTTParameters { 526 | val io = IO(new Bundle { 527 | val valid = Input(Bool()) 528 | val mode = Input(Bool()) 529 | 530 | // ram from outside 531 | val ra = Output(UInt(log2Ceil(Dimension / ButterflyNum).W)) 532 | val re = Output(Bool()) 533 | val di = Input(UInt((ButterflyNum * DataWidth).W)) 534 | 535 | // from register 536 | val vectorReg1 = Input(Vec(ML, UInt(DataWidth.W))) 537 | val vectorReg2 = Input(Vec(ML, UInt(DataWidth.W))) 538 | 539 | // csr interface 540 | val csrs = Input(new CSRIO) 541 | 542 | // output 543 | val dataOut = Output(Vec(ButterflyNum * 2, UInt(DataWidth.W))) 544 | val done = Output(Bool()) 545 | val busy = Output(Bool()) 546 | val wb = Output(Bool()) 547 | }) 548 | 549 | 550 | io.wb := true.B 551 | io.done := RegNext(io.valid) 552 | io.busy := false.B 553 | 554 | // permutation in 555 | val permNet = Module(new PermNetIn) 556 | for (i <- 0 until ButterflyNum*2) { 557 | if (i < ButterflyNum) { 558 | permNet.io.in(i) := io.vectorReg1(i) 559 | } 560 | else { 561 | permNet.io.in(i) := io.vectorReg2(i - ButterflyNum) 562 | } 563 | } 564 | 565 | // to support NTT and INTT : stage i -> stage n-1-i 566 | for (i <- 0 until log2Ceil(ButterflyNum)) { 567 | permNet.io.config(i) := false.B 568 | } 569 | if (ButterflyNum == 4) { 570 | when(io.csrs.csrButterflyCtrl.stageCfg === 0.U) { 571 | 572 | } .elsewhen(io.csrs.csrButterflyCtrl.stageCfg === 1.U) { 573 | permNet.io.config(0) := true.B 574 | } .otherwise { 575 | permNet.io.config(1) := true.B 576 | } 577 | } else if (ButterflyNum == 8) { 578 | when(io.csrs.csrButterflyCtrl.stageCfg === 0.U) { 579 | 580 | } .elsewhen(io.csrs.csrButterflyCtrl.stageCfg === 1.U) { 581 | permNet.io.config(0) := true.B 582 | } .elsewhen(io.csrs.csrButterflyCtrl.stageCfg === 2.U) { 583 | permNet.io.config(1) := true.B 584 | } .otherwise { 585 | permNet.io.config(2) := true.B 586 | } 587 | } else if (ButterflyNum == 16) { 588 | when(io.csrs.csrButterflyCtrl.stageCfg === 0.U) { 589 | 590 | } .elsewhen(io.csrs.csrButterflyCtrl.stageCfg === 1.U) { 591 | permNet.io.config(0) := true.B 592 | } .elsewhen(io.csrs.csrButterflyCtrl.stageCfg === 2.U) { 593 | permNet.io.config(1) := true.B 594 | } .elsewhen(io.csrs.csrButterflyCtrl.stageCfg === 3.U) { 595 | permNet.io.config(2) := true.B 596 | } .otherwise { // split = 32 597 | permNet.io.config(3) := true.B 598 | } 599 | } else if (ButterflyNum == 32) { 600 | when(io.csrs.csrButterflyCtrl.stageCfg === 0.U) { 601 | 602 | } .elsewhen(io.csrs.csrButterflyCtrl.stageCfg === 1.U) { 603 | permNet.io.config(0) := true.B 604 | } .elsewhen(io.csrs.csrButterflyCtrl.stageCfg === 2.U) { 605 | permNet.io.config(1) := true.B 606 | } .elsewhen(io.csrs.csrButterflyCtrl.stageCfg === 3.U) { 607 | permNet.io.config(2) := true.B 608 | } .elsewhen(io.csrs.csrButterflyCtrl.stageCfg === 4.U) { // split = 32 609 | permNet.io.config(3) := true.B 610 | }.otherwise { 611 | permNet.io.config(4) := true.B 612 | } 613 | } else if (ButterflyNum == 64) { 614 | when(io.csrs.csrButterflyCtrl.stageCfg === 0.U) { 615 | 616 | } .elsewhen(io.csrs.csrButterflyCtrl.stageCfg === 1.U) { 617 | permNet.io.config(0) := true.B 618 | } .elsewhen(io.csrs.csrButterflyCtrl.stageCfg === 2.U) { 619 | permNet.io.config(1) := true.B 620 | } .elsewhen(io.csrs.csrButterflyCtrl.stageCfg === 3.U) { 621 | permNet.io.config(2) := true.B 622 | } .elsewhen(io.csrs.csrButterflyCtrl.stageCfg === 4.U) { 623 | permNet.io.config(3) := true.B 624 | }.elsewhen(io.csrs.csrButterflyCtrl.stageCfg === 5.U) { 625 | permNet.io.config(4) := true.B 626 | }.otherwise { 627 | permNet.io.config(5) := true.B 628 | } 629 | } 630 | 631 | // wn addr prepare 632 | val wnBaseAddr1 = Wire(UInt(log2Ceil(Dimension / ButterflyNum).W)) 633 | val wnBaseAddr2 = Wire(UInt(log2Ceil(Dimension / ButterflyNum).W)) 634 | // contain stage info 635 | wnBaseAddr1 := Mux(io.csrs.csrButterflyCtrl.stageCfg < log2Ceil(ButterflyNum).asUInt(), 636 | 0.U, 1.U << (io.csrs.csrButterflyCtrl.stageCfg - log2Ceil(ButterflyNum).asUInt())) 637 | // contain iter info 638 | wnBaseAddr2 := Mux(io.csrs.csrButterflyCtrl.stageCfg > log2Ceil(ButterflyNum).asUInt(), 639 | io.csrs.csrButterflyCtrl.iterCfg + wnBaseAddr1, wnBaseAddr1) 640 | 641 | io.re := io.valid 642 | io.ra := wnBaseAddr2 643 | 644 | // butterfly PEs 645 | val PEs = VecInit(Seq.fill(ButterflyNum)(Module(new Butterfly()).io)) 646 | for (i <- 0 until ButterflyNum) { 647 | PEs(i).a := permNet.io.out(2 * i) 648 | PEs(i).b := permNet.io.out(2 * i + 1) 649 | // wn assign 650 | PEs(i).wn := io.di(DataWidth * i + DataWidth - 1, DataWidth * i) 651 | PEs(i).mode := io.mode 652 | if (ButterflyNum == 4) { 653 | when(io.csrs.csrButterflyCtrl.stageCfg < log2Ceil(ButterflyNum).asUInt()) { 654 | switch(io.csrs.csrButterflyCtrl.stageCfg) { 655 | is(0.U) { 656 | PEs(i).wn := io.di(DataWidth - 1, 0) 657 | } 658 | is(1.U) { 659 | PEs(i).wn := io.di(DataWidth * (i % 2 + 2) + DataWidth - 1, DataWidth * (i % 2 + 2)) 660 | } 661 | } 662 | } 663 | } else if (ButterflyNum == 8) { 664 | when(io.csrs.csrButterflyCtrl.stageCfg < log2Ceil(ButterflyNum).asUInt()) { 665 | switch(io.csrs.csrButterflyCtrl.stageCfg) { 666 | is(0.U) { 667 | PEs(i).wn := io.di(DataWidth - 1, 0) 668 | } 669 | is(1.U) { 670 | PEs(i).wn := io.di(DataWidth * (i % 2 + 2) + DataWidth - 1, DataWidth * (i % 2 + 2)) 671 | } 672 | is(2.U) { 673 | PEs(i).wn := io.di(DataWidth * (i % 4 + 4) + DataWidth - 1, DataWidth * (i % 4 + 4)) 674 | } 675 | } 676 | } 677 | } else if (ButterflyNum == 16) { 678 | when(io.csrs.csrButterflyCtrl.stageCfg < log2Ceil(ButterflyNum).asUInt()) { 679 | switch(io.csrs.csrButterflyCtrl.stageCfg) { 680 | is(0.U) { 681 | PEs(i).wn := io.di(DataWidth - 1, 0) 682 | } 683 | is(1.U) { 684 | PEs(i).wn := io.di(DataWidth * (i % 2 + 2) + DataWidth - 1, DataWidth * (i % 2 + 2)) 685 | } 686 | is(2.U) { 687 | PEs(i).wn := io.di(DataWidth * (i % 4 + 4) + DataWidth - 1, DataWidth * (i % 4 + 4)) 688 | } 689 | is(3.U) { 690 | PEs(i).wn := io.di(DataWidth * (i % 8 + 8) + DataWidth - 1, DataWidth * (i % 8 + 8)) 691 | } 692 | } 693 | } 694 | } else if (ButterflyNum == 32) { 695 | when(io.csrs.csrButterflyCtrl.stageCfg < log2Ceil(ButterflyNum).asUInt()) { 696 | switch(io.csrs.csrButterflyCtrl.stageCfg) { 697 | is(0.U) { 698 | PEs(i).wn := io.di(DataWidth - 1, 0) 699 | } 700 | is(1.U) { 701 | PEs(i).wn := io.di(DataWidth * (i % 2 + 2) + DataWidth - 1, DataWidth * (i % 2 + 2)) 702 | } 703 | is(2.U) { 704 | PEs(i).wn := io.di(DataWidth * (i % 4 + 4) + DataWidth - 1, DataWidth * (i % 4 + 4)) 705 | } 706 | is(3.U) { 707 | PEs(i).wn := io.di(DataWidth * (i % 8 + 8) + DataWidth - 1, DataWidth * (i % 8 + 8)) 708 | } 709 | is(4.U) { 710 | PEs(i).wn := io.di(DataWidth * (i % 16 + 16) + DataWidth - 1, DataWidth * (i % 16 + 16)) 711 | } 712 | } 713 | } 714 | } else if (ButterflyNum == 64) { 715 | when(io.csrs.csrButterflyCtrl.stageCfg < log2Ceil(ButterflyNum).asUInt()) { 716 | switch(io.csrs.csrButterflyCtrl.stageCfg) { 717 | is(0.U) { 718 | PEs(i).wn := io.di(DataWidth - 1, 0) 719 | } 720 | is(1.U) { 721 | PEs(i).wn := io.di(DataWidth * (i % 2 + 2) + DataWidth - 1, DataWidth * (i % 2 + 2)) 722 | } 723 | is(2.U) { 724 | PEs(i).wn := io.di(DataWidth * (i % 4 + 4) + DataWidth - 1, DataWidth * (i % 4 + 4)) 725 | } 726 | is(3.U) { 727 | PEs(i).wn := io.di(DataWidth * (i % 8 + 8) + DataWidth - 1, DataWidth * (i % 8 + 8)) 728 | } 729 | is(4.U) { 730 | PEs(i).wn := io.di(DataWidth * (i % 16 + 16) + DataWidth - 1, DataWidth * (i % 16 + 16)) 731 | } 732 | is(5.U) { 733 | PEs(i).wn := io.di(DataWidth * (i % 32 + 32) + DataWidth - 1, DataWidth * (i % 32 + 32)) 734 | } 735 | } 736 | } 737 | } 738 | 739 | PEs(i).n := io.csrs.csrModulusLen 740 | PEs(i).m := io.csrs.csrModulusq 741 | PEs(i).u := io.csrs.csrBarretu 742 | 743 | } 744 | 745 | // permutation out 746 | val permNetOut = Module(new PermNetOut) 747 | permNetOut.io.config := permNet.io.config 748 | for (i <- 0 until ButterflyNum) { 749 | permNetOut.io.in(2 * i) := PEs(i).aout 750 | permNetOut.io.in(2 * i + 1) := PEs(i).bout 751 | } 752 | io.dataOut := permNetOut.io.out 753 | } 754 | 755 | class NTTWithoutRamShare extends Module 756 | with HasCommonParameters 757 | with HasNTTParameters { 758 | val io = IO(new Bundle { 759 | val valid = Input(Bool()) 760 | val mode = Input(Bool()) 761 | val vecValid = Input(Vec(3, Bool())) 762 | 763 | // ram from outside 764 | val ra = Output(UInt(log2Ceil(Dimension / ButterflyNum).W)) 765 | val re = Output(Bool()) 766 | val di = Input(UInt((ButterflyNum * DataWidth).W)) 767 | 768 | // from register 769 | val vectorReg1 = Input(Vec(ML, UInt(DataWidth.W))) 770 | val vectorReg2 = Input(Vec(ML, UInt(DataWidth.W))) 771 | 772 | // csr interface 773 | val csrs = Input(new CSRIO) 774 | 775 | // output 776 | val dataOut = Output(Vec(ButterflyNum * 2, UInt(DataWidth.W))) 777 | val done = Output(Bool()) 778 | val busy = Output(Bool()) 779 | val wb = Output(Bool()) 780 | }) 781 | 782 | 783 | io.wb := true.B 784 | io.done := RegNext(io.valid) 785 | io.busy := false.B 786 | 787 | // permutation in 788 | val permNet = Module(new PermNetIn) 789 | for (i <- 0 until ButterflyNum*2) { 790 | if (i < ButterflyNum) { 791 | permNet.io.in(i) := io.vectorReg1(i) 792 | } 793 | else { 794 | permNet.io.in(i) := io.vectorReg2(i - ButterflyNum) 795 | } 796 | } 797 | 798 | // to support NTT and INTT : stage i -> stage n-1-i 799 | for (i <- 0 until log2Ceil(ButterflyNum)) { 800 | permNet.io.config(i) := false.B 801 | } 802 | if (ButterflyNum == 4) { 803 | when(io.csrs.csrButterflyCtrl.stageCfg === 0.U) { 804 | 805 | } .elsewhen(io.csrs.csrButterflyCtrl.stageCfg === 1.U) { 806 | permNet.io.config(0) := true.B 807 | } .otherwise { 808 | permNet.io.config(1) := true.B 809 | } 810 | } else if (ButterflyNum == 8) { 811 | when(io.csrs.csrButterflyCtrl.stageCfg === 0.U) { 812 | 813 | } .elsewhen(io.csrs.csrButterflyCtrl.stageCfg === 1.U) { 814 | permNet.io.config(0) := true.B 815 | } .elsewhen(io.csrs.csrButterflyCtrl.stageCfg === 2.U) { 816 | permNet.io.config(1) := true.B 817 | } .otherwise { 818 | permNet.io.config(2) := true.B 819 | } 820 | } else if (ButterflyNum == 16) { 821 | when(io.csrs.csrButterflyCtrl.stageCfg === 0.U) { 822 | 823 | } .elsewhen(io.csrs.csrButterflyCtrl.stageCfg === 1.U) { 824 | permNet.io.config(0) := true.B 825 | } .elsewhen(io.csrs.csrButterflyCtrl.stageCfg === 2.U) { 826 | permNet.io.config(1) := true.B 827 | } .elsewhen(io.csrs.csrButterflyCtrl.stageCfg === 3.U) { 828 | permNet.io.config(2) := true.B 829 | } .otherwise { // split = 32 830 | permNet.io.config(3) := true.B 831 | } 832 | } else if (ButterflyNum == 32) { 833 | when(io.csrs.csrButterflyCtrl.stageCfg === 0.U) { 834 | 835 | } .elsewhen(io.csrs.csrButterflyCtrl.stageCfg === 1.U) { 836 | permNet.io.config(0) := true.B 837 | } .elsewhen(io.csrs.csrButterflyCtrl.stageCfg === 2.U) { 838 | permNet.io.config(1) := true.B 839 | } .elsewhen(io.csrs.csrButterflyCtrl.stageCfg === 3.U) { 840 | permNet.io.config(2) := true.B 841 | } .elsewhen(io.csrs.csrButterflyCtrl.stageCfg === 4.U) { // split = 32 842 | permNet.io.config(3) := true.B 843 | }.otherwise { 844 | permNet.io.config(4) := true.B 845 | } 846 | } else if (ButterflyNum == 64) { 847 | when(io.csrs.csrButterflyCtrl.stageCfg === 0.U) { 848 | 849 | } .elsewhen(io.csrs.csrButterflyCtrl.stageCfg === 1.U) { 850 | permNet.io.config(0) := true.B 851 | } .elsewhen(io.csrs.csrButterflyCtrl.stageCfg === 2.U) { 852 | permNet.io.config(1) := true.B 853 | } .elsewhen(io.csrs.csrButterflyCtrl.stageCfg === 3.U) { 854 | permNet.io.config(2) := true.B 855 | } .elsewhen(io.csrs.csrButterflyCtrl.stageCfg === 4.U) { 856 | permNet.io.config(3) := true.B 857 | }.elsewhen(io.csrs.csrButterflyCtrl.stageCfg === 5.U) { 858 | permNet.io.config(4) := true.B 859 | }.otherwise { 860 | permNet.io.config(5) := true.B 861 | } 862 | } 863 | 864 | // wn addr prepare 865 | val wnBaseAddr1 = Wire(UInt(log2Ceil(Dimension / ButterflyNum).W)) 866 | val wnBaseAddr2 = Wire(UInt(log2Ceil(Dimension / ButterflyNum).W)) 867 | // contain stage info 868 | wnBaseAddr1 := Mux(io.csrs.csrButterflyCtrl.stageCfg < log2Ceil(ButterflyNum).asUInt(), 869 | 0.U, 1.U << (io.csrs.csrButterflyCtrl.stageCfg - log2Ceil(ButterflyNum).asUInt())) 870 | // contain iter info 871 | wnBaseAddr2 := Mux(io.csrs.csrButterflyCtrl.stageCfg > log2Ceil(ButterflyNum).asUInt(), 872 | io.csrs.csrButterflyCtrl.iterCfg + wnBaseAddr1, wnBaseAddr1) 873 | 874 | io.re := io.valid 875 | io.ra := wnBaseAddr2 876 | 877 | // butterfly PEs 878 | val PEs = VecInit(Seq.fill(ButterflyNum)(Module(new ButterflyShare()).io)) 879 | for (i <- 0 until ButterflyNum) { 880 | PEs(i).a := Mux(io.valid, permNet.io.out(2 * i), io.vectorReg1(i)) 881 | PEs(i).b := Mux(io.valid, permNet.io.out(2 * i + 1), io.vectorReg2(i)) 882 | // wn assign 883 | PEs(i).wn := io.di(DataWidth * i + DataWidth - 1, DataWidth * i) 884 | PEs(i).mode := io.mode 885 | PEs(i).vecValid := io.vecValid 886 | if (ButterflyNum == 4) { 887 | when(io.csrs.csrButterflyCtrl.stageCfg < log2Ceil(ButterflyNum).asUInt()) { 888 | switch(io.csrs.csrButterflyCtrl.stageCfg) { 889 | is(0.U) { 890 | PEs(i).wn := io.di(DataWidth - 1, 0) 891 | } 892 | is(1.U) { 893 | PEs(i).wn := io.di(DataWidth * (i % 2 + 2) + DataWidth - 1, DataWidth * (i % 2 + 2)) 894 | } 895 | } 896 | } 897 | } else if (ButterflyNum == 8) { 898 | when(io.csrs.csrButterflyCtrl.stageCfg < log2Ceil(ButterflyNum).asUInt()) { 899 | switch(io.csrs.csrButterflyCtrl.stageCfg) { 900 | is(0.U) { 901 | PEs(i).wn := io.di(DataWidth - 1, 0) 902 | } 903 | is(1.U) { 904 | PEs(i).wn := io.di(DataWidth * (i % 2 + 2) + DataWidth - 1, DataWidth * (i % 2 + 2)) 905 | } 906 | is(2.U) { 907 | PEs(i).wn := io.di(DataWidth * (i % 4 + 4) + DataWidth - 1, DataWidth * (i % 4 + 4)) 908 | } 909 | } 910 | } 911 | } else if (ButterflyNum == 16) { 912 | when(io.csrs.csrButterflyCtrl.stageCfg < log2Ceil(ButterflyNum).asUInt()) { 913 | switch(io.csrs.csrButterflyCtrl.stageCfg) { 914 | is(0.U) { 915 | PEs(i).wn := io.di(DataWidth - 1, 0) 916 | } 917 | is(1.U) { 918 | PEs(i).wn := io.di(DataWidth * (i % 2 + 2) + DataWidth - 1, DataWidth * (i % 2 + 2)) 919 | } 920 | is(2.U) { 921 | PEs(i).wn := io.di(DataWidth * (i % 4 + 4) + DataWidth - 1, DataWidth * (i % 4 + 4)) 922 | } 923 | is(3.U) { 924 | PEs(i).wn := io.di(DataWidth * (i % 8 + 8) + DataWidth - 1, DataWidth * (i % 8 + 8)) 925 | } 926 | } 927 | } 928 | } else if (ButterflyNum == 32) { 929 | when(io.csrs.csrButterflyCtrl.stageCfg < log2Ceil(ButterflyNum).asUInt()) { 930 | switch(io.csrs.csrButterflyCtrl.stageCfg) { 931 | is(0.U) { 932 | PEs(i).wn := io.di(DataWidth - 1, 0) 933 | } 934 | is(1.U) { 935 | PEs(i).wn := io.di(DataWidth * (i % 2 + 2) + DataWidth - 1, DataWidth * (i % 2 + 2)) 936 | } 937 | is(2.U) { 938 | PEs(i).wn := io.di(DataWidth * (i % 4 + 4) + DataWidth - 1, DataWidth * (i % 4 + 4)) 939 | } 940 | is(3.U) { 941 | PEs(i).wn := io.di(DataWidth * (i % 8 + 8) + DataWidth - 1, DataWidth * (i % 8 + 8)) 942 | } 943 | is(4.U) { 944 | PEs(i).wn := io.di(DataWidth * (i % 16 + 16) + DataWidth - 1, DataWidth * (i % 16 + 16)) 945 | } 946 | } 947 | } 948 | } else if (ButterflyNum == 64) { 949 | when(io.csrs.csrButterflyCtrl.stageCfg < log2Ceil(ButterflyNum).asUInt()) { 950 | switch(io.csrs.csrButterflyCtrl.stageCfg) { 951 | is(0.U) { 952 | PEs(i).wn := io.di(DataWidth - 1, 0) 953 | } 954 | is(1.U) { 955 | PEs(i).wn := io.di(DataWidth * (i % 2 + 2) + DataWidth - 1, DataWidth * (i % 2 + 2)) 956 | } 957 | is(2.U) { 958 | PEs(i).wn := io.di(DataWidth * (i % 4 + 4) + DataWidth - 1, DataWidth * (i % 4 + 4)) 959 | } 960 | is(3.U) { 961 | PEs(i).wn := io.di(DataWidth * (i % 8 + 8) + DataWidth - 1, DataWidth * (i % 8 + 8)) 962 | } 963 | is(4.U) { 964 | PEs(i).wn := io.di(DataWidth * (i % 16 + 16) + DataWidth - 1, DataWidth * (i % 16 + 16)) 965 | } 966 | is(5.U) { 967 | PEs(i).wn := io.di(DataWidth * (i % 32 + 32) + DataWidth - 1, DataWidth * (i % 32 + 32)) 968 | } 969 | } 970 | } 971 | } 972 | 973 | PEs(i).n := io.csrs.csrModulusLen 974 | PEs(i).m := io.csrs.csrModulusq 975 | PEs(i).u := io.csrs.csrBarretu 976 | 977 | } 978 | 979 | // permutation out 980 | val permNetOut = Module(new PermNetOut) 981 | permNetOut.io.config := permNet.io.config 982 | val vecArithOut = Wire(Vec(2 * ButterflyNum, UInt(DataWidth.W))) 983 | for (i <- 0 until ButterflyNum) { 984 | vecArithOut(i) := PEs(i).aout 985 | vecArithOut(i + ButterflyNum) := 0.U 986 | permNetOut.io.in(2 * i) := PEs(i).aout 987 | permNetOut.io.in(2 * i + 1) := PEs(i).bout 988 | } 989 | io.dataOut := Mux(io.valid, permNetOut.io.out, vecArithOut) 990 | } 991 | 992 | object elaborateNTTWithoutRamShare extends App { 993 | chisel3.Driver.execute(args, () => new NTTWithoutRamShare) 994 | } 995 | 996 | object elaborateNTTWithoutRam extends App { 997 | chisel3.Driver.execute(args, () => new NTTWithoutRam) 998 | } 999 | 1000 | object elaborateNTT extends App { 1001 | chisel3.Driver.execute(args, () => new NTT) 1002 | } 1003 | 1004 | //class SwitchIO extends Bundle 1005 | // with HasCommonParameters { 1006 | // val a = Input(UInt(DataWidth.W)) 1007 | // val b = Input(UInt(DataWidth.W)) 1008 | // val sel = Input(Bool()) 1009 | // val aout = Output(UInt(DataWidth.W)) 1010 | // val bout = Output(UInt(DataWidth.W)) 1011 | //} 1012 | // 1013 | //class Switch extends Module { 1014 | // val io = IO(new SwitchIO()) 1015 | // io.aout := Mux(io.sel, io.b, io.a) 1016 | // io.bout := Mux(io.sel, io.a, io.b) 1017 | //} 1018 | //object Switch { 1019 | // def apply(a: UInt, b: UInt, sel: Bool, aout: UInt, bout: UInt): Module = { 1020 | // val inst = Module(new Switch) 1021 | // inst.io.a := a 1022 | // inst.io.b := b 1023 | // inst.io.sel := sel 1024 | // aout := inst.io.aout 1025 | // bout := inst.io.bout 1026 | // inst 1027 | // } 1028 | //} 1029 | 1030 | //class NTTR2MDCIO extends Bundle 1031 | // with HasNTTCommonParameters { 1032 | // val dIn = Input(UInt(DataWidth.W)) 1033 | // val dInValid = Input(Bool()) 1034 | // val dOut1 = Output(UInt(DataWidth.W)) 1035 | // // val addrOut1 = Output(UInt(AddrWidth.W)) 1036 | // val dOut2 = Output(UInt(DataWidth.W)) 1037 | // // val addrOut2 = Output(UInt(AddrWidth.W)) 1038 | // val dOutValid = Output(Bool()) 1039 | //} 1040 | // 1041 | //class NTTR2MDC extends Module 1042 | // with HasMRParameters 1043 | // with HasNTTCommonParameters 1044 | // with HasNTTParameters { 1045 | // val io = IO(new NTTR2MDCIO()) 1046 | // 1047 | // val stages: Int = AddrWidth 1048 | // 1049 | // val wn = RegInit(VecInit(Seq.fill(stages)(1.U(DataWidth.W)))) 1050 | // val cnt = RegInit(0.U((stages).W)) 1051 | // when(io.dInValid){ 1052 | // cnt := cnt + 1.U 1053 | // for (i <- 0 until stages - 1) { 1054 | // val res = MRFix(wn(i) * WN(i).asUInt()) 1055 | // wn(i) := Mux(res === (MRq.asUInt() - 1.U), 1.U, res) 1056 | // } 1057 | // } 1058 | // 1059 | // val out1 = VecInit(Seq.fill(stages)(0.U(DataWidth.W))) 1060 | // val out2 = VecInit(Seq.fill(stages)(0.U(DataWidth.W))) 1061 | // /*** 1062 | // * 1063 | // * pre modular reduction 1064 | // * 1065 | // * 1066 | // */ 1067 | // 1068 | // val dIn = Mux(io.dIn < MRq.asUInt(), io.dIn, io.dIn - MRq.asUInt()) 1069 | // 1070 | // for (i <- 0 until stages){ 1071 | // if (i == 0) { 1072 | // val regIn = ShiftRegister(dIn, VectorLength / 2) 1073 | // val BFOut1 = Wire(UInt(DataWidth.W)) 1074 | // val BFOut2 = Wire(UInt(DataWidth.W)) 1075 | // BF2(regIn, dIn, BFOut1, BFOut2) 1076 | // /** 1077 | // * can add pipeline if necessary 1078 | // * 1079 | // **/ 1080 | // val mulRes = MRFix(BFOut2 * wn(i)) 1081 | // val switchIn1 = BFOut1 1082 | // val switchIn2 = ShiftRegister(mulRes, VectorLength/4) 1083 | // val swCtrl = cnt(stages-2) 1084 | // Switch(switchIn1, switchIn2, swCtrl, out1(0), out2(0)) 1085 | // } 1086 | // else if (i < stages - 1){ 1087 | // val regIn = ShiftRegister(out1(i-1), (VectorLength/pow(2, i + 1)).toInt) 1088 | // val BFOut1 = Wire(UInt(14.W)) 1089 | // val BFOut2 = Wire(UInt(14.W)) 1090 | // BF2(regIn, out2(i-1), BFOut1, BFOut2) 1091 | // /** 1092 | // * can add pipeline if necessary 1093 | // * 1094 | // **/ 1095 | // val mulRes = MRFix(BFOut2 * wn(i)) 1096 | // val switchIn1 = BFOut1 1097 | // val switchIn2 = ShiftRegister(mulRes, (VectorLength/pow(2, i + 2)).toInt) 1098 | // val swCtrl = cnt(stages-2-i) 1099 | // Switch(switchIn1, switchIn2, swCtrl, out1(i), out2(i)) 1100 | // } 1101 | // else { 1102 | // val regIn = ShiftRegister(out1(i-1), (VectorLength/pow(2, i + 1)).toInt) 1103 | // BF2(regIn, out2(i-1), out1(i), out2(i)) 1104 | // } 1105 | // } 1106 | // io.dOut1 := RegNext(out1(stages - 1)) 1107 | // io.dOut2 := RegNext(out2(stages - 1)) 1108 | // io.dOutValid := ShiftRegister(io.dInValid, VectorLength) 1109 | // 1110 | //} 1111 | -------------------------------------------------------------------------------- /src/main/scala/VPQC/NTTParameters.scala: -------------------------------------------------------------------------------- 1 | 2 | package VPQC 3 | 4 | import chisel3.util._ 5 | import scala.math._ 6 | 7 | // using Barrett reduction 8 | //trait HasMRParameters { 9 | // val MRalpha = 16 + 1 10 | // val MRbeta = -2 11 | // val MRq = 12289 12 | // val MRu = (2 << (14 + MRalpha)) / MRq 13 | //} 14 | 15 | //trait HasNTTCommonParameters { 16 | // val DataWidth = 16 17 | // val VectorLength = 512 18 | // // 9 bits 19 | // val AddrWidth = log2Ceil(VectorLength) 20 | //} 21 | 22 | trait HasNTTParameters { 23 | /** 24 | * For example in NEWHOPE512 parameters 25 | * r = 10968, w = 3, w-1 = 8193, r-1 = 3656, n-1 = 12265 26 | * w is root of unity, 3 ^ 512 = 1, 3 ^ 256 = 12289 - 1, 3 ^ 128 = 1479 27 | **/ 28 | 29 | val ButterflyNum = 32 30 | val Dimension = 1024 //ntt dimension 31 | 32 | val WN = new Array[Int](9) 33 | for (i <- 0 until 9 - 1) { 34 | if (i == 0) { 35 | WN(i) = 3 36 | } else { 37 | WN(i) = (pow(WN(i-1), 2) % 12289).toInt 38 | } 39 | } 40 | } -------------------------------------------------------------------------------- /src/main/scala/VPQC/PQCCoprocessor.scala: -------------------------------------------------------------------------------- 1 | 2 | package VPQC 3 | 4 | import chisel3._ 5 | import chisel3.util._ 6 | import utility._ 7 | 8 | // standalone coprocessor 9 | 10 | class PQCCoprocessorIO extends Bundle 11 | with HasCommonParameters 12 | with HasNTTParameters 13 | with HasPQCInstructions { 14 | // for decode 15 | val instr = Input(new PQCInstruction) 16 | val rs1 = Input(UInt(32.W)) 17 | val rs2 = Input(UInt(32.W)) 18 | val rd = Output(UInt(32.W)) 19 | val rdw = Output(Bool()) 20 | val in_fire = Input(Bool()) 21 | val rdy = Output(Bool()) 22 | val busy = Output(Bool()) 23 | 24 | // for Keccak core 25 | val seed = Input(new stateArray()) 26 | val seedWrite = Input(Bool()) 27 | 28 | // for twiddle factors / const RAM, from outside 29 | val twiddleData = Input(UInt((DataWidth * ButterflyNum).W)) 30 | val twiddleAddr = Input(UInt(log2Ceil(Dimension / ButterflyNum).W)) 31 | val twiddleWrite = Input(Bool()) 32 | 33 | val constReadEnable = Input(Bool()) 34 | val constReadAddr = Input(UInt(log2Ceil(2 * Dimension / ButterflyNum).W)) 35 | val constWrite = Input(Bool()) 36 | 37 | // vector interface 38 | val vectorOut = Output(Vec(ML, UInt(DataWidth.W))) 39 | val vectorWriteAddr = Output(UInt(32.W)) 40 | val vectorWriteEnable = Output(Bool()) 41 | 42 | val vectorIn = Input(Vec(ML, UInt(DataWidth.W))) 43 | val vectorReadaddr = Output(UInt(32.W)) 44 | val vectorReadEnable = Output(Bool()) 45 | } 46 | 47 | class CSRIO extends Bundle 48 | with HasCommonParameters { 49 | val csrBarretu = UInt((DataWidth + 2).W) 50 | val csrBound = UInt((DataWidth * 2).W) 51 | val csrBinomialk = UInt(3.W) 52 | val csrModulusq = UInt(DataWidth.W) 53 | val csrModulusLen = UInt(5.W) 54 | /** 55 | * 56 | * 9 ... 6 5 ... 0 57 | * stage_cfg iter_cfg 58 | */ 59 | 60 | val csrButterflyCtrl = new Bundle { 61 | val stageCfg = UInt(4.W) 62 | val iterCfg = UInt(6.W) 63 | } 64 | val csrLsBaseAddr = UInt(32.W) 65 | } 66 | 67 | class PQCCoprocessor extends Module 68 | with HasCommonParameters 69 | with HasPQCInstructions { 70 | val io = IO(new PQCCoprocessorIO) 71 | 72 | // decode 73 | val decoder = Module(new PQCDecode) 74 | decoder.io.instr := io.instr 75 | decoder.io.rs1 := io.rs1 76 | decoder.io.rs2 := io.rs2 77 | decoder.io.in_fire := io.in_fire 78 | val doinstr = decoder.io.result 79 | 80 | // CSRs 81 | val csrBarretu = RegInit(0.U((DataWidth + 2).W)) 82 | val csrBound = RegInit(0.U((DataWidth * 2).W)) 83 | val csrBinomialk = RegInit(0.U(3.W)) 84 | val csrModulusq = RegInit(0.U(DataWidth.W)) 85 | val csrModulusLen = RegInit(0.U(5.W)) 86 | val csrButterflyCtrl = RegInit(0.U(10.W)) 87 | val csrLsBaseAddr = RegInit(0.U(32.W)) 88 | 89 | // vector SRAM 90 | val vecSRAM = Module(new VectorRegister2Port) 91 | vecSRAM.io.vectorReadAddr1 := decoder.io.vrs1idx 92 | vecSRAM.io.vectorReadAddr2 := decoder.io.vrs2idx 93 | val vecop1 = vecSRAM.io.vectorReadValue1 94 | val vecop2 = vecSRAM.io.vectorReadValue2 95 | 96 | // const RAM 97 | val constRAM = Module(new SyncRam(dep = 64, dw = DataWidth * ML)) 98 | constRAM.io.re := io.constReadEnable 99 | constRAM.io.we := io.constWrite 100 | constRAM.io.ra := io.constReadAddr 101 | constRAM.io.wa := io.twiddleAddr 102 | constRAM.io.di := io.twiddleData 103 | val constVal = Wire(Vec(ML, UInt(DataWidth.W))) 104 | for (i <- 0 until ML) { 105 | constVal(i) := constRAM.io.dout(DataWidth * i + DataWidth - 1, DataWidth * i) 106 | } 107 | 108 | // ====================================== 109 | // EXU Stage 110 | // ====================================== 111 | val exunit = Module(new PQCExu) 112 | 113 | exunit.io.seed := io.seed 114 | exunit.io.seedWrite := io.seedWrite 115 | exunit.io.twiddleData := io.twiddleData 116 | exunit.io.twiddleAddr := io.twiddleAddr 117 | exunit.io.twiddleWrite := io.twiddleWrite 118 | for (i <- 0 until INSTR_QUANTITY) { 119 | exunit.io.valid(i) := doinstr(i) && decoder.io.fire 120 | } 121 | 122 | exunit.io.vrs1 := vecop1 123 | exunit.io.vrs2 := Mux(io.constReadEnable, constVal, vecop2) 124 | exunit.io.csrs.csrBarretu := csrBarretu 125 | exunit.io.csrs.csrBound := csrBound 126 | exunit.io.csrs.csrBinomialk := csrBinomialk 127 | exunit.io.csrs.csrModulusq := csrModulusq 128 | exunit.io.csrs.csrModulusLen := csrModulusLen 129 | exunit.io.csrs.csrButterflyCtrl.stageCfg := csrButterflyCtrl(9, 6) 130 | exunit.io.csrs.csrButterflyCtrl.iterCfg := csrButterflyCtrl(5, 0) 131 | exunit.io.csrs.csrLsBaseAddr := csrLsBaseAddr 132 | 133 | vecSRAM.io.vectorWriteEnable1 := exunit.io.done && exunit.io.wb 134 | vecSRAM.io.vectorWriteEnable2 := exunit.io.done && exunit.io.wb && exunit.io.wpos 135 | vecSRAM.io.vectorWriteData1 := exunit.io.vres1 136 | vecSRAM.io.vectorWriteData2 := exunit.io.vres2 137 | vecSRAM.io.vectorWriteAddr1 := Mux(!RegNext((doinstr(INSTR_BUTTERFLY) || doinstr(INSTR_IBUTTERFLY)) 138 | && decoder.io.fire), RegNext(decoder.io.vrdidx), RegNext(decoder.io.vrs1idx)) 139 | vecSRAM.io.vectorWriteAddr2 := RegNext(decoder.io.vrs2idx) 140 | 141 | io.rd := 0.U 142 | io.rdw := false.B 143 | // CSRRW 144 | when(decoder.io.fire && doinstr(INSTR_CSRRW)){ 145 | io.rdw:= true.B 146 | switch(decoder.io.vrs2idx) { 147 | is(0.U) { 148 | io.rd := csrBarretu 149 | csrBarretu := decoder.io.srs1 150 | } 151 | is(1.U) { 152 | io.rd := csrBound 153 | csrBound := decoder.io.srs1 154 | } 155 | is(2.U) { 156 | io.rd := csrBinomialk 157 | csrBinomialk := decoder.io.srs1 158 | } 159 | is(3.U) { 160 | io.rd := csrModulusq 161 | csrModulusq := decoder.io.srs1 162 | } 163 | is(4.U) { 164 | io.rd := csrModulusLen 165 | csrModulusLen := decoder.io.srs1 166 | } 167 | is(5.U) { 168 | io.rd := csrButterflyCtrl 169 | csrButterflyCtrl := decoder.io.srs1 170 | } 171 | is(6.U) { 172 | io.rd := csrLsBaseAddr 173 | csrLsBaseAddr := decoder.io.srs1 174 | } 175 | } 176 | } 177 | 178 | // CSRRWI 179 | when(decoder.io.fire && doinstr(INSTR_CSRRWI)){ 180 | switch(decoder.io.vrs2idx) { 181 | is(0.U) { 182 | csrBarretu := decoder.io.vrs1idx 183 | } 184 | is(1.U) { 185 | csrBound := decoder.io.vrs1idx 186 | } 187 | is(2.U) { 188 | csrBinomialk := decoder.io.vrs1idx 189 | } 190 | is(3.U) { 191 | csrModulusq := decoder.io.vrs1idx 192 | } 193 | is(4.U) { 194 | csrModulusLen := decoder.io.vrs1idx 195 | } 196 | is(5.U) { 197 | csrButterflyCtrl := decoder.io.vrs1idx 198 | } 199 | is(6.U) { 200 | csrLsBaseAddr := decoder.io.vrs1idx 201 | } 202 | } 203 | } 204 | 205 | // additional connect 206 | decoder.io.busy := exunit.io.busy 207 | io.rdy := exunit.io.done && !exunit.io.busy 208 | io.busy := decoder.io.busy 209 | 210 | // memory 211 | io.vectorReadaddr := csrLsBaseAddr 212 | io.vectorWriteAddr := csrLsBaseAddr 213 | io.vectorOut := vecop1 214 | 215 | when(decoder.io.fire && doinstr(INSTR_VLD)) { 216 | vecSRAM.io.vectorWriteData1 := io.vectorIn 217 | vecSRAM.io.vectorWriteEnable1 := true.B 218 | } 219 | 220 | io.vectorReadEnable := decoder.io.fire && doinstr(INSTR_VLD) 221 | io.vectorWriteEnable := decoder.io.fire && doinstr(INSTR_VST) 222 | } 223 | 224 | object elaboratePQCCoprocessor extends App { 225 | chisel3.Driver.execute(args, () => new PQCCoprocessor()) 226 | } 227 | 228 | /** 229 | * 230 | * ---------------------------------------------------------------------- 231 | * 232 | * 233 | **/ 234 | // for synthesis 235 | class PQCCoprocessorNoMemIO extends Bundle 236 | with HasCommonParameters 237 | with HasNTTParameters 238 | with HasPQCInstructions { 239 | // for decode 240 | val instr = Input(new PQCInstruction) 241 | val rs1 = Input(UInt(32.W)) 242 | val rs2 = Input(UInt(32.W)) 243 | val rd = Output(UInt(32.W)) 244 | val rdw = Output(Bool()) 245 | val in_fire = Input(Bool()) 246 | val rdy = Output(Bool()) 247 | val busy = Output(Bool()) 248 | 249 | // for SHA3 250 | val seed = Input(new stateArray()) 251 | val seedWrite = Input(Bool()) 252 | val writeData = Output(UInt((ML * DataWidth).W)) 253 | val writeEnable = Output(Bool()) 254 | val readEnable = Output(Bool()) 255 | val readData = Input(UInt((ML * DataWidth).W)) 256 | val readEmpty = Input(Bool()) 257 | val writeFull = Input(Bool()) 258 | 259 | // for vecop, from outside 260 | val vectorReadAddr1 = Output(UInt(5.W)) 261 | val vectorReadAddr2 = Output(UInt(5.W)) 262 | val vectorWriteData1 = Output(Vec(ML, UInt(DataWidth.W))) 263 | val vectorWriteAddr1 = Output(UInt(32.W)) 264 | val vectorWriteEnable1 = Output(Bool()) 265 | val vectorWriteData2 = Output(Vec(ML, UInt(DataWidth.W))) 266 | val vectorWriteAddr2 = Output(UInt(32.W)) 267 | val vectorWriteEnable2 = Output(Bool()) 268 | val vectorReadValue1 = Input(Vec(ML, UInt(DataWidth.W))) 269 | val vectorReadValue2 = Input(Vec(ML, UInt(DataWidth.W))) 270 | 271 | // for ConstRam, from outside 272 | // ram from outside 273 | val ra = Output(UInt(log2Ceil(Dimension / ButterflyNum).W)) 274 | val re = Output(Bool()) 275 | val di = Input(UInt((ButterflyNum * DataWidth).W)) 276 | 277 | // vector interface 278 | val vectorOut = Output(Vec(ML, UInt(DataWidth.W))) 279 | val vectorWriteAddr = Output(UInt(32.W)) 280 | val vectorWriteEnable = Output(Bool()) 281 | 282 | val vectorIn = Input(Vec(ML, UInt(DataWidth.W))) 283 | val vectorReadaddr = Output(UInt(32.W)) 284 | val vectorReadEnable = Output(Bool()) 285 | } 286 | 287 | class PQCCoprocessorNoMem extends Module 288 | with HasCommonParameters 289 | with HasPQCInstructions { 290 | val io = IO(new PQCCoprocessorNoMemIO) 291 | 292 | // decode 293 | val decoder = Module(new PQCDecode) 294 | decoder.io.instr := io.instr 295 | decoder.io.rs1 := io.rs1 296 | decoder.io.rs2 := io.rs2 297 | decoder.io.in_fire := io.in_fire 298 | val doinstr = decoder.io.result 299 | 300 | // CSRs 301 | val csrBarretu = RegInit(0.U((DataWidth + 2).W)) 302 | val csrBound = RegInit(0.U((DataWidth * 2).W)) 303 | val csrBinomialk = RegInit(0.U(3.W)) 304 | val csrModulusq = RegInit(0.U(DataWidth.W)) 305 | val csrModulusLen = RegInit(0.U(5.W)) 306 | val csrButterflyCtrl = RegInit(0.U(10.W)) 307 | val csrLsBaseAddr = RegInit(0.U(32.W)) 308 | 309 | // vector SRAM 310 | io.vectorReadAddr1 := decoder.io.vrs1idx 311 | io.vectorReadAddr2 := decoder.io.vrs2idx 312 | val vecop1 = io.vectorReadValue1 313 | val vecop2 = io.vectorReadValue2 314 | 315 | // ====================================== 316 | // EXU Stage 317 | // ====================================== 318 | val exunit = Module(new PQCExuNoMem) 319 | 320 | exunit.io.seed := io.seed 321 | exunit.io.seedWrite := io.seedWrite 322 | exunit.io.readEmpty := io.readEmpty 323 | exunit.io.readData := io.readData 324 | exunit.io.writeFull := io.writeFull 325 | io.writeData := exunit.io.writeData 326 | io.writeEnable := exunit.io.writeEnable 327 | io.readEnable := exunit.io.readEnable 328 | 329 | exunit.io.di := io.di 330 | io.re := exunit.io.re 331 | io.ra := exunit.io.ra 332 | for (i <- 0 until INSTR_QUANTITY) { 333 | exunit.io.valid(i) := doinstr(i) && decoder.io.fire 334 | } 335 | 336 | exunit.io.vrs1 := vecop1 337 | exunit.io.vrs2 := vecop2 338 | exunit.io.csrs.csrBarretu := csrBarretu 339 | exunit.io.csrs.csrBound := csrBound 340 | exunit.io.csrs.csrBinomialk := csrBinomialk 341 | exunit.io.csrs.csrModulusq := csrModulusq 342 | exunit.io.csrs.csrModulusLen := csrModulusLen 343 | exunit.io.csrs.csrButterflyCtrl.stageCfg := csrButterflyCtrl(9, 6) 344 | exunit.io.csrs.csrButterflyCtrl.iterCfg := csrButterflyCtrl(5, 0) 345 | exunit.io.csrs.csrLsBaseAddr := csrLsBaseAddr 346 | 347 | io.vectorWriteEnable1 := exunit.io.done && exunit.io.wb 348 | io.vectorWriteEnable2 := exunit.io.done && exunit.io.wb && exunit.io.wpos 349 | io.vectorWriteData1 := exunit.io.vres1 350 | io.vectorWriteData2 := exunit.io.vres2 351 | io.vectorWriteAddr1 := Mux(!RegNext((doinstr(INSTR_BUTTERFLY) || doinstr(INSTR_IBUTTERFLY)) 352 | && decoder.io.fire), RegNext(decoder.io.vrdidx), RegNext(decoder.io.vrs1idx)) 353 | io.vectorWriteAddr2 := RegNext(decoder.io.vrs2idx) 354 | 355 | io.rd := 0.U 356 | io.rdw := false.B 357 | // CSRRW 358 | when(decoder.io.fire && doinstr(INSTR_CSRRW)){ 359 | io.rdw := true.B 360 | switch(decoder.io.vrs2idx) { 361 | is(0.U) { 362 | io.rd := csrBarretu 363 | csrBarretu := decoder.io.srs1 364 | } 365 | is(1.U) { 366 | io.rd := csrBound 367 | csrBound := decoder.io.srs1 368 | } 369 | is(2.U) { 370 | io.rd := csrBinomialk 371 | csrBinomialk := decoder.io.srs1 372 | } 373 | is(3.U) { 374 | io.rd := csrModulusq 375 | csrModulusq := decoder.io.srs1 376 | } 377 | is(4.U) { 378 | io.rd := csrModulusLen 379 | csrModulusLen := decoder.io.srs1 380 | } 381 | is(5.U) { 382 | io.rd := csrButterflyCtrl 383 | csrButterflyCtrl := decoder.io.srs1 384 | } 385 | is(6.U) { 386 | io.rd := csrLsBaseAddr 387 | csrLsBaseAddr := decoder.io.srs1 388 | } 389 | } 390 | } 391 | 392 | // CSRRWI 393 | when(decoder.io.fire && doinstr(INSTR_CSRRWI)){ 394 | switch(decoder.io.vrs2idx) { 395 | is(0.U) { 396 | csrBarretu := decoder.io.vrs1idx 397 | } 398 | is(1.U) { 399 | csrBound := decoder.io.vrs1idx 400 | } 401 | is(2.U) { 402 | csrBinomialk := decoder.io.vrs1idx 403 | } 404 | is(3.U) { 405 | csrModulusq := decoder.io.vrs1idx 406 | } 407 | is(4.U) { 408 | csrModulusLen := decoder.io.vrs1idx 409 | } 410 | is(5.U) { 411 | csrButterflyCtrl := decoder.io.vrs1idx 412 | } 413 | is(6.U) { 414 | csrLsBaseAddr := decoder.io.vrs1idx 415 | } 416 | } 417 | } 418 | 419 | // additional connect 420 | decoder.io.busy := exunit.io.busy 421 | io.rdy := exunit.io.done && !exunit.io.busy 422 | io.busy := decoder.io.busy 423 | 424 | // memory 425 | io.vectorReadaddr := csrLsBaseAddr 426 | io.vectorWriteAddr := csrLsBaseAddr 427 | io.vectorOut := vecop1 428 | 429 | when(decoder.io.fire && doinstr(INSTR_VLD)) { 430 | io.vectorWriteData1 := io.vectorIn 431 | io.vectorWriteEnable1 := true.B 432 | } 433 | 434 | io.vectorReadEnable := decoder.io.fire && doinstr(INSTR_VLD) 435 | io.vectorWriteEnable := decoder.io.fire && doinstr(INSTR_VST) 436 | } 437 | 438 | object elaboratePQCCoprocessorNoMem extends App { 439 | chisel3.Driver.execute(args, () => new PQCCoprocessorNoMem()) 440 | } 441 | 442 | //class PQCCoprocessorNoMemIO extends Bundle 443 | // with HasCommonParameters 444 | // with HasNTTParameters 445 | // with HasPQCInstructions { 446 | // // for decode 447 | // val instr = Input(new PQCInstruction) 448 | // val rs1 = Input(UInt(32.W)) 449 | // val rs2 = Input(UInt(32.W)) 450 | // val rd = Output(UInt(32.W)) 451 | // val rdw = Output(Bool()) 452 | // val in_fire = Input(Bool()) 453 | // val rdy = Output(Bool()) 454 | // val busy = Output(Bool()) 455 | // 456 | // // for SHA3 457 | // val seed = Input(new stateArray()) 458 | // val seedWrite = Input(Bool()) 459 | // val writeData = Output(UInt((ML * DataWidth).W)) 460 | // val writeEnable = Output(Bool()) 461 | // val readEnable = Output(Bool()) 462 | // val readData = Input(UInt((ML * DataWidth).W)) 463 | // val readEmpty = Input(Bool()) 464 | // val writeFull = Input(Bool()) 465 | // 466 | // // for vecop, from outside 467 | // val vectorReadAddr1 = Output(UInt(5.W)) 468 | // val vectorReadAddr2 = Output(UInt(5.W)) 469 | // val vectorWriteData1 = Output(Vec(ML, UInt(DataWidth.W))) 470 | // val vectorWriteAddr1 = Output(UInt(32.W)) 471 | // val vectorWriteEnable1 = Output(Bool()) 472 | // val vectorWriteData2 = Output(Vec(ML, UInt(DataWidth.W))) 473 | // val vectorWriteAddr2 = Output(UInt(32.W)) 474 | // val vectorWriteEnable2 = Output(Bool()) 475 | // val vectorReadValue1 = Input(Vec(ML, UInt(DataWidth.W))) 476 | // val vectorReadValue2 = Input(Vec(ML, UInt(DataWidth.W))) 477 | // 478 | // // for ConstRam, from outside 479 | // // ram from outside 480 | // val ra = Output(UInt(log2Ceil(Dimension / ButterflyNum).W)) 481 | // val re = Output(Bool()) 482 | // val di = Input(UInt((ButterflyNum * DataWidth).W)) 483 | // 484 | // // vector interface 485 | // val vectorOut = Output(Vec(ML, UInt(DataWidth.W))) 486 | // val vectorWriteAddr = Output(UInt(32.W)) 487 | // val vectorWriteEnable = Output(Bool()) 488 | // 489 | // val vectorIn = Input(Vec(ML, UInt(DataWidth.W))) 490 | // val vectorReadaddr = Output(UInt(32.W)) 491 | // val vectorReadEnable = Output(Bool()) 492 | //} 493 | // 494 | //class PQCCoprocessorNoMem extends Module 495 | // with HasCommonParameters 496 | // with HasPQCInstructions { 497 | // val io = IO(new PQCCoprocessorNoMemIO) 498 | // 499 | // // decode 500 | // val decoder = Module(new PQCDecode) 501 | // decoder.io.instr := io.instr 502 | // decoder.io.rs1 := io.rs1 503 | // decoder.io.rs2 := io.rs2 504 | // decoder.io.in_fire := io.in_fire 505 | // val doinstr = decoder.io.result 506 | // 507 | // // CSRs 508 | // val csrBarretu = RegInit(0.U((DataWidth + 2).W)) 509 | // val csrBound = RegInit(0.U((DataWidth * 2).W)) 510 | // val csrBinomialk = RegInit(0.U(3.W)) 511 | // val csrModulusq = RegInit(0.U(DataWidth.W)) 512 | // val csrModulusLen = RegInit(0.U(5.W)) 513 | // val csrButterflyCtrl = RegInit(0.U(10.W)) 514 | // val csrLsBaseAddr = RegInit(0.U(32.W)) 515 | // 516 | // // vector SRAM 517 | // io.vectorReadAddr1 := decoder.io.vrs1idx 518 | // io.vectorReadAddr2 := decoder.io.vrs2idx 519 | // val vecop1 = io.vectorReadValue1 520 | // val vecop2 = io.vectorReadValue2 521 | // 522 | // // ====================================== 523 | // // EXU Stage 524 | // // ====================================== 525 | // val exunit = Module(new PQCExuNoMem) 526 | // 527 | // exunit.io.seed := io.seed 528 | // exunit.io.seedWrite := io.seedWrite 529 | // exunit.io.readEmpty := io.readEmpty 530 | // exunit.io.readData := io.readData 531 | // exunit.io.writeFull := io.writeFull 532 | // io.writeData := exunit.io.writeData 533 | // io.writeEnable := exunit.io.writeEnable 534 | // io.readEnable := exunit.io.readEnable 535 | // 536 | // exunit.io.di := io.di 537 | // io.re := exunit.io.re 538 | // io.ra := exunit.io.ra 539 | // for (i <- 0 until INSTR_QUANTITY) { 540 | // exunit.io.valid(i) := doinstr(i) && decoder.io.fire 541 | // } 542 | // 543 | // exunit.io.vrs1 := vecop1 544 | // exunit.io.vrs2 := vecop2 545 | // exunit.io.csrs.csrBarretu := csrBarretu 546 | // exunit.io.csrs.csrBound := csrBound 547 | // exunit.io.csrs.csrBinomialk := csrBinomialk 548 | // exunit.io.csrs.csrModulusq := csrModulusq 549 | // exunit.io.csrs.csrModulusLen := csrModulusLen 550 | // exunit.io.csrs.csrButterflyCtrl.stageCfg := csrButterflyCtrl(9, 6) 551 | // exunit.io.csrs.csrButterflyCtrl.iterCfg := csrButterflyCtrl(5, 0) 552 | // exunit.io.csrs.csrLsBaseAddr := csrLsBaseAddr 553 | // 554 | // io.vectorWriteEnable1 := exunit.io.done && exunit.io.wb 555 | // io.vectorWriteEnable2 := exunit.io.done && exunit.io.wb && exunit.io.wpos 556 | // io.vectorWriteData1 := exunit.io.vres1 557 | // io.vectorWriteData2 := exunit.io.vres2 558 | // io.vectorWriteAddr1 := Mux(!RegNext((doinstr(INSTR_BUTTERFLY) || doinstr(INSTR_IBUTTERFLY)) 559 | // && decoder.io.fire), RegNext(decoder.io.vrdidx), RegNext(decoder.io.vrs1idx)) 560 | // io.vectorWriteAddr2 := RegNext(decoder.io.vrs2idx) 561 | // 562 | // io.rd := 0.U 563 | // io.rdw := false.B 564 | // // CSRRW 565 | // when(decoder.io.fire && doinstr(INSTR_CSRRW)){ 566 | // io.rdw := true.B 567 | // switch(decoder.io.vrs2idx) { 568 | // is(0.U) { 569 | // io.rd := csrBarretu 570 | // csrBarretu := decoder.io.srs1 571 | // } 572 | // is(1.U) { 573 | // io.rd := csrBound 574 | // csrBound := decoder.io.srs1 575 | // } 576 | // is(2.U) { 577 | // io.rd := csrBinomialk 578 | // csrBinomialk := decoder.io.srs1 579 | // } 580 | // is(3.U) { 581 | // io.rd := csrModulusq 582 | // csrModulusq := decoder.io.srs1 583 | // } 584 | // is(4.U) { 585 | // io.rd := csrModulusLen 586 | // csrModulusLen := decoder.io.srs1 587 | // } 588 | // is(5.U) { 589 | // io.rd := csrButterflyCtrl 590 | // csrButterflyCtrl := decoder.io.srs1 591 | // } 592 | // is(6.U) { 593 | // io.rd := csrLsBaseAddr 594 | // csrLsBaseAddr := decoder.io.srs1 595 | // } 596 | // } 597 | // } 598 | // 599 | // // CSRRWI 600 | // when(decoder.io.fire && doinstr(INSTR_CSRRWI)){ 601 | // switch(decoder.io.vrs2idx) { 602 | // is(0.U) { 603 | // csrBarretu := decoder.io.vrs1idx 604 | // } 605 | // is(1.U) { 606 | // csrBound := decoder.io.vrs1idx 607 | // } 608 | // is(2.U) { 609 | // csrBinomialk := decoder.io.vrs1idx 610 | // } 611 | // is(3.U) { 612 | // csrModulusq := decoder.io.vrs1idx 613 | // } 614 | // is(4.U) { 615 | // csrModulusLen := decoder.io.vrs1idx 616 | // } 617 | // is(5.U) { 618 | // csrButterflyCtrl := decoder.io.vrs1idx 619 | // } 620 | // is(6.U) { 621 | // csrLsBaseAddr := decoder.io.vrs1idx 622 | // } 623 | // } 624 | // } 625 | // 626 | // // additional connect 627 | // decoder.io.busy := exunit.io.busy 628 | // io.rdy := exunit.io.done && !exunit.io.busy 629 | // io.busy := decoder.io.busy 630 | // 631 | // // memory 632 | // io.vectorReadaddr := csrLsBaseAddr 633 | // io.vectorWriteAddr := csrLsBaseAddr 634 | // io.vectorOut := vecop1 635 | // 636 | // when(decoder.io.fire && doinstr(INSTR_VLD)) { 637 | // io.vectorWriteData1 := io.vectorIn 638 | // io.vectorWriteEnable1 := true.B 639 | // } 640 | // 641 | // io.vectorReadEnable := decoder.io.fire && doinstr(INSTR_VLD) 642 | // io.vectorWriteEnable := decoder.io.fire && doinstr(INSTR_VST) 643 | //} 644 | -------------------------------------------------------------------------------- /src/main/scala/VPQC/PQCDecode.scala: -------------------------------------------------------------------------------- 1 | 2 | package VPQC 3 | 4 | import chisel3._ 5 | 6 | class PQCInstruction extends Bundle { 7 | val funct = UInt(7.W) 8 | val rs2 = UInt(5.W) 9 | val rs1 = UInt(5.W) 10 | val xd = Bool() 11 | val xs1 = Bool() 12 | val xs2 = Bool() 13 | val rd = UInt(5.W) 14 | val opcode = UInt(7.W) 15 | } 16 | 17 | class PQCDecode extends Module 18 | with HasPQCInstructions 19 | with HasCommonParameters { 20 | val io = IO(new Bundle{ 21 | val instr = Input(new PQCInstruction) 22 | val rs1 = Input(UInt(32.W)) 23 | val rs2 = Input(UInt(32.W)) 24 | val in_fire = Input(Bool()) 25 | val result = Output(Vec(INSTR_QUANTITY, Bool())) 26 | val busy = Input(Bool()) 27 | val pqcBusy = Output(Bool()) 28 | val srs1 = Output(UInt(32.W)) 29 | val srs2 = Output(UInt(32.W)) 30 | val vrs1idx = Output(UInt(5.W)) 31 | val vrs2idx = Output(UInt(5.W)) 32 | val vrdidx = Output(UInt(5.W)) 33 | val fire = Output(Bool()) 34 | }) 35 | for(i <- 0 until INSTR_QUANTITY){ 36 | io.result(i) := io.instr.funct === i.U 37 | } 38 | io.pqcBusy := io.busy 39 | io.vrs1idx := io.instr.rs1 40 | io.vrs2idx := io.instr.rs2 41 | io.vrdidx := io.instr.rd 42 | io.srs1 := io.rs1 43 | io.srs2 := io.rs2 44 | io.fire := io.in_fire 45 | } 46 | 47 | -------------------------------------------------------------------------------- /src/main/scala/VPQC/PQCExu.scala: -------------------------------------------------------------------------------- 1 | 2 | package VPQC 3 | 4 | import chisel3._ 5 | import chisel3.util._ 6 | 7 | // Current simplified version does not apply 8 | // resource sharing between ntt module and vector arithmetic module 9 | class PQCExuIO extends Bundle 10 | with HasPQCInstructions 11 | with HasCommonParameters 12 | with HasNTTParameters { 13 | // for SHA3, from outside 14 | val seed = Input(new stateArray()) 15 | val seedWrite = Input(Bool()) 16 | // for ConstRam, from outside 17 | val twiddleData = Input(UInt((DataWidth * ButterflyNum).W)) 18 | val twiddleAddr = Input(UInt(log2Ceil(Dimension / ButterflyNum).W)) 19 | val twiddleWrite = Input(Bool()) 20 | val wpos = Output(Bool()) 21 | 22 | // pipeline 23 | val valid = Input(Vec(INSTR_QUANTITY, Bool())) 24 | // val rs1 = Input(UInt(64.W)) 25 | // val rs2 = Input(UInt(64.W)) 26 | val vrs1 = Input(Vec(ML, UInt(DataWidth.W))) 27 | val vrs2 = Input(Vec(ML, UInt(DataWidth.W))) 28 | val csrs = Input(new CSRIO) 29 | val done = Output(Bool()) 30 | val busy = Output(Bool()) 31 | val wb = Output(Bool()) 32 | val vres1 = Output(Vec(ML, UInt(DataWidth.W))) 33 | val vres2 = Output(Vec(ML, UInt(DataWidth.W))) 34 | } 35 | 36 | class PQCExu extends Module 37 | with HasPQCInstructions 38 | with HasCommonParameters 39 | with HasNTTParameters { 40 | val io = IO(new PQCExuIO()) 41 | 42 | val vrs1 = io.vrs1 43 | val vrs2 = io.vrs2 44 | 45 | val SHA3Core = Module(new KeccakWithFifo) 46 | SHA3Core.io.seed := io.seed 47 | SHA3Core.io.seedWrite := io.seedWrite 48 | SHA3Core.io.valid := io.valid(INSTR_FETCHRN) 49 | 50 | val Samplers = Module(new Samplers) 51 | Samplers.io.vectorReg1 := vrs1 52 | Samplers.io.vectorReg2 := vrs2 53 | Samplers.io.csrs := io.csrs 54 | Samplers.io.valid := io.valid(INSTR_SAMPLEBINOMIAL) || io.valid(INSTR_SAMPLEREJECTION) 55 | Samplers.io.mode := io.valid(INSTR_SAMPLEBINOMIAL) || !io.valid(INSTR_SAMPLEREJECTION) 56 | 57 | val VecArith = Module(new VectorArith) 58 | VecArith.io.addA := vrs1 59 | VecArith.io.addB := vrs2 60 | VecArith.io.subA := vrs1 61 | VecArith.io.subB := vrs2 62 | VecArith.io.mulA := vrs1 63 | VecArith.io.mulB := vrs2 64 | VecArith.io.csrs := io.csrs 65 | VecArith.io.valid := io.valid(INSTR_VADD) || io.valid(INSTR_VSUB) || io.valid(INSTR_VMUL) 66 | val vecData = Mux(io.valid(INSTR_VADD), VecArith.io.addRes, 67 | Mux(io.valid(INSTR_VSUB), VecArith.io.subRes, 68 | Mux(io.valid(INSTR_VMUL), VecArith.io.mulRes, VecInit(Seq.fill(ML)(0.U(DataWidth.W)))))) 69 | 70 | val NTT = Module(new NTT) 71 | NTT.io.wa := io.twiddleAddr 72 | NTT.io.di := io.twiddleData 73 | NTT.io.we := io.twiddleWrite 74 | NTT.io.vectorReg1 := vrs1 75 | NTT.io.vectorReg2 := vrs2 76 | NTT.io.csrs := io.csrs 77 | NTT.io.valid := io.valid(INSTR_BUTTERFLY) || io.valid(INSTR_IBUTTERFLY) 78 | NTT.io.mode := !io.valid(INSTR_BUTTERFLY) && io.valid(INSTR_IBUTTERFLY) 79 | 80 | 81 | val done = SHA3Core.io.done || Samplers.io.done || NTT.io.done || VecArith.io.done 82 | val busy = SHA3Core.io.busy || Samplers.io.busy || NTT.io.busy || VecArith.io.done 83 | val NTTData1 = Wire(Vec(ButterflyNum , UInt(DataWidth.W))) 84 | val NTTData2 = Wire(Vec(ButterflyNum , UInt(DataWidth.W))) 85 | for (i <- 0 until ButterflyNum) { 86 | NTTData1(i) := NTT.io.dataOut(i) 87 | } 88 | for (i <- ButterflyNum until ButterflyNum*2) { 89 | NTTData2(i-ButterflyNum) := NTT.io.dataOut(i) 90 | } 91 | 92 | val vres1 = MuxCase(SHA3Core.io.prn, Array( 93 | Samplers.io.done -> Samplers.io.sampledData, 94 | VecArith.io.done -> vecData, 95 | NTT.io.done -> NTTData1 96 | )) 97 | val vres2 = MuxCase(SHA3Core.io.prn, Array( 98 | Samplers.io.done -> Samplers.io.sampledData, 99 | VecArith.io.done -> vecData, 100 | NTT.io.done -> NTTData2 101 | )) 102 | 103 | val wb = MuxCase(false.B, Array( 104 | SHA3Core.io.done -> SHA3Core.io.wb, 105 | Samplers.io.done -> Samplers.io.wb, 106 | VecArith.io.done -> VecArith.io.wb, 107 | NTT.io.done -> NTT.io.wb 108 | )) 109 | 110 | io.done := done 111 | io.vres1 := vres1 112 | io.vres2 := vres2 113 | io.busy := busy 114 | io.wb := wb 115 | io.wpos := NTT.io.done 116 | } 117 | 118 | class PQCExuNoMemIO extends Bundle 119 | with HasPQCInstructions 120 | with HasCommonParameters 121 | with HasNTTParameters { 122 | // for SHA3, from outside 123 | val seed = Input(new stateArray()) 124 | val seedWrite = Input(Bool()) 125 | // prefetch FIFO IO 126 | val writeData = Output(UInt((ML * DataWidth).W)) 127 | val writeEnable = Output(Bool()) 128 | val readEnable = Output(Bool()) 129 | val readData = Input(UInt((ML * DataWidth).W)) 130 | val readEmpty = Input(Bool()) 131 | val writeFull = Input(Bool()) 132 | 133 | // for ConstRam, from outside 134 | // ram from outside 135 | val ra = Output(UInt(log2Ceil(Dimension / ButterflyNum).W)) 136 | val re = Output(Bool()) 137 | val di = Input(UInt((ButterflyNum * DataWidth).W)) 138 | 139 | val wpos = Output(Bool()) 140 | 141 | // pipeline 142 | val valid = Input(Vec(INSTR_QUANTITY, Bool())) 143 | // val rs1 = Input(UInt(64.W)) 144 | // val rs2 = Input(UInt(64.W)) 145 | val vrs1 = Input(Vec(ML, UInt(DataWidth.W))) 146 | val vrs2 = Input(Vec(ML, UInt(DataWidth.W))) 147 | val csrs = Input(new CSRIO) 148 | val done = Output(Bool()) 149 | val busy = Output(Bool()) 150 | val wb = Output(Bool()) 151 | val vres1 = Output(Vec(ML, UInt(DataWidth.W))) 152 | val vres2 = Output(Vec(ML, UInt(DataWidth.W))) 153 | } 154 | 155 | // resource sharing 156 | class PQCExuNoMem extends Module 157 | with HasPQCInstructions 158 | with HasCommonParameters 159 | with HasNTTParameters { 160 | val io = IO(new PQCExuNoMemIO()) 161 | 162 | val vrs1 = io.vrs1 163 | val vrs2 = io.vrs2 164 | 165 | val SHA3Core = Module(new KeccakNoFifo) 166 | SHA3Core.io.seed := io.seed 167 | SHA3Core.io.seedWrite := io.seedWrite 168 | SHA3Core.io.valid := io.valid(INSTR_FETCHRN) 169 | SHA3Core.io.readEmpty := io.readEmpty 170 | SHA3Core.io.readData := io.readData 171 | SHA3Core.io.writeFull := io.writeFull 172 | io.writeData := SHA3Core.io.writeData 173 | io.writeEnable := SHA3Core.io.writeEnable 174 | io.readEnable := SHA3Core.io.readEnable 175 | 176 | val Samplers = Module(new Samplers) 177 | Samplers.io.vectorReg1 := vrs1 178 | Samplers.io.vectorReg2 := vrs2 179 | Samplers.io.csrs := io.csrs 180 | Samplers.io.valid := io.valid(INSTR_SAMPLEBINOMIAL) || io.valid(INSTR_SAMPLEREJECTION) 181 | Samplers.io.mode := io.valid(INSTR_SAMPLEBINOMIAL) || !io.valid(INSTR_SAMPLEREJECTION) 182 | 183 | val NTT = Module(new NTTWithoutRam) 184 | NTT.io.di := io.di 185 | io.ra := NTT.io.ra 186 | io.re := NTT.io.re 187 | 188 | NTT.io.vectorReg1 := vrs1 189 | NTT.io.vectorReg2 := vrs2 190 | NTT.io.csrs := io.csrs 191 | NTT.io.valid := io.valid(INSTR_BUTTERFLY) || io.valid(INSTR_IBUTTERFLY) 192 | NTT.io.mode := !io.valid(INSTR_BUTTERFLY) && io.valid(INSTR_IBUTTERFLY) 193 | 194 | 195 | val done = SHA3Core.io.done || Samplers.io.done || NTT.io.done 196 | val busy = SHA3Core.io.busy || Samplers.io.busy || NTT.io.busy 197 | val NTTData1 = Wire(Vec(ButterflyNum , UInt(DataWidth.W))) 198 | val NTTData2 = Wire(Vec(ButterflyNum , UInt(DataWidth.W))) 199 | for (i <- 0 until ButterflyNum) { 200 | NTTData1(i) := NTT.io.dataOut(i) 201 | } 202 | for (i <- ButterflyNum until ButterflyNum*2) { 203 | NTTData2(i-ButterflyNum) := NTT.io.dataOut(i) 204 | } 205 | 206 | val vres1 = MuxCase(SHA3Core.io.prn, Array( 207 | Samplers.io.done -> Samplers.io.sampledData, 208 | NTT.io.done -> NTTData1 209 | )) 210 | val vres2 = MuxCase(SHA3Core.io.prn, Array( 211 | Samplers.io.done -> Samplers.io.sampledData, 212 | NTT.io.done -> NTTData2 213 | )) 214 | 215 | val wb = MuxCase(false.B, Array( 216 | SHA3Core.io.done -> SHA3Core.io.wb, 217 | Samplers.io.done -> Samplers.io.wb, 218 | NTT.io.done -> NTT.io.wb 219 | )) 220 | 221 | io.done := done 222 | io.vres1 := vres1 223 | io.vres2 := vres2 224 | io.busy := busy 225 | io.wb := wb 226 | io.wpos := NTT.io.done 227 | } 228 | 229 | class PQCExuNoMemIO2 extends Bundle 230 | with HasPQCInstructions 231 | with HasCommonParameters 232 | with HasNTTParameters { 233 | // for SHA3, from outside 234 | val seed = Input(new stateArray()) 235 | val seedWrite = Input(Bool()) 236 | // prefetch FIFO IO 237 | val writeData = Output(UInt((ML * DataWidth).W)) 238 | val writeEnable = Output(Bool()) 239 | val readEnable = Output(Bool()) 240 | val readData = Input(UInt((ML * DataWidth).W)) 241 | val readEmpty = Input(Bool()) 242 | val writeFull = Input(Bool()) 243 | 244 | // for ConstRam, from outside 245 | // ram from outside 246 | val ra = Output(UInt(log2Ceil(Dimension / ButterflyNum).W)) 247 | val re = Output(Bool()) 248 | val di = Input(UInt((ButterflyNum * DataWidth).W)) 249 | 250 | val wpos = Output(Bool()) 251 | 252 | // pipeline 253 | val valid = Input(Vec(INSTR_QUANTITY, Bool())) 254 | // val rs1 = Input(UInt(64.W)) 255 | // val rs2 = Input(UInt(64.W)) 256 | val vrs1 = Input(Vec(ML, UInt(DataWidth.W))) 257 | val vrs2 = Input(Vec(ML, UInt(DataWidth.W))) 258 | val csrs = Input(new CSRIO) 259 | val done = Output(Bool()) 260 | val busy = Output(Bool()) 261 | val wb = Output(Bool()) 262 | val vres1 = Output(Vec(ML, UInt(DataWidth.W))) 263 | val vres2 = Output(Vec(ML, UInt(DataWidth.W))) 264 | } 265 | 266 | // resource sharing 267 | class PQCExuNoMem2 extends Module 268 | with HasPQCInstructions 269 | with HasCommonParameters 270 | with HasNTTParameters { 271 | val io = IO(new PQCExuNoMemIO2()) 272 | 273 | val vrs1 = io.vrs1 274 | val vrs2 = io.vrs2 275 | 276 | val SHA3Core = Module(new KeccakNoFifo) 277 | SHA3Core.io.seed := io.seed 278 | SHA3Core.io.seedWrite := io.seedWrite 279 | SHA3Core.io.valid := io.valid(INSTR_FETCHRN) 280 | SHA3Core.io.readEmpty := io.readEmpty 281 | SHA3Core.io.readData := io.readData 282 | SHA3Core.io.writeFull := io.writeFull 283 | io.writeData := SHA3Core.io.writeData 284 | io.writeEnable := SHA3Core.io.writeEnable 285 | io.readEnable := SHA3Core.io.readEnable 286 | 287 | val Samplers = Module(new Samplers) 288 | Samplers.io.vectorReg1 := vrs1 289 | Samplers.io.vectorReg2 := vrs2 290 | Samplers.io.csrs := io.csrs 291 | Samplers.io.valid := io.valid(INSTR_SAMPLEBINOMIAL) || io.valid(INSTR_SAMPLEREJECTION) 292 | Samplers.io.mode := io.valid(INSTR_SAMPLEBINOMIAL) || !io.valid(INSTR_SAMPLEREJECTION) 293 | 294 | val NTT = Module(new NTTWithoutRamShare) 295 | NTT.io.di := io.di 296 | io.ra := NTT.io.ra 297 | io.re := NTT.io.re 298 | 299 | NTT.io.vectorReg1 := vrs1 300 | NTT.io.vectorReg2 := vrs2 301 | NTT.io.csrs := io.csrs 302 | NTT.io.valid := io.valid(INSTR_BUTTERFLY) || io.valid(INSTR_IBUTTERFLY) 303 | NTT.io.vecValid(0) := io.valid(INSTR_VADD) 304 | NTT.io.vecValid(1) := io.valid(INSTR_VSUB) 305 | NTT.io.vecValid(2) := io.valid(INSTR_VMUL) 306 | NTT.io.mode := !io.valid(INSTR_BUTTERFLY) && io.valid(INSTR_IBUTTERFLY) 307 | 308 | 309 | val done = SHA3Core.io.done || Samplers.io.done || NTT.io.done 310 | val busy = SHA3Core.io.busy || Samplers.io.busy || NTT.io.busy 311 | val NTTData1 = Wire(Vec(ButterflyNum , UInt(DataWidth.W))) 312 | val NTTData2 = Wire(Vec(ButterflyNum , UInt(DataWidth.W))) 313 | for (i <- 0 until ButterflyNum) { 314 | NTTData1(i) := NTT.io.dataOut(i) 315 | } 316 | for (i <- ButterflyNum until ButterflyNum*2) { 317 | NTTData2(i-ButterflyNum) := NTT.io.dataOut(i) 318 | } 319 | 320 | val vres1 = MuxCase(SHA3Core.io.prn, Array( 321 | Samplers.io.done -> Samplers.io.sampledData, 322 | NTT.io.done -> NTTData1 323 | )) 324 | val vres2 = MuxCase(SHA3Core.io.prn, Array( 325 | Samplers.io.done -> Samplers.io.sampledData, 326 | NTT.io.done -> NTTData2 327 | )) 328 | 329 | val wb = MuxCase(false.B, Array( 330 | SHA3Core.io.done -> SHA3Core.io.wb, 331 | Samplers.io.done -> Samplers.io.wb, 332 | NTT.io.done -> NTT.io.wb 333 | )) 334 | 335 | io.done := done 336 | io.vres1 := vres1 337 | io.vres2 := vres2 338 | io.busy := busy 339 | io.wb := wb 340 | io.wpos := NTT.io.done && NTT.io.valid 341 | } -------------------------------------------------------------------------------- /src/main/scala/VPQC/SamplerParameters.scala: -------------------------------------------------------------------------------- 1 | 2 | package VPQC 3 | 4 | trait HasSamplerParameters { 5 | // datawidth for binomial 6 | // support k = 1, 2, 4, 8, 16 7 | 8 | // number of Sampler 9 | val SamplerNum = 32 10 | } 11 | -------------------------------------------------------------------------------- /src/main/scala/VPQC/Samplers.scala: -------------------------------------------------------------------------------- 1 | 2 | package VPQC 3 | 4 | import chisel3._ 5 | import chisel3.util._ 6 | 7 | class BinomialSamplerIO extends Bundle 8 | with HasSamplerParameters 9 | with HasCommonParameters { 10 | val data1 = Input(UInt(DataWidth.W)) 11 | val data2 = Input(UInt(DataWidth.W)) 12 | val k = Input(UInt(3.W)) 13 | val q = Input(UInt(DataWidth.W)) 14 | val dataOut = Output(UInt(DataWidth.W)) 15 | } 16 | 17 | class BinomialSampler extends Module 18 | with HasSamplerParameters 19 | with HasCommonParameters { 20 | val io = IO(new BinomialSamplerIO()) 21 | 22 | val data1pop2 = Wire(Vec(DataWidth / 2, UInt(2.W))) 23 | val data2pop2 = Wire(Vec(DataWidth / 2, UInt(2.W))) 24 | for (i <- 0 until DataWidth / 2) { 25 | data1pop2(i) := io.data1(2 * i + 1) + io.data1(2 * i) 26 | data2pop2(i) := io.data2(2 * i + 1) + io.data2(2 * i) 27 | } 28 | 29 | val data1pop4 = Wire(Vec(DataWidth / 4, UInt(3.W))) 30 | val data2pop4 = Wire(Vec(DataWidth / 4, UInt(3.W))) 31 | for (i <- 0 until DataWidth / 4) { 32 | data1pop4(i) := data1pop2(2 * i + 1) + data1pop2(2 * i) 33 | data2pop4(i) := data2pop2(2 * i + 1) + data2pop2(2 * i) 34 | } 35 | 36 | val data1pop8 = Wire(Vec(DataWidth / 8, UInt(4.W))) 37 | val data2pop8 = Wire(Vec(DataWidth / 8, UInt(4.W))) 38 | for (i <- 0 until DataWidth / 8) { 39 | data1pop8(i) := data1pop4(2 * i + 1) + data1pop4(2 * i) 40 | data2pop8(i) := data2pop4(2 * i + 1) + data2pop4(2 * i) 41 | } 42 | 43 | val data1pop16 = Wire(UInt(5.W)) 44 | val data2pop16 = Wire(UInt(5.W)) 45 | if (DataWidth == 16) { 46 | data1pop16 := data1pop8(1) + data1pop8(0) 47 | data2pop16 := data2pop8(1) + data2pop8(0) 48 | 49 | } else { 50 | data1pop16 := data1pop8(0) 51 | data2pop16 := data2pop8(0) 52 | 53 | } 54 | 55 | // two minus operands 56 | val op1 = WireInit(0.U(5.W)) 57 | val op2 = WireInit(0.U(5.W)) 58 | 59 | switch(io.k) { 60 | // k = 1 61 | is(0.U) { 62 | op1 := io.data1(0) 63 | op2 := io.data2(0) 64 | } 65 | // k = 2 66 | is(1.U) { 67 | op1 := data1pop2(0) 68 | op2 := data2pop2(0) 69 | } 70 | // k = 4 71 | is(2.U) { 72 | op1 := data1pop4(0) 73 | op2 := data2pop4(0) 74 | } 75 | // k = 8 76 | is(3.U) { 77 | op1 := data1pop8(0) 78 | op2 := data2pop8(0) 79 | } 80 | // k = 16 81 | is(4.U) { 82 | op1 := data1pop16(0) 83 | op2 := data2pop16(0) 84 | } 85 | } 86 | 87 | val diff = op1 < op2 88 | val out = Mux(diff, op1 + io.q - op2, op1 - op2) 89 | io.dataOut := out 90 | } 91 | 92 | class RejectionSamplerIO extends Bundle 93 | with HasSamplerParameters 94 | with HasCommonParameters { 95 | val dataIn = Input(UInt((DataWidth * 2).W)) 96 | val n = Input(UInt(5.W)) 97 | val m = Input(UInt(DataWidth.W)) 98 | val u = Input(UInt((DataWidth + 2).W)) 99 | val bound = Input(UInt((DataWidth * 2).W)) 100 | val dataOut = Output(UInt(DataWidth.W)) 101 | val valid = Output(Bool()) 102 | } 103 | 104 | // small rejection rate 105 | class RejectionSampler extends Module 106 | with HasSamplerParameters { 107 | val io = IO(new RejectionSamplerIO()) 108 | 109 | val failed = !(io.dataIn < io.bound) 110 | io.dataOut := Mux(failed, 0.U, MR(io.dataIn, io.n, io.m, io.u)) 111 | io.valid := !failed 112 | } 113 | 114 | class SamplersIO extends Bundle 115 | with HasSamplerParameters 116 | with HasKeccakParameters 117 | with HasCommonParameters { 118 | // 0: rejection sample 1: binomial sample 119 | val valid = Input(Bool()) 120 | val mode = Input(Bool()) 121 | 122 | // from register 123 | val vectorReg1 = Input(Vec(ML, UInt(DataWidth.W))) 124 | val vectorReg2 = Input(Vec(ML, UInt(DataWidth.W))) 125 | 126 | // csr interface 127 | val csrs = Input(new CSRIO) 128 | 129 | // output 130 | val sampledData = Output(Vec(SamplerNum, UInt(DataWidth.W))) 131 | val sampledDataValid = Output(Bool()) 132 | val done = Output(Bool()) 133 | val busy = Output(Bool()) 134 | val wb = Output(Bool()) 135 | } 136 | 137 | class Samplers extends Module 138 | with HasSamplerParameters 139 | with HasCommonParameters { 140 | val io = IO(new SamplersIO()) 141 | 142 | val rejectionSamplers = VecInit(Seq.fill(SamplerNum)(Module(new RejectionSampler()).io)) 143 | val rejectionData = Wire(Vec(SamplerNum, UInt(DataWidth.W))) 144 | for (i <- 0 until SamplerNum) { 145 | rejectionSamplers(i).dataIn := Cat(io.vectorReg2(i), io.vectorReg1(i)) 146 | rejectionSamplers(i).bound := io.csrs.csrBound 147 | rejectionSamplers(i).n := io.csrs.csrModulusLen 148 | rejectionSamplers(i).m := io.csrs.csrModulusq 149 | rejectionSamplers(i).u := io.csrs.csrBarretu 150 | rejectionData(i) := rejectionSamplers(i).dataOut 151 | } 152 | 153 | val binomialSamplers = VecInit(Seq.fill(SamplerNum)(Module(new BinomialSampler()).io)) 154 | val binomialData = Wire(Vec(SamplerNum, UInt(DataWidth.W))) 155 | for (i <- 0 until SamplerNum) { 156 | binomialSamplers(i).data1 := io.vectorReg1(i) 157 | binomialSamplers(i).data2 := io.vectorReg2(i) 158 | binomialSamplers(i).k := io.csrs.csrBinomialk 159 | binomialSamplers(i).q := io.csrs.csrModulusq 160 | binomialData(i) := binomialSamplers(i).dataOut 161 | } 162 | 163 | io.sampledData := Mux(io.mode, binomialData, rejectionData) 164 | val rejectionValid = (0 until SamplerNum).map(i => rejectionSamplers(i).valid).reduce(_ && _) 165 | io.sampledDataValid := Mux(io.mode, true.B, rejectionValid) 166 | io.done := RegNext(io.valid) 167 | io.busy := false.B 168 | io.wb := true.B 169 | } 170 | 171 | object elaborateSamplers extends App { 172 | chisel3.Driver.execute(args, () => new Samplers()) 173 | } 174 | -------------------------------------------------------------------------------- /src/main/scala/VPQC/VectorRegister.scala: -------------------------------------------------------------------------------- 1 | 2 | package VPQC 3 | 4 | import chisel3._ 5 | 6 | // behavior model 7 | // use SRAM to implement it 8 | class VectorRegisterIO extends Bundle 9 | with HasCommonParameters { 10 | val vectorReadAddr1 = Input(UInt(5.W)) 11 | val vectorReadAddr2 = Input(UInt(5.W)) 12 | val vectorWriteAddr = Input(UInt(5.W)) 13 | val vectorWriteData = Input(Vec(ML, UInt(DataWidth.W))) 14 | val vectorWriteEnable = Input(Bool()) 15 | val vectorReadValue1 = Output(Vec(ML, UInt(DataWidth.W))) 16 | val vectorReadValue2 = Output(Vec(ML, UInt(DataWidth.W))) 17 | } 18 | 19 | class VectorRegister extends Module 20 | with HasCommonParameters { 21 | val io = IO(new VectorRegisterIO) 22 | 23 | val vectorSRAM = SyncReadMem(32, UInt((ML * DataWidth).W)) 24 | 25 | // Create one write port and two read port. 26 | when (io.vectorWriteEnable) { 27 | vectorSRAM.write(io.vectorWriteAddr, io.vectorWriteData.asUInt()) 28 | } 29 | 30 | for (i <- 0 until ML) { 31 | io.vectorReadValue1(i) := vectorSRAM.read(io.vectorReadAddr1, true.B)(DataWidth * i + DataWidth - 1, DataWidth * i) 32 | io.vectorReadValue2(i) := vectorSRAM.read(io.vectorReadAddr2, true.B)(DataWidth * i + DataWidth - 1, DataWidth * i) 33 | } 34 | } 35 | 36 | class VectorRegister2PortIO extends Bundle 37 | with HasCommonParameters { 38 | val vectorReadAddr1 = Input(UInt(5.W)) 39 | val vectorReadAddr2 = Input(UInt(5.W)) 40 | val vectorWriteAddr1 = Input(UInt(5.W)) 41 | val vectorWriteData1 = Input(Vec(ML, UInt(DataWidth.W))) 42 | val vectorWriteAddr2 = Input(UInt(5.W)) 43 | val vectorWriteData2 = Input(Vec(ML, UInt(DataWidth.W))) 44 | val vectorWriteEnable1 = Input(Bool()) 45 | val vectorWriteEnable2 = Input(Bool()) 46 | val vectorReadValue1 = Output(Vec(ML, UInt(DataWidth.W))) 47 | val vectorReadValue2 = Output(Vec(ML, UInt(DataWidth.W))) 48 | } 49 | 50 | class VectorRegister2Port extends Module 51 | with HasCommonParameters { 52 | val io = IO(new VectorRegister2PortIO) 53 | 54 | val vectorSRAM = SyncReadMem(32, UInt((ML * DataWidth).W)) 55 | 56 | // Create two write port and two read port. 57 | when (io.vectorWriteEnable1) { 58 | vectorSRAM.write(io.vectorWriteAddr1, io.vectorWriteData1.asUInt()) 59 | } 60 | when (io.vectorWriteEnable2) { 61 | vectorSRAM.write(io.vectorWriteAddr2, io.vectorWriteData2.asUInt()) 62 | } 63 | 64 | for (i <- 0 until ML) { 65 | io.vectorReadValue1(i) := vectorSRAM.read(io.vectorReadAddr1, true.B)(DataWidth * i + DataWidth - 1, DataWidth * i) 66 | io.vectorReadValue2(i) := vectorSRAM.read(io.vectorReadAddr2, true.B)(DataWidth * i + DataWidth - 1, DataWidth * i) 67 | } 68 | } 69 | 70 | object elaborateVectorRegister extends App { 71 | chisel3.Driver.execute(args, () => new VectorRegister()) 72 | } -------------------------------------------------------------------------------- /src/main/scala/utility/ShiftRegs.scala: -------------------------------------------------------------------------------- 1 | 2 | package utility 3 | 4 | import chisel3._ 5 | 6 | // paramized shift register 7 | class ShiftRegs(val n: Int, val w: Int) extends Module { 8 | val io = IO(new Bundle { 9 | val in = Input(UInt(w.W)) 10 | val out = Output(UInt(w.W)) 11 | }) 12 | 13 | val initValues = Seq.fill(n) { 0.U(w.W) } 14 | 15 | val delays = RegInit(VecInit(initValues)) 16 | for (i <- 0 until n) { 17 | if (i == 0) { 18 | delays(0) := io.in 19 | } 20 | else { 21 | delays(i) := delays(i-1) 22 | } 23 | } 24 | io.out := delays(n-1) 25 | } 26 | 27 | // object vecshift using apply method 28 | object ShiftRegs { 29 | def apply (n : Int, w : Int) (in: UInt) : UInt = { 30 | val inst = Module(new ShiftRegs (n, w)) 31 | inst.io.in := in 32 | inst.io.out 33 | } 34 | } 35 | -------------------------------------------------------------------------------- /src/main/scala/utility/SyncFifo.scala: -------------------------------------------------------------------------------- 1 | 2 | package utility 3 | 4 | import chisel3._ 5 | import chisel3.util._ 6 | 7 | class SyncFifo[T <: Data](dep: Int, dataType: T) extends Module { 8 | val io = IO(new Bundle { 9 | val writeData = Input(dataType) 10 | val writeEnable = Input(Bool()) 11 | val readEnable = Input(Bool()) 12 | val readData = Output(dataType) 13 | val readEmpty = Output(Bool()) 14 | val writeFull = Output(Bool()) 15 | }) 16 | 17 | val fifo = SyncReadMem(dep, dataType) 18 | val readPtr = RegInit(0.U((log2Ceil(dep) + 1).W)) 19 | val writePtr = RegInit(0.U((log2Ceil(dep) + 1).W)) 20 | val readAddr = readPtr(log2Ceil(dep)-1, 0) 21 | val writeAddr = writePtr(log2Ceil(dep)-1, 0) 22 | 23 | io.readEmpty := (readPtr === writePtr) 24 | io.writeFull := (readAddr === writeAddr) && (readPtr(log2Ceil(dep)) ^ writePtr(log2Ceil(dep))) 25 | 26 | io.readData := fifo.read(readAddr) 27 | when(io.readEnable && !io.readEmpty) { 28 | readPtr := readPtr + 1.U 29 | } 30 | when(io.writeEnable && !io.writeFull) { 31 | fifo.write(writeAddr, io.writeData) 32 | writePtr := writePtr + 1.U 33 | } 34 | } 35 | 36 | object elaborateSyncFifo extends App { 37 | chisel3.Driver.execute(args, () => new SyncFifo(dep = 32, dataType = UInt(512.W))) 38 | } 39 | -------------------------------------------------------------------------------- /src/main/scala/utility/SyncRam.scala: -------------------------------------------------------------------------------- 1 | 2 | package utility 3 | 4 | import chisel3._ 5 | import chisel3.util._ 6 | 7 | // this is a synchronous-read, synchronous-write memory 8 | 9 | class SyncRam(dep: Int, dw: Int) extends Module{ 10 | 11 | val io = IO(new Bundle { 12 | //control signal 13 | val re = Input(Bool()) 14 | val we = Input(Bool()) 15 | 16 | //data signal 17 | val ra = Input(UInt(log2Ceil(dep).W)) 18 | val wa = Input(UInt(log2Ceil(dep).W)) 19 | val di = Input(UInt(dw.W)) 20 | val dout = Output(UInt(dw.W)) 21 | }) 22 | 23 | // Create a synchronous-read, synchronous-write memory (like in FPGAs). 24 | val mem = SyncReadMem(dep, UInt(dw.W)) 25 | // Create one write port and one read port. 26 | when (io.we) { 27 | mem.write(io.wa, io.di) 28 | } 29 | 30 | io.dout := mem.read(io.ra, io.re) 31 | } 32 | -------------------------------------------------------------------------------- /src/test/scala/VPQC/KeccakTest.scala: -------------------------------------------------------------------------------- 1 | 2 | package VPQC 3 | 4 | import chisel3.iotesters 5 | import chisel3.iotesters.{ChiselFlatSpec, Driver, PeekPokeTester} 6 | 7 | // test value from: SHAKE128 8 | // https://csrc.nist.gov/projects/cryptographic-standards-and-guidelines/example-values 9 | class KeccakTest(c: KeccakCore) extends PeekPokeTester(c) { 10 | var message = new Array[Long](1600 / 64) 11 | for (i <- 0 until 1600 / 64) { 12 | if (i < 21) { 13 | message(i) = 0xA3A3A3A3A3A3A3A3L 14 | } 15 | else { 16 | message(i) = 0L 17 | } 18 | } 19 | for (y <- 0 until 5) { 20 | for (x <- 0 until 5) { 21 | poke(c.io.seed.s(y)(x), message(y * 5 + x)) 22 | } 23 | } 24 | poke(c.io.seedWrite, 1) 25 | step(1) 26 | poke(c.io.seedWrite, 0) 27 | poke(c.io.valid, 1) 28 | step(1) 29 | poke(c.io.valid, 0) 30 | step(24) 31 | 32 | expect(c.io.done, 1) 33 | var res = new Array[BigInt](1600 / 64) 34 | for (y <- 0 until 5) { 35 | for (x <- 0 until 5) { 36 | res(y * 5 + x) = peek(c.io.prngNumber.s(y)(x)) 37 | printf("res[%d] = %x\n", y * 5 + x, res(y * 5 + x)) 38 | } 39 | } 40 | } 41 | object KeccakTestMain extends App { 42 | iotesters.Driver.execute(Array("--backend-name", "verilator"), () => new KeccakCore) { 43 | c => new KeccakTest(c) 44 | } 45 | } 46 | 47 | -------------------------------------------------------------------------------- /src/test/scala/VPQC/NTTTest.scala: -------------------------------------------------------------------------------- 1 | 2 | package VPQC 3 | 4 | import chisel3.iotesters 5 | import chisel3.iotesters.{ChiselFlatSpec, Driver, PeekPokeTester} 6 | import chisel3.util._ 7 | 8 | import scala.util._ 9 | 10 | 11 | //class MRTest(c: MR) extends PeekPokeTester(c) { 12 | // var a = new scala.util.Random 13 | // poke(c.io.n, 14) 14 | // for (i <- 0 until 1000) { 15 | // var aIn = a.nextInt(1 << 28) 16 | // poke(c.io.a, aIn) 17 | // var m = (1 << 13) + a.nextInt(1 << 13) 18 | // poke(c.io.m, m) 19 | // poke(c.io.u, scala.math.pow(2, 29).toLong / m) 20 | // step(1) 21 | // expect(c.io.ar, aIn % m) 22 | // } 23 | //} 24 | // 25 | //object MRTestMain extends App { 26 | // iotesters.Driver.execute(Array(), () => new MR) { 27 | // c => new MRTest(c) 28 | // } 29 | //} 30 | 31 | class BTest(c: Butterfly) extends PeekPokeTester(c) { 32 | var s = new scala.util.Random 33 | poke(c.io.n, 14) 34 | 35 | for (i <- 0 until 1000) { 36 | var m = (1 << 13) + s.nextInt(1 << 13) 37 | poke(c.io.m, m) 38 | poke(c.io.u, scala.math.pow(2, 29).toLong / m) 39 | 40 | var a = s.nextInt(m) 41 | var b = s.nextInt(m) 42 | var wn = s.nextInt(m) 43 | poke(c.io.a, a) 44 | poke(c.io.b, b) 45 | poke(c.io.wn, wn) 46 | 47 | step(1) 48 | var mul = (b * wn) % m 49 | var expectB = 0 50 | if (a < mul) { 51 | expectB = a + m- mul 52 | } else { 53 | expectB = a - mul 54 | } 55 | expect(c.io.aout, (a + mul) % m) 56 | expect(c.io.bout, expectB) 57 | } 58 | } 59 | 60 | object BTestMain extends App { 61 | iotesters.Driver.execute(Array(), () => new Butterfly) { 62 | c => new BTest(c) 63 | } 64 | } 65 | 66 | class PermTest(c: PermNetIn) extends PeekPokeTester(c) 67 | with HasNTTParameters { 68 | var r = scala.util.Random 69 | for (i <- 0 until 4) { 70 | for (j <- 0 until 1) { 71 | for (k <- 0 until ButterflyNum * 2) { 72 | poke(c.io.in(k), r.nextInt(1 << 16)) 73 | } 74 | for (k <- 0 until 4) { 75 | poke(c.io.config(k), 0) 76 | } 77 | poke(c.io.config(i), 1) 78 | print("Config:\t" + peek(c.io.config) + "\n") 79 | for (k <- 0 until ButterflyNum * 2) { 80 | print(k + ":\t" + peek(c.io.in(k)) + "->" + peek(c.io.out(k)) + "\n") 81 | } 82 | } 83 | } 84 | } 85 | 86 | object PermTestMain extends App { 87 | iotesters.Driver.execute(Array(), () => new PermNetIn) { 88 | c => new PermTest(c) 89 | } 90 | } 91 | 92 | class NTTTest(c: NTT) extends PeekPokeTester(c) 93 | with HasNTTParameters { 94 | // test butterfly 95 | 96 | var r = Random 97 | // configure csrs 98 | poke(c.io.mode, 0) 99 | poke(c.io.valid, 0) 100 | 101 | poke(c.io.csrs.csrBinomialk, log2Ceil(8)) 102 | poke(c.io.csrs.csrModulusq, 12289) 103 | poke(c.io.csrs.csrBarretu, scala.math.pow(2, 29).toLong / 12289) 104 | poke(c.io.csrs.csrModulusLen, 14) 105 | poke(c.io.csrs.csrBound, scala.math.pow(2, 32).toLong / 12289 * 12289) 106 | poke(c.io.csrs.csrButterflyCtrl.stageCfg, 0) 107 | poke(c.io.csrs.csrButterflyCtrl.iterCfg, 0) 108 | 109 | // 1: DIF 110 | // 0: DIT 111 | var s = new Array[Int](256) 112 | for (i <- 0 until 8) { 113 | for (j <- 0 until 32) { 114 | s(32*i + j) = r.nextInt() & 0xffff 115 | } 116 | } 117 | 118 | // write const ram 119 | var w256 = new Array[Int](256) 120 | for (i <- 0 until 256) { 121 | if (i == 0) { 122 | w256(i) = 1 123 | } else { 124 | w256(i) = (w256(i-1) * 9) % 12289 125 | } 126 | } 127 | var dataBuf = Array.ofDim[BigInt](256 / ButterflyNum, ButterflyNum) 128 | var idx = 0 129 | var div = 64 130 | for (i <- 0 until 256 / ButterflyNum) { 131 | for (j <- 0 until ButterflyNum) { 132 | if (i == 0) { 133 | dataBuf(i)(0) = w256(0) 134 | dataBuf(i)(1) = w256(0) 135 | dataBuf(i)(2) = w256(0) 136 | dataBuf(i)(3) = w256(256/4) 137 | dataBuf(i)(4) = w256(0) 138 | dataBuf(i)(5) = w256(256/8) 139 | dataBuf(i)(6) = w256(256*2/8) 140 | dataBuf(i)(7) = w256(256*3/8) 141 | dataBuf(i)(8) = w256(0) 142 | dataBuf(i)(9) = w256(256/16) 143 | dataBuf(i)(10) = w256(256*2/16) 144 | dataBuf(i)(11) = w256(256*3/16) 145 | dataBuf(i)(12) = w256(256*4/16) 146 | dataBuf(i)(13) = w256(256*5/16) 147 | dataBuf(i)(14) = w256(256*6/16) 148 | dataBuf(i)(15) = w256(256*7/16) 149 | dataBuf(i)(16) = w256(0) 150 | dataBuf(i)(17) = w256(256/32) 151 | dataBuf(i)(18) = w256(256*2/32) 152 | dataBuf(i)(19) = w256(256*3/32) 153 | dataBuf(i)(20) = w256(256*4/32) 154 | dataBuf(i)(21) = w256(256*5/32) 155 | dataBuf(i)(22) = w256(256*6/32) 156 | dataBuf(i)(23) = w256(256*7/32) 157 | dataBuf(i)(24) = w256(256*8/32) 158 | dataBuf(i)(25) = w256(256*9/32) 159 | dataBuf(i)(26) = w256(256*10/32) 160 | dataBuf(i)(27) = w256(256*11/32) 161 | dataBuf(i)(28) = w256(256*12/32) 162 | dataBuf(i)(29) = w256(256*13/32) 163 | dataBuf(i)(30) = w256(256*14/32) 164 | dataBuf(i)(31) = w256(256*15/32) 165 | } else { 166 | dataBuf(i)(j) = w256(256 * idx / div) 167 | idx = idx + 1 168 | } 169 | } 170 | if (idx == (div / 2)) { 171 | idx = 0 172 | div = div * 2 173 | } 174 | } 175 | 176 | var a:BigInt = 0 177 | var fill_time = 0 178 | for (i <- 0 until 256 / ButterflyNum) { 179 | poke(c.io.wa, i) 180 | a = 0 181 | for (j <- 0 until ButterflyNum) { 182 | a = a | ((dataBuf(i)(j) & 0xffff) << (16 * j)) 183 | } 184 | poke(c.io.di, a) 185 | poke(c.io.we, true) 186 | step(1) 187 | fill_time += 1 188 | } 189 | poke(c.io.we, false) 190 | step(1) 191 | fill_time += 1 192 | 193 | // raw ntt 194 | var len = 256 195 | def ntt(x: Array[Int]): Array[Int] = { 196 | var res = new Array[Int](len) 197 | for (i <- 0 until len) { 198 | res(i) = 0 199 | } 200 | for (i <- 0 until len) { 201 | for (j <- 0 until len) { 202 | res(i) += (x(j) * w256((i*j) % len)) % 12289 203 | } 204 | res(i) = res(i) % 12289 205 | } 206 | res 207 | } 208 | def range(a: Int, upBound: Int, downBound: Int) : Int = { 209 | assert(upBound < 32) 210 | assert(downBound >= 0) 211 | return (a >> downBound) & (0xffffffff >>> (31-upBound+downBound)) 212 | } 213 | def reverse(a: Int, len: Int): Int = { 214 | var res: Int = 0 215 | for(i <- 0 until len) { 216 | res = res | range(a, i, i) << (len-1-i) 217 | } 218 | res 219 | } 220 | var sInOrder = new Array[Int](256) 221 | for (i <- 0 until 256) { 222 | sInOrder(reverse(i, log2Ceil(len))) = s(i) 223 | } 224 | val rawRes = ntt(sInOrder) 225 | 226 | // iterative ntt (assume bit reverse) 227 | var sit = new Array[Int](256) 228 | for (i <- 0 until 256) { 229 | sit(i) = s(i) 230 | } 231 | for(ii <- 1 to 8) { 232 | var i = 1 << ii 233 | if(ButterflyNum >= i/2) { 234 | for(j <- 0 to 255 by 2*ButterflyNum) { 235 | for(k1 <- 0 to 2*ButterflyNum-1 by i) { 236 | for(k2 <- 0 to i/2 - 1) { 237 | var aIn = sit(k1 + k2 + j) 238 | var bIn = sit(k1 + k2 + j + i / 2) 239 | var mul = (bIn * w256(256 * k2 / i)) % 12289 240 | if (aIn < mul) { 241 | sit(k1 + k2 + j + i / 2) = aIn + 12289 - mul 242 | } else { 243 | sit(k1 + k2 + j + i / 2) = aIn - mul 244 | } 245 | sit(k1 + k2 + j) = (aIn + mul) % 12289 246 | } 247 | } 248 | } 249 | } 250 | else { 251 | for(k1 <- 0 to i/2 - 1 by ButterflyNum) { 252 | for (j <- 0 to 255 by i) { 253 | for (k2 <- 0 to ButterflyNum - 1) { 254 | var aIn = sit(k1 + k2 + j) 255 | var bIn = sit(k1 + k2 + j + i / 2) 256 | var mul = (bIn * w256(256 * (k1+k2) / i)) % 12289 257 | if (aIn < mul) { 258 | sit(k1 + k2 + j + i / 2) = aIn + 12289 - mul 259 | } else { 260 | sit(k1 + k2 + j + i / 2) = aIn - mul 261 | } 262 | sit(k1 + k2 + j) = (aIn + mul) % 12289 263 | } 264 | } 265 | } 266 | } 267 | } 268 | 269 | var ntt_time = 0 270 | poke(c.io.mode, 0) 271 | poke(c.io.valid, 1) 272 | for (stage <- 0 until 8) { 273 | var i = 2 << stage 274 | if(stage < 6) { 275 | poke(c.io.valid, 0) 276 | poke(c.io.csrs.csrButterflyCtrl.stageCfg, stage) 277 | poke(c.io.csrs.csrButterflyCtrl.iterCfg, 0) 278 | step(1) 279 | ntt_time += 1 280 | poke(c.io.valid, 1) 281 | for (iter <- 0 until 4) { 282 | for (j <- 0 until 32) { 283 | poke(c.io.vectorReg1(j), s(iter * 2 * 32 + j)) 284 | } 285 | for (j <- 0 until 32) { 286 | poke(c.io.vectorReg2(j), s((iter * 2 + 1) * 32 + j)) 287 | } 288 | step(1) 289 | ntt_time += 1 290 | var res = peek(c.io.dataOut) 291 | for (j <- 0 until 32) { 292 | s(iter * 2 * 32 + j) = res(j).toInt & 0xffff 293 | } 294 | for (j <- 0 until 32) { 295 | s((iter * 2 + 1) * 32 + j) = res(32 + j).toInt & 0xffff 296 | } 297 | } 298 | } 299 | else { 300 | for (iter <- 0 until i / (2 * ButterflyNum)) { 301 | poke(c.io.valid, 0) 302 | poke(c.io.csrs.csrButterflyCtrl.stageCfg, stage) 303 | poke(c.io.csrs.csrButterflyCtrl.iterCfg, iter) 304 | step(1) 305 | ntt_time += 1 306 | poke(c.io.valid, 1) 307 | for (k <- 0 until 256 / i) { 308 | for (j <- 0 until 32) { 309 | poke(c.io.vectorReg1(j), s((2 * k * i / (2 * ButterflyNum) + iter) * 32 + j)) 310 | } 311 | for (j <- 0 until 32) { 312 | poke(c.io.vectorReg2(j), s(((2 * k + 1) * i / (2 * ButterflyNum) + iter) * 32 + j )) 313 | } 314 | step(1) 315 | ntt_time += 1 316 | var res = peek(c.io.dataOut) 317 | for (j <- 0 until 32) { 318 | s((2 * k * i / (2 * ButterflyNum) + iter) * 32 + j) = res(j).toInt & 0xffff 319 | } 320 | for (j <- 0 until 32) { 321 | s(((2 * k + 1) * i / (2 * ButterflyNum) + iter) * 32 + j ) = res(32+j).toInt & 0xffff 322 | } 323 | } 324 | } 325 | } 326 | } 327 | poke(c.io.valid, 0) 328 | step(1) 329 | ntt_time += 1 330 | 331 | // expect 332 | for (i <- 0 until 256) { 333 | // iterative | recursive | harware 334 | printf("%d: 0x%x\t%d: 0x%x\t%d: 0x%x\n", i, sit(i), i, rawRes(i), i, s(i)) 335 | } 336 | 337 | // performance 338 | printf("NTT time(Dimension = 256): %d cycles\n", ntt_time) 339 | } 340 | 341 | object NTTTestSimple extends App { 342 | iotesters.Driver.execute(Array(), () => new NTT) { 343 | c => new NTTTest(c) 344 | } 345 | } 346 | 347 | object NTTTestMain extends App { 348 | iotesters.Driver.execute(Array("--backend-name", "verilator"), () => new NTT) { 349 | c => new NTTTest(c) 350 | } 351 | } 352 | 353 | //class NTTTest(c:NTTR2MDC) extends PeekPokeTester(c) { 354 | // 355 | // var len = 512 356 | // def ntt(x: Array[Int]): Array[Int] = { 357 | // var wn = new Array[Int](len) 358 | // for (i <- 0 until len) { 359 | // if (i == 0){ 360 | // wn(i) = 1 361 | // } else { 362 | // wn(i) = (wn(i-1) * 3) % 12289 363 | // } 364 | // } 365 | // var res = new Array[Int](len) 366 | // for (i <- 0 until len) { 367 | // res(i) = 0 368 | // } 369 | // for (i <- 0 until len) { 370 | // for (j <- 0 until len) { 371 | // res(i) += (x(j) * wn((i*j) % len)) % 12289 372 | // } 373 | // res(i) = res(i) % 12289 374 | // } 375 | // res 376 | // } 377 | // 378 | // def range(a: Int, upBound: Int, downBound: Int) : Int = { 379 | // assert(upBound < 32) 380 | // assert(downBound >= 0) 381 | // return (a >> downBound) & (0xffffffff >>> (31-upBound+downBound)) 382 | // } 383 | // 384 | // def reverse(a: Int, len: Int): Int = { 385 | // var res: Int = 0 386 | // for(i <- 0 until len) { 387 | // res = res | range(a, i, i) << (len-1-i) 388 | // } 389 | // res 390 | // } 391 | // 392 | // var l = 14 393 | // val r = new scala.util.Random 394 | // var bound: Double = math.pow(2.0, l) 395 | // var iterNum: Int = 100 396 | // 397 | // for (t <- 0 until iterNum) { 398 | // var a = new Array[Int](len) 399 | // for (i <- 0 until len) { 400 | // a(i) = r.nextInt(bound.toInt) 401 | // poke(c.io.dIn, a(i) & 0x3fff) 402 | // poke(c.io.dInValid, 1) 403 | // step(1) 404 | // } 405 | // var ref = ntt(a) 406 | // 407 | // for (i <- 0 until len / 2) { 408 | // var ref1 = ref(reverse(i * 2, log2Ceil(len))) 409 | // expect(c.io.dOut1, ref1) 410 | // 411 | // var ref2 = ref(reverse(i * 2 + 1, log2Ceil(len))) 412 | // expect(c.io.dOut2, ref2) 413 | // 414 | // expect(c.io.dOutValid, 1) 415 | // 416 | // a(reverse(i * 2, log2Ceil(len))) = peek(c.io.dOut1).toInt 417 | // a(reverse(i * 2 + 1, log2Ceil(len))) = peek(c.io.dOut2).toInt 418 | // step(1) 419 | // } 420 | //// for (i <- 0 until len) { 421 | //// print(ref(i) + "\n") 422 | //// } 423 | //// for (i <- 0 until len) { 424 | //// print(a(i) + "\n") 425 | //// } 426 | // } 427 | //} 428 | // 429 | //object NTTTestMain extends App { 430 | // iotesters.Driver.execute(Array(), () => new NTTR2MDC) { 431 | // c => new NTTTest(c) 432 | // } 433 | //} -------------------------------------------------------------------------------- /src/test/scala/VPQC/SamplersTest.scala: -------------------------------------------------------------------------------- 1 | 2 | package VPQC 3 | 4 | import chisel3.iotesters 5 | import chisel3.iotesters.PeekPokeTester 6 | import chisel3.util._ 7 | import scala.util._ 8 | 9 | class SamplersTest(c: Samplers) extends PeekPokeTester(c){ 10 | 11 | // write seed 12 | var r = Random 13 | 14 | poke(c.io.csrs.csrBinomialk, log2Ceil(8)) 15 | poke(c.io.csrs.csrModulusq, 12289) 16 | poke(c.io.csrs.csrBarretu, scala.math.pow(2, 29).toLong / 12289) 17 | poke(c.io.csrs.csrModulusLen, 14) 18 | poke(c.io.csrs.csrBound, scala.math.pow(2, 32).toLong / 12289 * 12289) 19 | 20 | // rejection sample test 21 | var iterNum = 16 22 | for (i <- 0 until iterNum) { 23 | poke(c.io.mode, 0) 24 | for (i <- 0 until 32) { 25 | poke(c.io.vectorReg1(i), r.nextInt() & 0xffff) 26 | poke(c.io.vectorReg2(i), r.nextInt() & 0xffff) 27 | } 28 | step(1) 29 | } 30 | 31 | // binomial(gauss) sample test 32 | for (i <- 0 until iterNum) { 33 | poke(c.io.mode, 1) 34 | for (i <- 0 until 32) { 35 | poke(c.io.vectorReg1(i), r.nextInt() & 0xffff) 36 | poke(c.io.vectorReg2(i), r.nextInt() & 0xffff) 37 | } 38 | step(1) 39 | } 40 | } 41 | 42 | object SamplersTestSimple extends App { 43 | iotesters.Driver.execute(Array(), () => new Samplers) { 44 | c => new SamplersTest(c) 45 | } 46 | } 47 | 48 | object SamplersTestMain extends App { 49 | iotesters.Driver.execute(Array("--backend-name", "verilator"), () => new Samplers) { 50 | c => new SamplersTest(c) 51 | } 52 | } 53 | -------------------------------------------------------------------------------- /src/test/scala/VPQC/procTest.scala: -------------------------------------------------------------------------------- 1 | 2 | package VPQC 3 | 4 | import chisel3.iotesters 5 | import chisel3.iotesters.PeekPokeTester 6 | import chisel3.util._ 7 | import scala.util._ 8 | 9 | class ProcTest(c: PQCCoprocessor) extends PeekPokeTester(c) 10 | with HasNTTParameters 11 | with HasCommonParameters 12 | with HasPQCInstructions { 13 | 14 | /* instruction encoding style */ 15 | // func7 rs2 rs1 reserved rd 00010 11 16 | def instr(f: Int, rs1: Int, rs2: Int, rd: Int): Int= { 17 | (f << 25) | (rs2 << 20) | (rs1 << 15) | (rd << 7) | 11 18 | } 19 | 20 | // init 21 | var r = Random 22 | poke(c.io.instr.rs1, 0) 23 | poke(c.io.instr.rs2, 0) 24 | poke(c.io.instr.rd, 0) 25 | poke(c.io.instr.funct, INSTR_FETCHRN) // nop 26 | poke(c.io.in_fire, 0) 27 | 28 | poke(c.io.twiddleData, 0) 29 | poke(c.io.twiddleAddr, 0) 30 | poke(c.io.twiddleWrite, false) 31 | step(1) 32 | 33 | // configure csrs 34 | poke(c.io.in_fire, 1) 35 | poke(c.io.instr.funct, INSTR_CSRRW) 36 | poke(c.io.instr.rs2, 0) 37 | poke(c.io.rs1, scala.math.pow(2, 29).toLong / 12289) 38 | step(1) 39 | poke(c.io.instr.rs2, 1) 40 | poke(c.io.rs1, scala.math.pow(2, 32).toLong / 12289 * 12289) 41 | step(1) 42 | poke(c.io.instr.rs2, 2) 43 | poke(c.io.rs1, log2Ceil(8)) 44 | step(1) 45 | poke(c.io.instr.rs2, 3) 46 | poke(c.io.rs1, 12289) 47 | step(1) 48 | poke(c.io.instr.rs2, 4) 49 | poke(c.io.rs1, 14) 50 | step(1) 51 | poke(c.io.instr.rs2, 5) 52 | poke(c.io.rs1, 0 << 6 | 0) 53 | step(1) 54 | poke(c.io.in_fire, 0) 55 | 56 | // write seed 57 | for (y <- 0 until 5) { 58 | for (x <- 0 until 5) { 59 | poke(c.io.seed.s(y)(x), r.nextLong()) 60 | } 61 | } 62 | poke(c.io.seedWrite, 1) 63 | step(1) 64 | poke(c.io.seedWrite, 0) 65 | 66 | 67 | // waite for the buffer-filling 68 | var sample_time = 0 69 | step(24 * 16 + 1) 70 | sample_time += 24*16 + 1 71 | 72 | // test fetchrn 73 | poke(c.io.in_fire, 1) 74 | poke(c.io.instr.funct, INSTR_FETCHRN) 75 | poke(c.io.instr.rs1, 0) 76 | poke(c.io.instr.rs2, 0) 77 | poke(c.io.instr.rd, 0) 78 | 79 | for (i <- 0 until 16) { 80 | poke(c.io.instr.rd, i) 81 | step(1) 82 | sample_time += 1 83 | } 84 | poke(c.io.in_fire, 0) 85 | step(1) 86 | sample_time += 1 87 | 88 | // test sample 89 | poke(c.io.in_fire, 1) 90 | poke(c.io.instr.funct, INSTR_SAMPLEBINOMIAL) 91 | 92 | for (i <- 0 until 8) { 93 | poke(c.io.instr.rs1, 2 * i) 94 | poke(c.io.instr.rs2, 2 * i + 1) 95 | poke(c.io.instr.rd, i) 96 | step(1) 97 | sample_time += 1 98 | } 99 | poke(c.io.in_fire, 0) 100 | step(1) 101 | sample_time += 1 102 | 103 | // peek sampled value 104 | var s = new Array[Int](256) 105 | poke(c.io.in_fire, 1) 106 | for (i <- 0 until 8) { 107 | poke(c.io.instr.funct, INSTR_VST) 108 | poke(c.io.instr.rs1, i) 109 | step(1) 110 | for (j <- 0 until 32) { 111 | s(32*i + j) = peek(c.io.vectorOut(j)).toInt 112 | } 113 | } 114 | poke(c.io.in_fire, 0) 115 | 116 | // write const ram 117 | var w256 = new Array[Int](256) 118 | for (i <- 0 until 256) { 119 | if (i == 0) { 120 | w256(i) = 1 121 | } else { 122 | w256(i) = (w256(i-1) * 9) % 12289 123 | } 124 | } 125 | 126 | var dataBuf = Array.ofDim[BigInt](256 / ButterflyNum, ButterflyNum) 127 | var idx = 0 128 | var div = 64 129 | for (i <- 0 until 256 / ButterflyNum) { 130 | for (j <- 0 until ButterflyNum) { 131 | if (i == 0) { 132 | dataBuf(i)(0) = w256(0) 133 | dataBuf(i)(1) = w256(0) 134 | dataBuf(i)(2) = w256(0) 135 | dataBuf(i)(3) = w256(256/4) 136 | dataBuf(i)(4) = w256(0) 137 | dataBuf(i)(5) = w256(256/8) 138 | dataBuf(i)(6) = w256(256*2/8) 139 | dataBuf(i)(7) = w256(256*3/8) 140 | dataBuf(i)(8) = w256(0) 141 | dataBuf(i)(9) = w256(256/16) 142 | dataBuf(i)(10) = w256(256*2/16) 143 | dataBuf(i)(11) = w256(256*3/16) 144 | dataBuf(i)(12) = w256(256*4/16) 145 | dataBuf(i)(13) = w256(256*5/16) 146 | dataBuf(i)(14) = w256(256*6/16) 147 | dataBuf(i)(15) = w256(256*7/16) 148 | dataBuf(i)(16) = w256(0) 149 | dataBuf(i)(17) = w256(256/32) 150 | dataBuf(i)(18) = w256(256*2/32) 151 | dataBuf(i)(19) = w256(256*3/32) 152 | dataBuf(i)(20) = w256(256*4/32) 153 | dataBuf(i)(21) = w256(256*5/32) 154 | dataBuf(i)(22) = w256(256*6/32) 155 | dataBuf(i)(23) = w256(256*7/32) 156 | dataBuf(i)(24) = w256(256*8/32) 157 | dataBuf(i)(25) = w256(256*9/32) 158 | dataBuf(i)(26) = w256(256*10/32) 159 | dataBuf(i)(27) = w256(256*11/32) 160 | dataBuf(i)(28) = w256(256*12/32) 161 | dataBuf(i)(29) = w256(256*13/32) 162 | dataBuf(i)(30) = w256(256*14/32) 163 | dataBuf(i)(31) = w256(256*15/32) 164 | } else { 165 | dataBuf(i)(j) = w256(256 * idx / div) 166 | idx = idx + 1 167 | } 168 | } 169 | if (idx == (div / 2)) { 170 | idx = 0 171 | div = div * 2 172 | } 173 | } 174 | 175 | var a:BigInt = 0 176 | for (i <- 0 until 256 / ButterflyNum) { 177 | poke(c.io.twiddleAddr, i) 178 | a = 0 179 | for (j <- 0 until ButterflyNum) { 180 | a = a | ((dataBuf(i)(j) & 0xffff) << (16 * j)) 181 | } 182 | poke(c.io.twiddleData, a) 183 | poke(c.io.twiddleWrite, true) 184 | step(1) 185 | } 186 | poke(c.io.twiddleWrite, false) 187 | step(1) 188 | 189 | // raw ntt 190 | var len = 256 191 | def ntt(x: Array[Int]): Array[Int] = { 192 | var res = new Array[Int](len) 193 | for (i <- 0 until len) { 194 | res(i) = 0 195 | } 196 | for (i <- 0 until len) { 197 | for (j <- 0 until len) { 198 | res(i) += (x(j) * w256((i*j) % len)) % 12289 199 | } 200 | res(i) = res(i) % 12289 201 | } 202 | res 203 | } 204 | def range(a: Int, upBound: Int, downBound: Int) : Int = { 205 | assert(upBound < 32) 206 | assert(downBound >= 0) 207 | return (a >> downBound) & (0xffffffff >>> (31-upBound+downBound)) 208 | } 209 | def reverse(a: Int, len: Int): Int = { 210 | var res: Int = 0 211 | for(i <- 0 until len) { 212 | res = res | range(a, i, i) << (len-1-i) 213 | } 214 | res 215 | } 216 | var sInOrder = new Array[Int](256) 217 | for (i <- 0 until 256) { 218 | sInOrder(reverse(i, log2Ceil(len))) = s(i) 219 | } 220 | val rawRes = ntt(sInOrder) 221 | 222 | // iterative ntt (assume bit reverse) 223 | for(ii <- 1 to 8) { 224 | var i = 1 << ii 225 | if(ButterflyNum >= i/2) { 226 | for(j <- 0 to 255 by 2*ButterflyNum) { 227 | for(k1 <- 0 to 2*ButterflyNum-1 by i) { 228 | for(k2 <- 0 to i/2 - 1) { 229 | var aIn = s(k1 + k2 + j) 230 | var bIn = s(k1 + k2 + j + i / 2) 231 | var mul = (bIn * w256(256 * k2 / i)) % 12289 232 | if (aIn < mul) { 233 | s(k1 + k2 + j + i / 2) = aIn + 12289 - mul 234 | } else { 235 | s(k1 + k2 + j + i / 2) = aIn - mul 236 | } 237 | s(k1 + k2 + j) = (aIn + mul) % 12289 238 | } 239 | } 240 | } 241 | } 242 | else { 243 | for(k1 <- 0 to i/2 - 1 by ButterflyNum) { 244 | for (j <- 0 to 255 by i) { 245 | for (k2 <- 0 to ButterflyNum - 1) { 246 | var aIn = s(k1 + k2 + j) 247 | var bIn = s(k1 + k2 + j + i / 2) 248 | var mul = (bIn * w256(256 * (k1+k2) / i)) % 12289 249 | if (aIn < mul) { 250 | s(k1 + k2 + j + i / 2) = aIn + 12289 - mul 251 | } else { 252 | s(k1 + k2 + j + i / 2) = aIn - mul 253 | } 254 | s(k1 + k2 + j) = (aIn + mul) % 12289 255 | } 256 | } 257 | } 258 | } 259 | } 260 | 261 | // test butterfly 262 | var ntt_time = 0 263 | poke(c.io.in_fire, 1) 264 | for (stage <- 0 until 8) { 265 | var i = 2 << stage 266 | if(stage < 6) { 267 | poke(c.io.instr.funct, INSTR_CSRRW) 268 | poke(c.io.instr.rs2, 5) 269 | poke(c.io.rs1, stage << 6 | 0) 270 | step(1) 271 | ntt_time += 1 272 | poke(c.io.instr.funct, INSTR_BUTTERFLY) 273 | for (iter <- 0 until 4) { 274 | poke(c.io.instr.rs1, iter * 2) 275 | poke(c.io.instr.rs2, iter * 2 + 1) 276 | step(1) 277 | ntt_time += 1 278 | } 279 | } 280 | else { 281 | for (iter <- 0 until i / (2 * ButterflyNum)) { 282 | poke(c.io.instr.funct, INSTR_CSRRW) 283 | poke(c.io.instr.rs2, 5) 284 | poke(c.io.rs1, stage << 6 | iter) 285 | // simulate CSRRW 286 | step(1) 287 | ntt_time += 1 288 | poke(c.io.instr.funct, INSTR_BUTTERFLY) 289 | for (k <- 0 until 256 / i) { 290 | poke(c.io.instr.rs1, 2 * k * i / (2 * ButterflyNum) + iter) 291 | poke(c.io.instr.rs2, (2 * k + 1) * i / (2 * ButterflyNum) + iter) 292 | step(1) 293 | ntt_time += 1 294 | } 295 | } 296 | } 297 | } 298 | poke(c.io.in_fire, 0) 299 | step(1) 300 | ntt_time += 1 301 | 302 | // peek ntt value 303 | var nttRes = new Array[Int](256) 304 | poke(c.io.in_fire, 1) 305 | for (i <- 0 until 8) { 306 | poke(c.io.instr.funct, INSTR_VST) 307 | poke(c.io.instr.rs1, i) 308 | step(1) 309 | for (j <- 0 until 32) { 310 | expect(c.io.vectorOut(j), s(32 * i + j)) 311 | nttRes(32 * i + j) = peek(c.io.vectorOut(j)).toInt 312 | } 313 | } 314 | poke(c.io.in_fire, 0) 315 | 316 | // expect 317 | for (i <- 0 until 256) { 318 | // iterative | recursive | harware 319 | printf("%d: 0x%x\t%d: 0x%x\t%d: 0x%x\n", i, s(i), i, rawRes(i), i, nttRes(i)) 320 | } 321 | // performance 322 | printf("Binomial Sampling time(Dimension = 256): %d cycles\n", sample_time) 323 | printf("NTT time(Dimension = 256): %d cycles\n", ntt_time) 324 | } 325 | 326 | object TestTopTestSimple extends App { 327 | iotesters.Driver.execute(Array(), () => new PQCCoprocessor) { 328 | c => new ProcTest(c) 329 | } 330 | } 331 | 332 | object TestTopTestMain extends App { 333 | iotesters.Driver.execute(Array("--backend-name", "verilator"), () => new PQCCoprocessor) { 334 | c => new ProcTest(c) 335 | } 336 | } 337 | --------------------------------------------------------------------------------