├── .gitignore
├── README.md
├── build.sbt
├── project
├── build.properties
└── plugins.sbt
├── scalastyle-config.xml
├── scalastyle-test-config.xml
└── src
├── main
└── scala
│ ├── VPQC
│ ├── CommonParameters.scala
│ ├── KeccakCore.scala
│ ├── KeccakParameters.scala
│ ├── ModularArithmetic.scala
│ ├── NTT.scala
│ ├── NTTParameters.scala
│ ├── PQCCoprocessor.scala
│ ├── PQCDecode.scala
│ ├── PQCExu.scala
│ ├── SamplerParameters.scala
│ ├── Samplers.scala
│ └── VectorRegister.scala
│ └── utility
│ ├── ShiftRegs.scala
│ ├── SyncFifo.scala
│ └── SyncRam.scala
└── test
└── scala
└── VPQC
├── KeccakTest.scala
├── NTTTest.scala
├── SamplersTest.scala
└── procTest.scala
/.gitignore:
--------------------------------------------------------------------------------
1 | ### Project Specific stuff
2 | test_run_dir/*
3 | *.fir
4 | *.anno.json
5 | *.v
6 | ### XilinxISE template
7 | # intermediate build files
8 | *.bgn
9 | *.bit
10 | *.bld
11 | *.cmd_log
12 | *.drc
13 | *.ll
14 | *.lso
15 | *.msd
16 | *.msk
17 | *.ncd
18 | *.ngc
19 | *.ngd
20 | *.ngr
21 | *.pad
22 | *.par
23 | *.pcf
24 | *.prj
25 | *.ptwx
26 | *.rbb
27 | *.rbd
28 | *.stx
29 | *.syr
30 | *.twr
31 | *.twx
32 | *.unroutes
33 | *.ut
34 | *.xpi
35 | *.xst
36 | *_bitgen.xwbt
37 | *_envsettings.html
38 | *_map.map
39 | *_map.mrp
40 | *_map.ngm
41 | *_map.xrpt
42 | *_ngdbuild.xrpt
43 | *_pad.csv
44 | *_pad.txt
45 | *_par.xrpt
46 | *_summary.html
47 | *_summary.xml
48 | *_usage.xml
49 | *_xst.xrpt
50 |
51 | # project-wide generated files
52 | *.gise
53 | par_usage_statistics.html
54 | usage_statistics_webtalk.html
55 | webtalk.log
56 | webtalk_pn.xml
57 |
58 | # generated folders
59 | iseconfig/
60 | xlnx_auto_0_xdb/
61 | xst/
62 | _ngo/
63 | _xmsgs/
64 | ### Eclipse template
65 | *.pydevproject
66 | .metadata
67 | .gradle
68 | bin/
69 | tmp/
70 | *.tmp
71 | *.bak
72 | *.swp
73 | *~.nib
74 | local.properties
75 | .settings/
76 | .loadpath
77 |
78 | # Eclipse Core
79 | .project
80 |
81 | # External tool builders
82 | .externalToolBuilders/
83 |
84 | # Locally stored "Eclipse launch configurations"
85 | *.launch
86 |
87 | # CDT-specific
88 | .cproject
89 |
90 | # JDT-specific (Eclipse Java Development Tools)
91 | .classpath
92 |
93 | # Java annotation processor (APT)
94 | .factorypath
95 |
96 | # PDT-specific
97 | .buildpath
98 |
99 | # sbteclipse plugin
100 | .target
101 |
102 | # TeXlipse plugin
103 | .texlipse
104 | ### C template
105 | # Object files
106 | *.o
107 | *.ko
108 | *.obj
109 | *.elf
110 |
111 | # Precompiled Headers
112 | *.gch
113 | *.pch
114 |
115 | # Libraries
116 | *.lib
117 | *.a
118 | *.la
119 | *.lo
120 |
121 | # Shared objects (inc. Windows DLLs)
122 | *.dll
123 | *.so
124 | *.so.*
125 | *.dylib
126 |
127 | # Executables
128 | *.exe
129 | *.out
130 | *.app
131 | *.i*86
132 | *.x86_64
133 | *.hex
134 |
135 | # Debug files
136 | *.dSYM/
137 | ### SBT template
138 | # Simple Build Tool
139 | # http://www.scala-sbt.org/release/docs/Getting-Started/Directories.html#configuring-version-control
140 |
141 | target/
142 | lib_managed/
143 | src_managed/
144 | project/boot/
145 | .history
146 | .cache
147 | ### Emacs template
148 | # -*- mode: gitignore; -*-
149 | *~
150 | \#*\#
151 | /.emacs.desktop
152 | /.emacs.desktop.lock
153 | *.elc
154 | auto-save-list
155 | tramp
156 | .\#*
157 |
158 | # Org-mode
159 | .org-id-locations
160 | *_archive
161 |
162 | # flymake-mode
163 | *_flymake.*
164 |
165 | # eshell files
166 | /eshell/history
167 | /eshell/lastdir
168 |
169 | # elpa packages
170 | /elpa/
171 |
172 | # reftex files
173 | *.rel
174 |
175 | # AUCTeX auto folder
176 | /auto/
177 |
178 | # cask packages
179 | .cask/
180 | ### Vim template
181 | [._]*.s[a-w][a-z]
182 | [._]s[a-w][a-z]
183 | *.un~
184 | Session.vim
185 | .netrwhist
186 | *~
187 | ### JetBrains template
188 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio
189 |
190 | *.iml
191 |
192 | ## Directory-based project format:
193 | .idea/
194 | # if you remove the above rule, at least ignore the following:
195 |
196 | # User-specific stuff:
197 | # .idea/workspace.xml
198 | # .idea/tasks.xml
199 | # .idea/dictionaries
200 |
201 | # Sensitive or high-churn files:
202 | # .idea/dataSources.ids
203 | # .idea/dataSources.xml
204 | # .idea/sqlDataSources.xml
205 | # .idea/dynamic.xml
206 | # .idea/uiDesigner.xml
207 |
208 | # Gradle:
209 | # .idea/gradle.xml
210 | # .idea/libraries
211 |
212 | # Mongo Explorer plugin:
213 | # .idea/mongoSettings.xml
214 |
215 | ## File-based project format:
216 | *.ipr
217 | *.iws
218 |
219 | ## Plugin-specific files:
220 |
221 | # IntelliJ
222 | /out/
223 |
224 | # mpeltonen/sbt-idea plugin
225 | .idea_modules/
226 |
227 | # JIRA plugin
228 | atlassian-ide-plugin.xml
229 |
230 | # Crashlytics plugin (for Android Studio and IntelliJ)
231 | com_crashlytics_export_strings.xml
232 | crashlytics.properties
233 | crashlytics-build.properties
234 | ### C++ template
235 | # Compiled Object files
236 | *.slo
237 | *.lo
238 | *.o
239 | *.obj
240 |
241 | # Precompiled Headers
242 | *.gch
243 | *.pch
244 |
245 | # Compiled Dynamic libraries
246 | *.so
247 | *.dylib
248 | *.dll
249 |
250 | # Fortran module files
251 | *.mod
252 |
253 | # Compiled Static libraries
254 | *.lai
255 | *.la
256 | *.a
257 | *.lib
258 |
259 | # Executables
260 | *.exe
261 | *.out
262 | *.app
263 | ### OSX template
264 | .DS_Store
265 | .AppleDouble
266 | .LSOverride
267 |
268 | # Icon must end with two \r
269 | Icon
270 |
271 | # Thumbnails
272 | ._*
273 |
274 | # Files that might appear in the root of a volume
275 | .DocumentRevisions-V100
276 | .fseventsd
277 | .Spotlight-V100
278 | .TemporaryItems
279 | .Trashes
280 | .VolumeIcon.icns
281 |
282 | # Directories potentially created on remote AFP share
283 | .AppleDB
284 | .AppleDesktop
285 | Network Trash Folder
286 | Temporary Items
287 | .apdisk
288 | ### Xcode template
289 | # Xcode
290 | #
291 | # gitignore contributors: remember to update Global/Xcode.gitignore, Objective-C.gitignore & Swift.gitignore
292 |
293 | ## Build generated
294 | build/
295 | DerivedData
296 |
297 | ## Various settings
298 | *.pbxuser
299 | !default.pbxuser
300 | *.mode1v3
301 | !default.mode1v3
302 | *.mode2v3
303 | !default.mode2v3
304 | *.perspectivev3
305 | !default.perspectivev3
306 | xcuserdata
307 |
308 | ## Other
309 | *.xccheckout
310 | *.moved-aside
311 | *.xcuserstate
312 | ### Scala template
313 | *.class
314 | *.log
315 |
316 | # sbt specific
317 | .cache
318 | .history
319 | .lib/
320 | dist/*
321 | target/
322 | lib_managed/
323 | src_managed/
324 | project/boot/
325 | project/plugins/project/
326 |
327 | # Scala-IDE specific
328 | .scala_dependencies
329 | .worksheet
330 | ### Java template
331 | *.class
332 |
333 | # Mobile Tools for Java (J2ME)
334 | .mtj.tmp/
335 |
336 | # Package Files #
337 | *.jar
338 | *.war
339 | *.ear
340 |
341 | # virtual machine crash logs, see http://www.java.com/en/download/help/error_hotspot.xml
342 | hs_err_pid*
343 |
344 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ### VPQC
2 | This work present a vector processor for Ring-LWE and Module-LWE schemes in post-quantum cryptography
3 | and it is implemented using Chisel, an agile hardware construction language developed by UC Berkeley.
4 | The instruction set of the processor is based on the customized extension of RISC-V, so it could
5 | be served as a coprocessor for RISC-V cores. The customed instructions are defined as follows.
6 |
7 | | **Instr \ Bit field** | **31:25** | **24:20** | **19:15** | **14:12** | **11:7** | **6:0** | **Description** |
8 | | ----------------------- | :-----: | :-----: | :-----: | :-----: | :----: | :-----: | :-------------------------------------: |
9 | | fetchData | 0000000 | - | - | - | vd | 0001011 | v[vd] <- random data from prefetch FIFO |
10 | | binomial sampling | 0000001 | vs2 | vs1 | - | vd | 0001011 | v[vd] <- vector binomial sample (v[vs1],v[vs2]) |
11 | | rejection sampling | 0000010 | vs2 | vs1 | - | vd | 0001011 | v[vd] <- vector rejection sample (v[vs1],v[vs2]) |
12 | | DIT butterfly | 0000011 | vs2 | vs1 | - | - | 0001011 | (v[vs1], v[vs2]) <- vector DIT butterfly (v[vs1], v[vs2]) |
13 | | DIF butterfly | 0000100 | vs2 | vs1 | - | - | 0001011 | (v[vs1], v[vs2]) <- vector DIF butterfly (v[vs1], v[vs2]) |
14 | | csrrw | 0000101 | csridx | vs1 | - | - | 0001011 | swap value (csr[csridx], r[vs1]) |
15 | | csrrwi | 0000110 | csridx | imm[9:5] | - | imm[4:0] | 0001011 | csr[csridx] <- imm |
16 | | vld | 0000111 | - | - | - | vd | 0001011 | v[vd] <- memory (addr) |
17 | | vst | 0001000 | - | vs1 | - | - | 0001011 | memory (addr) <- v[vs1] |
18 | | vadd | 0001001 | vs2 | vs1 | - | vd | 0001011 | v[vd] <- vector addition (v[vs1], v[vs2]) |
19 | | vsub | 0001010 | vs2 | vs1 | - | vd | 0001011 | v[vd] <- vector subtraction (v[vs1], v[vs2]) |
20 | | vmul | 0001011 | vs2 | vs1 | - | vd | 0001011 | v[vd] <- vector multiplication (v[vs1], v[vs2]) |
21 |
22 | This project follows the style of file organizations in *sbt* (scala build tools), see
23 | https://www.scala-sbt.org/ for more information about *sbt*.
24 |
25 | The source codes are located in */src/main/scala/* and the test codes are located in */src/test/scala/VPQC/*.
26 |
27 | #### performance of NTT and sampling process
28 | [IntelliJ IDEA](https://www.jetbrains.com/idea/) is a powerful IDE for scala, and it is quite convenient to use IntelliJ IDEA to build this project.
29 |
30 | Run the object *TestTopTestSimple* in */src/test/scala/VPQC/procTest*, and the result shows:
31 |
32 | When the dimension = 256, binomial/rejection sampling can be finished in 411 cycles and
33 | NTT can be finished in 45 cycles.
34 |
35 | Or you could run the object *TestTopTestMain* to obtain more information about the waveform generated by Verilator.
36 | #### For more information about Chisel, see below:
37 | https://www.chisel-lang.org/
38 | #### For project template for Chisel, please refer to:
39 | https://github.com/freechipsproject/chisel-template
40 | #### Our source code also used a third-party chisel library for Finite State Machine, please refer to:
41 | https://github.com/dai-pch/FSM
42 |
--------------------------------------------------------------------------------
/build.sbt:
--------------------------------------------------------------------------------
1 | // See README.md for license details.
2 |
3 | def scalacOptionsVersion(scalaVersion: String): Seq[String] = {
4 | Seq() ++ {
5 | // If we're building with Scala > 2.11, enable the compile option
6 | // switch to support our anonymous Bundle definitions:
7 | // https://github.com/scala/bug/issues/10047
8 | CrossVersion.partialVersion(scalaVersion) match {
9 | case Some((2, scalaMajor: Long)) if scalaMajor < 12 => Seq()
10 | case _ => Seq("-Xsource:2.11")
11 | }
12 | }
13 | }
14 |
15 | def javacOptionsVersion(scalaVersion: String): Seq[String] = {
16 | Seq() ++ {
17 | // Scala 2.12 requires Java 8. We continue to generate
18 | // Java 7 compatible code for Scala 2.11
19 | // for compatibility with old clients.
20 | CrossVersion.partialVersion(scalaVersion) match {
21 | case Some((2, scalaMajor: Long)) if scalaMajor < 12 =>
22 | Seq("-source", "1.7", "-target", "1.7")
23 | case _ =>
24 | Seq("-source", "1.8", "-target", "1.8")
25 | }
26 | }
27 | }
28 |
29 | name := "chisel-module-template"
30 |
31 | version := "3.1.0"
32 |
33 | scalaVersion := "2.11.12"
34 |
35 | crossScalaVersions := Seq("2.11.12", "2.12.4")
36 |
37 | resolvers ++= Seq(
38 | Resolver.sonatypeRepo("snapshots"),
39 | Resolver.sonatypeRepo("releases")
40 | )
41 |
42 | // Provide a managed dependency on X if -DXVersion="" is supplied on the command line.
43 | val defaultVersions = Map(
44 | "chisel3" -> "3.1.+",
45 | "chisel-iotesters" -> "1.2.+"
46 | )
47 |
48 | libraryDependencies ++= (Seq("chisel3","chisel-iotesters").map {
49 | dep: String => "edu.berkeley.cs" %% dep % sys.props.getOrElse(dep + "Version", defaultVersions(dep)) })
50 |
51 | libraryDependencies += "org.daipch" %% "fsm" % "0.3.+"
52 |
53 | scalacOptions ++= scalacOptionsVersion(scalaVersion.value)
54 |
55 | javacOptions ++= javacOptionsVersion(scalaVersion.value)
56 |
--------------------------------------------------------------------------------
/project/build.properties:
--------------------------------------------------------------------------------
1 | sbt.version = 1.3.2
2 |
--------------------------------------------------------------------------------
/project/plugins.sbt:
--------------------------------------------------------------------------------
1 | logLevel := Level.Warn
--------------------------------------------------------------------------------
/scalastyle-config.xml:
--------------------------------------------------------------------------------
1 |
2 | Scalastyle standard configuration
3 |
4 |
5 |
6 |
7 |
8 |
9 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 | No lines ending with a ;
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 |
98 |
99 | |\|\||&&|:=|<>|<=|>=|!=|===|<<|>>|##|unary_(~|\-%?|!))$]]>
100 |
101 |
102 |
103 |
104 |
105 |
106 |
107 |
108 |
109 |
110 |
111 |
--------------------------------------------------------------------------------
/scalastyle-test-config.xml:
--------------------------------------------------------------------------------
1 |
2 | Scalastyle configuration for Chisel3 unit tests
3 |
4 |
5 |
6 |
7 |
8 |
9 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 | No lines ending with a ;
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 |
98 | |\|\||&&|:=|<>|<=|>=|!=|===|<<|>>|##|unary_(~|\-%?|!))$]]>
99 |
100 |
101 |
102 |
103 |
104 |
105 |
106 |
107 |
108 |
109 |
110 |
--------------------------------------------------------------------------------
/src/main/scala/VPQC/CommonParameters.scala:
--------------------------------------------------------------------------------
1 |
2 | package VPQC
3 |
4 | trait HasCommonParameters {
5 | val ML = 32
6 | val DataWidth = 16
7 | }
8 |
9 | //trait HasPQCCSR{
10 | // val CSRWIDTH : Int = 64
11 | // val CSRLENGTH : Int = 6 // 这里修改CSR寄存器的个数
12 | //
13 | // // 把每个CSR的名字和编号写在下面,必须以CSR_开头
14 | // val csrBarretu = Input(UInt((DataWidth + 2).W))
15 | // val csrBound = Input(UInt((DataWidth * 2).W))
16 | // val csrBinomialk = Input(UInt(3.W))
17 | // val csrModulusq = Input(UInt(DataWidth.W))
18 | // val csrModulusLen = Input(UInt(5.W))
19 | //}
20 |
21 | trait HasPQCInstructions {
22 | val INSTR_QUANTITY = 12 // number of custom instructions
23 |
24 | /* instruction encoding style */
25 | // func7 rs2 rs1 reserved rd 00010 11
26 |
27 | // PQC
28 | val INSTR_FETCHRN = 0
29 | val INSTR_SAMPLEBINOMIAL = 1
30 | val INSTR_SAMPLEREJECTION = 2
31 | val INSTR_BUTTERFLY = 3
32 | val INSTR_IBUTTERFLY = 4
33 | val INSTR_CSRRW = 5
34 | val INSTR_CSRRWI = 6
35 | val INSTR_VLD = 7
36 | val INSTR_VST = 8
37 |
38 | val INSTR_VADD = 9
39 | val INSTR_VSUB = 10
40 | val INSTR_VMUL = 11
41 | //
42 | // // LOAD STORE
43 | // val INSTR_VLD = 7
44 | // val INSTR_VLDS = 8
45 | // val INSTR_VLDX = 9
46 | // val INSTR_VST = 10
47 | // val INSTR_VSTS = 11
48 | // val INSTR_VSTX = 12
49 |
50 | }
51 |
--------------------------------------------------------------------------------
/src/main/scala/VPQC/KeccakCore.scala:
--------------------------------------------------------------------------------
1 |
2 | package VPQC
3 |
4 | import chisel3._
5 | import chisel3.util._
6 | import utility.SyncFifo
7 | import fsm._
8 |
9 | /**
10 | *
11 | * Description : implement Keccakc, that is Keccak-p[1600,24]
12 | * look FIPS PUB 202 for more information
13 | *
14 | **/
15 |
16 | /**
17 | *
18 | * index = (y,x,z)
19 | *
20 | **/
21 | class stateArray extends Bundle {
22 | val s = Vec(5, Vec(5, (UInt(64.W))))
23 | }
24 |
25 |
26 | class ProcIO extends Bundle
27 | with HasKeccakParameters {
28 | val in = Input(new stateArray())
29 | val out = Output(new stateArray())
30 | }
31 |
32 | class Proc1 extends Module
33 | with HasKeccakParameters {
34 | val io = IO(new ProcIO)
35 |
36 | val state = io.in.s
37 | val columnsXorSqueez = WireInit(VecInit(Seq.fill(5)(VecInit(Seq.fill(64)(false.B)))))
38 | // step 1
39 | for (x <- 0 until 5) {
40 | for (z <- 0 until 64) {
41 | columnsXorSqueez(x)(z) := (0 until 5).map(y => state(y)(x)(z)).reduce(_ ^ _)
42 | }
43 | }
44 | val columnsXorInter = WireInit(VecInit(Seq.fill(5)(VecInit(Seq.fill(64)(false.B)))))
45 | // step 2
46 | for (x <- 0 until 5) {
47 | for (z <- 0 until 64) {
48 | columnsXorInter(x)(z) := columnsXorSqueez(mod(x - 1, 5))(z) ^ columnsXorSqueez(mod(x + 1, 5))(mod(z - 1, 64))
49 | }
50 | }
51 | // step 3
52 | for (y <- 0 until 5) {
53 | for (x <- 0 until 5) {
54 | io.out.s(y)(x) := state(y)(x) ^ columnsXorInter(x).asUInt()
55 | }
56 | }
57 | }
58 |
59 | class Proc2 extends Module
60 | with HasKeccakParameters {
61 | val io = IO(new ProcIO())
62 |
63 | val state = io.in.s
64 | for (y <- 0 until 5) {
65 | for (x <- 0 until 5) {
66 | var shamt = mod(proc2param(y)(x), 64)
67 | if (shamt == 0) {
68 | io.out.s(y)(x) := state(y)(x)
69 | } else {
70 | io.out.s(y)(x) := Cat(state(y)(x)(63 - shamt, 0), state(y)(x)(63, 64 - shamt))
71 | }
72 | }
73 | }
74 | }
75 |
76 | class Proc3 extends Module
77 | with HasKeccakParameters {
78 | val io = IO(new ProcIO())
79 |
80 | val state = io.in.s
81 | for (y <- 0 until 5) {
82 | for (x <- 0 until 5) {
83 | io.out.s(y)(x) := state(x)(mod(x + 3 * y, 5))
84 | }
85 | }
86 | }
87 |
88 | class Proc4 extends Module
89 | with HasKeccakParameters {
90 | val io = IO(new ProcIO())
91 |
92 | val state = io.in.s
93 | for (y <- 0 until 5) {
94 | for (x <- 0 until 5) {
95 | io.out.s(y)(x) := (~state(y)(mod(x + 1, 5)) & state(y)(mod(x + 2, 5))) ^ state(y)(x)
96 | }
97 | }
98 | }
99 |
100 | class Proc5(ir: Int) extends Module
101 | with HasKeccakParameters {
102 | val io = IO(new ProcIO())
103 |
104 | val state = io.in.s
105 | val RC = WireInit(VecInit(Seq.fill(64)(false.B)))
106 | for (j <- 0 to 6) {
107 | RC((1 << j) - 1) := rc(j + ir * 7).asUInt()
108 | }
109 | val nextState = Wire(new stateArray())
110 | nextState.s := state
111 | nextState.s(0)(0) := state(0)(0) ^ RC.asUInt()
112 | io.out.s := nextState.s
113 | }
114 |
115 | class KeccakCoreIO extends Bundle
116 | with HasKeccakParameters {
117 | val valid = Input(Bool())
118 | val seed = Input(new stateArray())
119 | val seedWrite = Input(Bool())
120 | val prngNumber = Output(new stateArray())
121 | val done = Output(Bool())
122 | val initialized = Output(Bool())
123 | }
124 |
125 | /**
126 | * cnt valid state done
127 | * 0 0 0
128 | * 0 1
129 | * 1 done care
130 | * .
131 | * .
132 | * .
133 | * 24 end point write to final value
134 | * 0 1
135 | */
136 | class KeccakCore extends Module
137 | with HasKeccakParameters {
138 | val io = IO(new KeccakCoreIO())
139 |
140 | // count control
141 | val cnt = RegInit(0.U(5.W))
142 | val busy = cnt =/= 0.U
143 | val countBegin = io.valid || busy
144 | val state = Wire(Vec(5, new stateArray()))
145 | val stateReg = Reg((new stateArray()))
146 | val completeFlag = cnt === RoundsNum.asUInt()
147 | val initialized = RegInit(false.B)
148 |
149 | when(countBegin) {
150 | cnt := Mux(completeFlag, 0.U, cnt + 1.U)
151 | }
152 |
153 | when(io.seedWrite) {
154 | stateReg := io.seed
155 | initialized := true.B
156 | } .elsewhen(cnt =/= 0.U) {
157 | stateReg := state(4)
158 | }
159 |
160 | // module instantion
161 | val proc1 = Module(new Proc1())
162 | proc1.io.in := Mux(cnt === 0.U, 0.U.asTypeOf(new stateArray()), stateReg)
163 | state(0) := proc1.io.out
164 | val proc2 = Module(new Proc2())
165 | proc2.io.in := state(0)
166 | state(1) := proc2.io.out
167 | val proc3 = Module(new Proc3())
168 | proc3.io.in := state(1)
169 | state(2) := proc3.io.out
170 | val proc4 = Module(new Proc4())
171 | proc4.io.in := state(2)
172 | state(3) := proc4.io.out
173 | val proc5s = (0 until RoundsNum).map(i => Module(new Proc5(i)).io)
174 | for (i <- 0 until RoundsNum) {
175 | proc5s(i).in := state(3)
176 | }
177 | state(4) := MuxLookup(cnt, proc5s(0).out, (2 to RoundsNum).map(i => (i.asUInt() -> proc5s(i-1).out)))
178 |
179 | io.prngNumber := stateReg
180 | io.done := RegNext(completeFlag)
181 | io.initialized := initialized
182 | }
183 |
184 | class KeccakWithFifo extends Module
185 | with HasCommonParameters
186 | with HasKeccakParameters {
187 | val io = IO(new Bundle {
188 | val valid = Input(Bool())
189 | val seed = Input(new stateArray())
190 | val seedWrite = Input(Bool())
191 | val prn = Output(Vec(ML, UInt(DataWidth.W)))
192 | val done = Output(Bool())
193 | val busy = Output(Bool())
194 | val wb = Output(Bool())
195 | })
196 |
197 | val prng = Module(new KeccakCore)
198 | prng.io.seed := io.seed
199 | prng.io.seedWrite := io.seedWrite
200 | val fifo = Module(new SyncFifo(dep = 32, dataType = UInt((ML * DataWidth).W)))
201 |
202 | // fill the buffer if it is not full
203 | prng.io.valid := !fifo.io.writeFull && prng.io.initialized
204 | // write prn to buffer
205 | fifo.io.writeData := prng.io.prngNumber.s.asUInt()
206 | fifo.io.writeEnable := prng.io.done
207 | fifo.io.readEnable := io.valid
208 |
209 | // get random number
210 | for (i <- 0 until ML) {
211 | io.prn(i) := fifo.io.readData(DataWidth * i + DataWidth - 1, DataWidth * i)
212 | }
213 | io.wb := true.B
214 | io.done := false.B
215 | io.busy := false.B
216 |
217 | val fsm = InstanciateFSM(new FSM{
218 | entryState(stateName = "Idle")
219 | .act{
220 | io.done := false.B
221 | io.busy := false.B
222 | }
223 | .when(io.valid && fifo.io.readEmpty).transferTo(destName = "Wait")
224 | .when(io.valid && !fifo.io.readEmpty).transferTo(destName = "Read")
225 |
226 | state(stateName = "Wait")
227 | .act{
228 | io.done := false.B
229 | io.busy := true.B
230 | }
231 | .when(!fifo.io.readEmpty).transferTo(destName = "Read")
232 |
233 | state(stateName = "Read")
234 | .act{
235 | io.done := true.B
236 | io.busy := false.B
237 | }
238 | .when(io.valid && fifo.io.readEmpty).transferTo(destName = "Wait")
239 | .when(io.valid && !fifo.io.readEmpty).transferTo(destName = "Read")
240 | .otherwise.transferToEnd
241 | })
242 |
243 | }
244 |
245 | // for synthesis
246 | class KeccakNoFifo extends Module
247 | with HasCommonParameters
248 | with HasKeccakParameters {
249 | val io = IO(new Bundle {
250 | val valid = Input(Bool())
251 | val seed = Input(new stateArray())
252 | val seedWrite = Input(Bool())
253 | val prn = Output(Vec(ML, UInt(DataWidth.W)))
254 | val done = Output(Bool())
255 | val busy = Output(Bool())
256 | val wb = Output(Bool())
257 | // from/to syncfifo
258 | val writeData = Output(UInt((ML * DataWidth).W))
259 | val writeEnable = Output(Bool())
260 | val readEnable = Output(Bool())
261 | val readData = Input(UInt((ML * DataWidth).W))
262 | val readEmpty = Input(Bool())
263 | val writeFull = Input(Bool())
264 | })
265 | val prng = Module(new KeccakCore)
266 | prng.io.seed := io.seed
267 | prng.io.seedWrite := io.seedWrite
268 | prng.io.valid := !io.writeFull && prng.io.initialized
269 |
270 | // write prn to buffer
271 | io.writeData := prng.io.prngNumber.s.asUInt()
272 | io.writeEnable := prng.io.done
273 | io.readEnable := io.valid
274 |
275 | // get random number
276 | for (i <- 0 until ML) {
277 | io.prn(i) := io.readData(DataWidth * i + DataWidth - 1, DataWidth * i)
278 | }
279 | io.wb := true.B
280 | io.done := false.B
281 | io.busy := false.B
282 |
283 | val fsm = InstanciateFSM(new FSM{
284 | entryState(stateName = "Idle")
285 | .act{
286 | io.done := false.B
287 | io.busy := false.B
288 | }
289 | .when(io.valid && io.readEmpty).transferTo(destName = "Wait")
290 | .when(io.valid && !io.readEmpty).transferTo(destName = "Read")
291 |
292 | state(stateName = "Wait")
293 | .act{
294 | io.done := false.B
295 | io.busy := true.B
296 | }
297 | .when(!io.readEmpty).transferTo(destName = "Read")
298 |
299 | state(stateName = "Read")
300 | .act{
301 | io.done := true.B
302 | io.busy := false.B
303 | }
304 | .when(io.valid && io.readEmpty).transferTo(destName = "Wait")
305 | .when(io.valid && !io.readEmpty).transferTo(destName = "Read")
306 | .otherwise.transferToEnd
307 | })
308 |
309 | // val s = prng.io.prngNumber.s.asUInt()
310 | // for (i <- 0 until ML) {
311 | // io.prn(i) := s(DataWidth * i + DataWidth - 1, DataWidth * i)
312 | // }
313 | // io.wb := true.B
314 | // io.done := RegNext(io.valid)
315 | // io.busy := false.B
316 | }
317 |
318 | object elaborateKeccak extends App {
319 | chisel3.Driver.execute(args, () => new KeccakNoFifo)
320 | }
321 |
--------------------------------------------------------------------------------
/src/main/scala/VPQC/KeccakParameters.scala:
--------------------------------------------------------------------------------
1 |
2 | package VPQC
3 |
4 | import chisel3._
5 |
6 | trait HasKeccakParameters {
7 | val ArrayLength = 1600
8 | val RoundsNum = 24
9 |
10 | def mod(a: Int, q: Int): Int = {
11 | var res = 0
12 | if (a >= 0) {
13 | res = a % q
14 | } else {
15 | res = a + q
16 | }
17 | res
18 | }
19 | def range(a: Int, upBound: Int, downBound: Int) : Int = {
20 | assert(upBound < 32)
21 | assert(downBound >= 0)
22 | return (a >> downBound) & (0xffffffff >>> (31-upBound+downBound))
23 | }
24 |
25 | val proc2param = Array( Array(0, 1, 190, 28, 91),
26 | Array(36, 300, 6, 55, 276),
27 | Array(3, 10, 171, 153, 231),
28 | Array(105, 45, 15, 21, 136),
29 | Array(210, 66, 253, 120, 78)
30 | )
31 | def rc(t: Int): Int = {
32 | var res = 0
33 | var R: Int = 1
34 | var R8: Int = 0
35 | if (t % 255 == 0) {
36 | res = 1
37 | }
38 | else {
39 | for (i <- 1 to t % 255) {
40 | R = R << 1
41 | R8 = range(R, 8, 8)
42 | R = R ^ (R8 | (R8 << 4) | (R8 << 5) | (R8 << 6))
43 | R = R & 0xff
44 | }
45 | res = R & 1
46 | }
47 | res
48 | }
49 | }
50 |
--------------------------------------------------------------------------------
/src/main/scala/VPQC/ModularArithmetic.scala:
--------------------------------------------------------------------------------
1 |
2 | package VPQC
3 | import chisel3._
4 | import chisel3.util._
5 |
6 | class ModIO extends Bundle
7 | with HasCommonParameters {
8 | // two operands
9 | val a = Input(UInt(DataWidth.W))
10 | val b = Input(UInt(DataWidth.W))
11 | // modulus
12 | val m = Input(UInt(DataWidth.W))
13 | // reduced value
14 | val c = Output(UInt(DataWidth.W))
15 | }
16 |
17 | class ModMulIO extends ModIO {
18 | // precomputed value
19 | val u = Input(UInt((DataWidth + 2).W))
20 | val n = Input(UInt(5.W))
21 | }
22 |
23 | class ModAdd extends Module
24 | with HasCommonParameters {
25 | val io = IO(new ModIO)
26 |
27 | val ctmp1 = Wire(UInt((DataWidth + 1).W))
28 | ctmp1 := io.a +& io.b
29 | val ctmp2 = Wire(UInt((DataWidth + 1).W))
30 | ctmp2 := ctmp1 -& io.m
31 | val flag = ctmp2(DataWidth)
32 | io.c := Mux(flag, ctmp1(DataWidth - 1, 0), ctmp2(DataWidth - 1, 0))
33 | }
34 |
35 | object ModAdd {
36 | def apply(a: UInt, b: UInt, m: UInt): UInt = {
37 | val inst = Module(new ModAdd())
38 | inst.io.a := a
39 | inst.io.b := b
40 | inst.io.m := m
41 | inst.io.c
42 | }
43 | }
44 |
45 | class ModSub extends Module
46 | with HasCommonParameters {
47 | val io = IO(new ModIO)
48 |
49 | val ctmp1 = Wire(UInt((DataWidth + 1).W))
50 | ctmp1 := io.a -& io.b
51 | val ctmp2 = Wire(UInt((DataWidth + 1).W))
52 | ctmp2 := ctmp1 +& io.m
53 | val flag = ctmp1(DataWidth)
54 | io.c := Mux(flag, ctmp2(DataWidth - 1, 0), ctmp1(DataWidth - 1, 0))
55 | }
56 |
57 | object ModSub {
58 | def apply(a: UInt, b: UInt, m: UInt): UInt = {
59 | val inst = Module(new ModSub())
60 | inst.io.a := a
61 | inst.io.b := b
62 | inst.io.m := m
63 | inst.io.c
64 | }
65 | }
66 |
67 | class ModMul extends Module
68 | with HasCommonParameters {
69 | val io = IO(new ModMulIO)
70 |
71 | val a = io.a
72 | val b = io.b
73 | val m = io.m
74 | val u = io.u
75 | val n = io.n
76 |
77 | // process 1
78 | val c = a * b
79 | val shift = c >> (n - 2.U)
80 |
81 | // process 2
82 | val mul1 = u * shift.asUInt()
83 |
84 | // process 3
85 | val qGuess = mul1 >> (n + 3.U)
86 | val mul2 = io.m * qGuess.asUInt()
87 |
88 | // process 4
89 | val z = c - mul2
90 | io.c := Mux(z < io.m, z, z - io.m)
91 | }
92 |
93 | object ModMul {
94 | def apply(a: UInt, b: UInt, m: UInt, u: UInt, n: UInt): UInt = {
95 | val inst = Module(new ModMul())
96 | inst.io.a := a
97 | inst.io.b := b
98 | inst.io.m := m
99 | inst.io.u := u
100 | inst.io.n := n
101 | inst.io.c
102 | }
103 | }
104 |
105 | // Note: This Module can be eliminated by sharing the
106 | // hardware resources in NTT Module
107 | class VectorArithIO extends Bundle
108 | with HasCommonParameters {
109 | val valid = Input(Bool())
110 | // config
111 | // val m = Input(UInt(DataWidth.W))
112 | // val u = Input(UInt((DataWidth + 2).W))
113 | // val n = Input(UInt(5.W))
114 | val csrs = Input(new CSRIO)
115 | // input
116 | val addA = Input(Vec(ML, UInt(DataWidth.W)))
117 | val addB = Input(Vec(ML, UInt(DataWidth.W)))
118 | val subA = Input(Vec(ML, UInt(DataWidth.W)))
119 | val subB = Input(Vec(ML, UInt(DataWidth.W)))
120 | val mulA = Input(Vec(ML, UInt(DataWidth.W)))
121 | val mulB = Input(Vec(ML, UInt(DataWidth.W)))
122 | // output
123 | val addRes = Output(Vec(ML, UInt(DataWidth.W)))
124 | val subRes = Output(Vec(ML, UInt(DataWidth.W)))
125 | val mulRes = Output(Vec(ML, UInt(DataWidth.W)))
126 |
127 | val done = Output(Bool())
128 | val busy = Output(Bool())
129 | val wb = Output(Bool())
130 | }
131 |
132 | class VectorArith extends Module
133 | with HasCommonParameters {
134 | val io = IO(new VectorArithIO())
135 |
136 |
137 | val adds = VecInit(Seq.fill(ML)(Module(new ModAdd()).io))
138 | val subs = VecInit(Seq.fill(ML)(Module(new ModSub()).io))
139 | val muls = VecInit(Seq.fill(ML)(Module(new ModMul()).io))
140 | for (i <- 0 until ML) {
141 | adds(i).a := io.addA(i)
142 | adds(i).b := io.addB(i)
143 | adds(i).m := io.csrs.csrModulusq
144 | io.addRes(i) := adds(i).c
145 | subs(i).a := io.subA(i)
146 | subs(i).b := io.subB(i)
147 | subs(i).m := io.csrs.csrModulusq
148 | io.subRes(i) := subs(i).c
149 | muls(i).a := io.mulA(i)
150 | muls(i).b := io.mulB(i)
151 | muls(i).m := io.csrs.csrModulusq
152 | muls(i).u := io.csrs.csrBarretu
153 | muls(i).n := io.csrs.csrModulusLen
154 | io.mulRes(i) := muls(i).c
155 | }
156 |
157 | io.done := RegNext(io.valid)
158 | io.busy := false.B
159 | io.wb := true.B
160 | }
161 |
162 | object elaborateMod extends App {
163 | chisel3.Driver.execute(args, () => new VectorArith)
164 | }
--------------------------------------------------------------------------------
/src/main/scala/VPQC/NTT.scala:
--------------------------------------------------------------------------------
1 |
2 | package VPQC
3 |
4 | import chisel3._
5 | import chisel3.util._
6 | import scala.math._
7 | import utility._
8 | import fsm._
9 |
10 | /*
11 | class MRFixIO extends Bundle
12 | with HasNTTCommonParameters {
13 | val a = Input(UInt((DataWidth * 2).W))
14 | val ar = Output(UInt(DataWidth.W))
15 | }
16 |
17 | class MRFix extends Module
18 | with HasNTTCommonParameters
19 | with HasMRParameters {
20 | val io = IO(new MRFixIO())
21 |
22 | // 16bit
23 | val tmp1 = io.a >> 12.U
24 | // 20bit
25 | val tmp2 = (tmp1 << 1.U) +& (tmp1 << 3.U)
26 | // 32bit
27 | val mul = Wire(UInt(32.W))
28 | mul := (tmp2 << 12.U) + (tmp2 << 8.U) + (tmp2 << 4.U) + (tmp1 << 3.U) - tmp1
29 | // 15bit
30 | val qGuess = mul >> 17.U
31 | // 28bit
32 | val qM = Wire(UInt((DataWidth * 2).W))
33 | qM := qGuess + (qGuess << 12.U) + (qGuess << 13.U)
34 |
35 | val z = io.a - qM
36 | io.ar := Mux(z < MRq.asUInt(), z, z-MRq.asUInt())
37 | }
38 |
39 | object MRFix {
40 | def apply(a: UInt, ar: UInt): Module = {
41 | val inst = Module(new MRFix())
42 | inst.io.a := a
43 | ar := inst.io.ar
44 | inst
45 | }
46 | def apply(a: UInt): UInt = {
47 | val inst = Module(new MRFix())
48 | inst.io.a := a
49 | inst.io.ar
50 | }
51 | }
52 | */
53 |
54 | // add configurability
55 |
56 | class MRIO extends Bundle
57 | with HasCommonParameters {
58 | val a = Input(UInt((DataWidth * 2).W))
59 | val n = Input(UInt(5.W))
60 | val m = Input(UInt(DataWidth.W))
61 | val u = Input(UInt((DataWidth + 2).W))
62 | val ar = Output(UInt(DataWidth.W))
63 | }
64 |
65 | class MR extends Module
66 | with HasCommonParameters {
67 | val io = IO(new MRIO)
68 |
69 | val shift1 = io.a >> (io.n - 2.U)
70 | val mul1 = io.u * shift1.asUInt
71 | val qGuess = mul1 >> (io.n + 3.U)
72 | // TODO: check if pipeline is needed
73 | val mul2 = qGuess * io.m
74 | val z = io.a - mul2
75 |
76 | io.ar := Mux(z < io.m, z, z-io.m)
77 | }
78 |
79 | object MR {
80 | def apply(a: UInt, n: UInt, m: UInt, u: UInt): UInt = {
81 | val inst = Module(new MR())
82 | inst.io.a := a
83 | inst.io.n := n
84 | inst.io.m := m
85 | inst.io.u := u
86 | inst.io.ar
87 | }
88 | }
89 |
90 | object elaborateMR extends App {
91 | chisel3.Driver.execute(args, () => new MR)
92 | }
93 |
94 | // no resource sharing version
95 | class ButterflyIO extends Bundle
96 | with HasCommonParameters {
97 | // input two data
98 | val a = Input(UInt(DataWidth.W))
99 | val b = Input(UInt(DataWidth.W))
100 | // twiddle factor
101 | val wn = Input(UInt(DataWidth.W))
102 | // control whether DIT or DIF
103 | val mode = Input(Bool())
104 | // config for ModMul
105 | val m = Input(UInt(DataWidth.W))
106 | val u = Input(UInt((DataWidth + 2).W))
107 | val n = Input(UInt(5.W))
108 | // output two data
109 | val aout = Output(UInt(DataWidth.W))
110 | val bout = Output(UInt(DataWidth.W))
111 | }
112 |
113 | // DIT && DIF butterfly
114 | // 0 : DIT
115 | // 1 : DIF
116 | class Butterfly extends Module
117 | with HasCommonParameters {
118 | val io = IO(new ButterflyIO)
119 |
120 | val a = io.a
121 | val b = io.b
122 | val wn = io.wn
123 |
124 | // pre-process
125 | val aout1 = ModAdd(a, b, io.m)
126 | val bout1 = ModSub(a, b, io.m)
127 | val amux = Mux(io.mode, aout1, a)
128 | val bmux = Mux(io.mode, bout1, b)
129 |
130 | // post-process
131 | val amul = amux
132 | val bmul = ModMul(a = bmux, b = wn, m = io.m, u = io.u, n = io.n)
133 |
134 | val aout2 = ModAdd(amul, bmul, io.m)
135 | val bout2 = ModSub(amul, bmul, io.m)
136 |
137 | io.aout := Mux(io.mode, amul, aout2)
138 | io.bout := Mux(io.mode, bmul, bout2)
139 | }
140 |
141 | object elaborateButterfly extends App {
142 | chisel3.Driver.execute(args, () => new Butterfly)
143 | }
144 |
145 | // resource sharing version
146 | class ButterflyShareIO extends Bundle
147 | with HasCommonParameters {
148 | // input two data
149 | val a = Input(UInt(DataWidth.W))
150 | val b = Input(UInt(DataWidth.W))
151 | // twiddle factor
152 | val wn = Input(UInt(DataWidth.W))
153 | // control whether DIT or DIF
154 | val mode = Input(Bool())
155 | val vecValid = Input(Vec(3, Bool()))
156 | // config for ModMul
157 | val m = Input(UInt(DataWidth.W))
158 | val u = Input(UInt((DataWidth + 2).W))
159 | val n = Input(UInt(5.W))
160 | // output two data
161 | val aout = Output(UInt(DataWidth.W))
162 | val bout = Output(UInt(DataWidth.W))
163 | }
164 |
165 | // DIT && DIF butterfly
166 | // 0 : DIT
167 | // 1 : DIF
168 | class ButterflyShare extends Module
169 | with HasCommonParameters {
170 | val io = IO(new ButterflyShareIO)
171 |
172 | val a = io.a
173 | val b = io.b
174 | val wn = io.wn
175 |
176 | // pre-process
177 | val aout1 = ModAdd(a, b, io.m)
178 | val bout1 = ModSub(a, b, io.m)
179 | val amux = Mux(io.mode, aout1, a)
180 | val bmux = Mux(io.mode, bout1, b)
181 |
182 | // post-process
183 | val mul1 = Mux(io.vecValid(2), a, bmux)
184 | val mul2 = Mux(io.vecValid(2), b, wn)
185 | val amul = amux
186 | val bmul = ModMul(a = mul1, b = mul2, m = io.m, u = io.u, n = io.n)
187 |
188 | val aout2 = ModAdd(amul, bmul, io.m)
189 | val bout2 = ModSub(amul, bmul, io.m)
190 |
191 | io.aout := Mux(io.vecValid(0), aout1,
192 | Mux(io.vecValid(1), bout1,
193 | Mux(io.vecValid(2), bmul,
194 | Mux(io.mode, amul, aout2))))
195 | io.bout := Mux(io.mode, bmul, bout2)
196 | }
197 |
198 | // permutation network
199 | class PermNetIO extends Bundle
200 | with HasCommonParameters
201 | with HasNTTParameters {
202 | val in = Input(Vec(ButterflyNum * 2, UInt(DataWidth.W)))
203 | val config = Input(Vec(log2Ceil(ButterflyNum), Bool()))
204 | val out = Output(Vec(ButterflyNum * 2, UInt(DataWidth.W)))
205 | }
206 | class PermNetIn extends Module
207 | with HasCommonParameters
208 | with HasNTTParameters {
209 | val io = IO(new PermNetIO)
210 |
211 | def perm(split: Int): Vec[UInt] = {
212 | val perm = Wire(Vec(ButterflyNum * 2, UInt(DataWidth.W)))
213 | for(i <- 0 until ButterflyNum * 2 / split) {
214 | for (j <- 0 until split) {
215 | if (j % 2 == 0) {
216 | perm(i * split + j) := io.in(i * split + j / 2)
217 | } else {
218 | perm(i * split + j) := io.in(i * split + split / 2 + (j - 1) / 2)
219 | }
220 | }
221 | }
222 | perm
223 | }
224 |
225 | val multiPerm = Wire(Vec(log2Ceil(ButterflyNum), Vec(ButterflyNum * 2, UInt(DataWidth.W))))
226 | for (i <- 0 until log2Ceil(ButterflyNum)) {
227 | multiPerm(i) := perm(pow(2, i + 2).toInt)
228 | }
229 |
230 | val out = Mux1H(io.config, multiPerm)
231 | io.out := Mux(!(io.config.reduce(_ || _)), io.in, out)
232 | }
233 | object elaboratePermNet extends App {
234 | chisel3.Driver.execute(args, () => new PermNetOut)
235 | }
236 |
237 | class PermNetOut extends Module
238 | with HasCommonParameters
239 | with HasNTTParameters {
240 | val io = IO(new PermNetIO)
241 |
242 | def perm(split: Int): Vec[UInt] = {
243 | val perm = Wire(Vec(ButterflyNum * 2, UInt(DataWidth.W)))
244 | for(i <- 0 until ButterflyNum * 2 / split) {
245 | for (j <- 0 until split) {
246 | if (j < split / 2) {
247 | perm(i * split + j) := io.in(i * split + j * 2)
248 | } else {
249 | perm(i * split + j) := io.in(i * split + 1 + (j - split / 2) * 2)
250 | }
251 | }
252 | }
253 | perm
254 | }
255 |
256 | val multiPerm = Wire(Vec(log2Ceil(ButterflyNum), Vec(ButterflyNum * 2, UInt(DataWidth.W))))
257 | for (i <- 0 until log2Ceil(ButterflyNum)) {
258 | multiPerm(i) := perm(pow(2, i + 2).toInt)
259 | }
260 |
261 | val out = Mux1H(io.config, multiPerm)
262 | io.out := Mux(!(io.config.reduce(_ || _)), io.in, out)
263 | }
264 |
265 | class NTTIO extends Bundle
266 | with HasCommonParameters
267 | with HasNTTParameters {
268 | val valid = Input(Bool())
269 | val mode = Input(Bool())
270 |
271 | // ram write interface
272 | val wa = Input(UInt(log2Ceil(Dimension / ButterflyNum).W))
273 | val di = Input(UInt((DataWidth * ButterflyNum).W))
274 | val we = Input(Bool())
275 |
276 | // from register
277 | val vectorReg1 = Input(Vec(ML, UInt(DataWidth.W)))
278 | val vectorReg2 = Input(Vec(ML, UInt(DataWidth.W)))
279 |
280 | // csr interface
281 | val csrs = Input(new CSRIO)
282 |
283 | // output
284 | val dataOut = Output(Vec(ButterflyNum * 2, UInt(DataWidth.W)))
285 | val done = Output(Bool())
286 | val busy = Output(Bool())
287 | val wb = Output(Bool())
288 | // val wpos = Output(Bool())
289 | }
290 |
291 | // This NTT Module does not apply resource sharing with vector arithmetic units
292 | class NTT extends Module
293 | with HasCommonParameters
294 | with HasNTTParameters {
295 | val io = IO(new NTTIO)
296 |
297 |
298 | io.wb := true.B
299 | io.done := RegNext(io.valid)
300 | io.busy := false.B
301 |
302 | // permutation in
303 | val permNet = Module(new PermNetIn)
304 | for (i <- 0 until ButterflyNum*2) {
305 | if (i < ButterflyNum) {
306 | permNet.io.in(i) := io.vectorReg1(i)
307 | }
308 | else {
309 | permNet.io.in(i) := io.vectorReg2(i - ButterflyNum)
310 | }
311 | }
312 |
313 | for (i <- 0 until log2Ceil(ButterflyNum)) {
314 | permNet.io.config(i) := false.B
315 | }
316 |
317 | // to support NTT and INTT : stage i -> stage n-1-i
318 |
319 | if (ButterflyNum == 4) {
320 | when(io.csrs.csrButterflyCtrl.stageCfg === 0.U) {
321 |
322 | } .elsewhen(io.csrs.csrButterflyCtrl.stageCfg === 1.U) {
323 | permNet.io.config(0) := true.B
324 | } .otherwise {
325 | permNet.io.config(1) := true.B
326 | }
327 | } else if (ButterflyNum == 8) {
328 | when(io.csrs.csrButterflyCtrl.stageCfg === 0.U) {
329 |
330 | } .elsewhen(io.csrs.csrButterflyCtrl.stageCfg === 1.U) {
331 | permNet.io.config(0) := true.B
332 | } .elsewhen(io.csrs.csrButterflyCtrl.stageCfg === 2.U) {
333 | permNet.io.config(1) := true.B
334 | } .otherwise {
335 | permNet.io.config(2) := true.B
336 | }
337 | } else if (ButterflyNum == 16) {
338 | when(io.csrs.csrButterflyCtrl.stageCfg === 0.U) {
339 |
340 | } .elsewhen(io.csrs.csrButterflyCtrl.stageCfg === 1.U) {
341 | permNet.io.config(0) := true.B
342 | } .elsewhen(io.csrs.csrButterflyCtrl.stageCfg === 2.U) {
343 | permNet.io.config(1) := true.B
344 | } .elsewhen(io.csrs.csrButterflyCtrl.stageCfg === 3.U) {
345 | permNet.io.config(2) := true.B
346 | } .otherwise { // split = 32
347 | permNet.io.config(3) := true.B
348 | }
349 | } else if (ButterflyNum == 32) {
350 | when(io.csrs.csrButterflyCtrl.stageCfg === 0.U) {
351 |
352 | } .elsewhen(io.csrs.csrButterflyCtrl.stageCfg === 1.U) {
353 | permNet.io.config(0) := true.B
354 | } .elsewhen(io.csrs.csrButterflyCtrl.stageCfg === 2.U) {
355 | permNet.io.config(1) := true.B
356 | } .elsewhen(io.csrs.csrButterflyCtrl.stageCfg === 3.U) {
357 | permNet.io.config(2) := true.B
358 | } .elsewhen(io.csrs.csrButterflyCtrl.stageCfg === 4.U) { // split = 32
359 | permNet.io.config(3) := true.B
360 | }.otherwise {
361 | permNet.io.config(4) := true.B
362 | }
363 | } else if (ButterflyNum == 64) {
364 | when(io.csrs.csrButterflyCtrl.stageCfg === 0.U) {
365 |
366 | } .elsewhen(io.csrs.csrButterflyCtrl.stageCfg === 1.U) {
367 | permNet.io.config(0) := true.B
368 | } .elsewhen(io.csrs.csrButterflyCtrl.stageCfg === 2.U) {
369 | permNet.io.config(1) := true.B
370 | } .elsewhen(io.csrs.csrButterflyCtrl.stageCfg === 3.U) {
371 | permNet.io.config(2) := true.B
372 | } .elsewhen(io.csrs.csrButterflyCtrl.stageCfg === 4.U) {
373 | permNet.io.config(3) := true.B
374 | }.elsewhen(io.csrs.csrButterflyCtrl.stageCfg === 5.U) {
375 | permNet.io.config(4) := true.B
376 | }.otherwise {
377 | permNet.io.config(5) := true.B
378 | }
379 | }
380 |
381 | // wn addr prepare
382 | val wnBaseAddr1 = Wire(UInt(log2Ceil(Dimension / ButterflyNum).W))
383 | val wnBaseAddr2 = Wire(UInt(log2Ceil(Dimension / ButterflyNum).W))
384 | // contain stage info
385 | wnBaseAddr1 := Mux(io.csrs.csrButterflyCtrl.stageCfg < log2Ceil(ButterflyNum).asUInt(),
386 | 0.U, 1.U << (io.csrs.csrButterflyCtrl.stageCfg - log2Ceil(ButterflyNum).asUInt()))
387 | // contain iter info
388 | wnBaseAddr2 := Mux(io.csrs.csrButterflyCtrl.stageCfg > log2Ceil(ButterflyNum).asUInt(),
389 | io.csrs.csrButterflyCtrl.iterCfg + wnBaseAddr1, wnBaseAddr1)
390 |
391 |
392 | // wn consts ram
393 | /**
394 | *
395 | * (1) {1 1} {1 1/4} {1 1/8 ..} {1 1/16 ... 7/16}
396 | * (2) 1 1/32 ... 15/32
397 | * (3) ...
398 | *
399 | * 32 line 1 1/1024 ... 511/1024
400 | *
401 | */
402 | val twiddleRam = Module(new SyncRam(dep = Dimension / ButterflyNum, dw = DataWidth * ButterflyNum))
403 | twiddleRam.io.re := io.valid
404 | twiddleRam.io.we := io.we
405 | twiddleRam.io.ra := wnBaseAddr2
406 | twiddleRam.io.wa := io.wa
407 | twiddleRam.io.di := io.di
408 |
409 |
410 | // butterfly PEs
411 | val PEs = VecInit(Seq.fill(ButterflyNum)(Module(new Butterfly()).io))
412 | for (i <- 0 until ButterflyNum) {
413 | PEs(i).a := permNet.io.out(2 * i)
414 | PEs(i).b := permNet.io.out(2 * i + 1)
415 |
416 | PEs(i).wn := twiddleRam.io.dout(DataWidth * i + DataWidth - 1, DataWidth * i)
417 | PEs(i).mode := io.mode
418 |
419 | // wn assign
420 | if (ButterflyNum == 4) {
421 | when(io.csrs.csrButterflyCtrl.stageCfg < log2Ceil(ButterflyNum).asUInt()) {
422 | switch(io.csrs.csrButterflyCtrl.stageCfg) {
423 | is(0.U) {
424 | PEs(i).wn := twiddleRam.io.dout(DataWidth - 1, 0)
425 | }
426 | is(1.U) {
427 | PEs(i).wn := twiddleRam.io.dout(DataWidth * (i % 2 + 2) + DataWidth - 1, DataWidth * (i % 2 + 2))
428 | }
429 | }
430 | }
431 | } else if (ButterflyNum == 8) {
432 | when(io.csrs.csrButterflyCtrl.stageCfg < log2Ceil(ButterflyNum).asUInt()) {
433 | switch(io.csrs.csrButterflyCtrl.stageCfg) {
434 | is(0.U) {
435 | PEs(i).wn := twiddleRam.io.dout(DataWidth - 1, 0)
436 | }
437 | is(1.U) {
438 | PEs(i).wn := twiddleRam.io.dout(DataWidth * (i % 2 + 2) + DataWidth - 1, DataWidth * (i % 2 + 2))
439 | }
440 | is(2.U) {
441 | PEs(i).wn := twiddleRam.io.dout(DataWidth * (i % 4 + 4) + DataWidth - 1, DataWidth * (i % 4 + 4))
442 | }
443 | }
444 | }
445 | } else if (ButterflyNum == 16) {
446 | when(io.csrs.csrButterflyCtrl.stageCfg < log2Ceil(ButterflyNum).asUInt()) {
447 | switch(io.csrs.csrButterflyCtrl.stageCfg) {
448 | is(0.U) {
449 | PEs(i).wn := twiddleRam.io.dout(DataWidth - 1, 0)
450 | }
451 | is(1.U) {
452 | PEs(i).wn := twiddleRam.io.dout(DataWidth * (i % 2 + 2) + DataWidth - 1, DataWidth * (i % 2 + 2))
453 | }
454 | is(2.U) {
455 | PEs(i).wn := twiddleRam.io.dout(DataWidth * (i % 4 + 4) + DataWidth - 1, DataWidth * (i % 4 + 4))
456 | }
457 | is(3.U) {
458 | PEs(i).wn := twiddleRam.io.dout(DataWidth * (i % 8 + 8) + DataWidth - 1, DataWidth * (i % 8 + 8))
459 | }
460 | }
461 | }
462 | } else if (ButterflyNum == 32) {
463 | when(io.csrs.csrButterflyCtrl.stageCfg < log2Ceil(ButterflyNum).asUInt()) {
464 | switch(io.csrs.csrButterflyCtrl.stageCfg) {
465 | is(0.U) {
466 | PEs(i).wn := twiddleRam.io.dout(DataWidth - 1, 0)
467 | }
468 | is(1.U) {
469 | PEs(i).wn := twiddleRam.io.dout(DataWidth * (i % 2 + 2) + DataWidth - 1, DataWidth * (i % 2 + 2))
470 | }
471 | is(2.U) {
472 | PEs(i).wn := twiddleRam.io.dout(DataWidth * (i % 4 + 4) + DataWidth - 1, DataWidth * (i % 4 + 4))
473 | }
474 | is(3.U) {
475 | PEs(i).wn := twiddleRam.io.dout(DataWidth * (i % 8 + 8) + DataWidth - 1, DataWidth * (i % 8 + 8))
476 | }
477 | is(4.U) {
478 | PEs(i).wn := twiddleRam.io.dout(DataWidth * (i % 16 + 16) + DataWidth - 1, DataWidth * (i % 16 + 16))
479 | }
480 | }
481 | }
482 | } else if (ButterflyNum == 64) {
483 | when(io.csrs.csrButterflyCtrl.stageCfg < log2Ceil(ButterflyNum).asUInt()) {
484 | switch(io.csrs.csrButterflyCtrl.stageCfg) {
485 | is(0.U) {
486 | PEs(i).wn := twiddleRam.io.dout(DataWidth - 1, 0)
487 | }
488 | is(1.U) {
489 | PEs(i).wn := twiddleRam.io.dout(DataWidth * (i % 2 + 2) + DataWidth - 1, DataWidth * (i % 2 + 2))
490 | }
491 | is(2.U) {
492 | PEs(i).wn := twiddleRam.io.dout(DataWidth * (i % 4 + 4) + DataWidth - 1, DataWidth * (i % 4 + 4))
493 | }
494 | is(3.U) {
495 | PEs(i).wn := twiddleRam.io.dout(DataWidth * (i % 8 + 8) + DataWidth - 1, DataWidth * (i % 8 + 8))
496 | }
497 | is(4.U) {
498 | PEs(i).wn := twiddleRam.io.dout(DataWidth * (i % 16 + 16) + DataWidth - 1, DataWidth * (i % 16 + 16))
499 | }
500 | is(5.U) {
501 | PEs(i).wn := twiddleRam.io.dout(DataWidth * (i % 32 + 32) + DataWidth - 1, DataWidth * (i % 32 + 32))
502 | }
503 | }
504 | }
505 | }
506 |
507 | PEs(i).n := io.csrs.csrModulusLen
508 | PEs(i).m := io.csrs.csrModulusq
509 | PEs(i).u := io.csrs.csrBarretu
510 |
511 | }
512 |
513 | // permutation out
514 | val permNetOut = Module(new PermNetOut)
515 | permNetOut.io.config := permNet.io.config
516 | for (i <- 0 until ButterflyNum) {
517 | permNetOut.io.in(2 * i) := PEs(i).aout
518 | permNetOut.io.in(2 * i + 1) := PEs(i).bout
519 | }
520 | io.dataOut := permNetOut.io.out
521 | }
522 |
523 | class NTTWithoutRam extends Module
524 | with HasCommonParameters
525 | with HasNTTParameters {
526 | val io = IO(new Bundle {
527 | val valid = Input(Bool())
528 | val mode = Input(Bool())
529 |
530 | // ram from outside
531 | val ra = Output(UInt(log2Ceil(Dimension / ButterflyNum).W))
532 | val re = Output(Bool())
533 | val di = Input(UInt((ButterflyNum * DataWidth).W))
534 |
535 | // from register
536 | val vectorReg1 = Input(Vec(ML, UInt(DataWidth.W)))
537 | val vectorReg2 = Input(Vec(ML, UInt(DataWidth.W)))
538 |
539 | // csr interface
540 | val csrs = Input(new CSRIO)
541 |
542 | // output
543 | val dataOut = Output(Vec(ButterflyNum * 2, UInt(DataWidth.W)))
544 | val done = Output(Bool())
545 | val busy = Output(Bool())
546 | val wb = Output(Bool())
547 | })
548 |
549 |
550 | io.wb := true.B
551 | io.done := RegNext(io.valid)
552 | io.busy := false.B
553 |
554 | // permutation in
555 | val permNet = Module(new PermNetIn)
556 | for (i <- 0 until ButterflyNum*2) {
557 | if (i < ButterflyNum) {
558 | permNet.io.in(i) := io.vectorReg1(i)
559 | }
560 | else {
561 | permNet.io.in(i) := io.vectorReg2(i - ButterflyNum)
562 | }
563 | }
564 |
565 | // to support NTT and INTT : stage i -> stage n-1-i
566 | for (i <- 0 until log2Ceil(ButterflyNum)) {
567 | permNet.io.config(i) := false.B
568 | }
569 | if (ButterflyNum == 4) {
570 | when(io.csrs.csrButterflyCtrl.stageCfg === 0.U) {
571 |
572 | } .elsewhen(io.csrs.csrButterflyCtrl.stageCfg === 1.U) {
573 | permNet.io.config(0) := true.B
574 | } .otherwise {
575 | permNet.io.config(1) := true.B
576 | }
577 | } else if (ButterflyNum == 8) {
578 | when(io.csrs.csrButterflyCtrl.stageCfg === 0.U) {
579 |
580 | } .elsewhen(io.csrs.csrButterflyCtrl.stageCfg === 1.U) {
581 | permNet.io.config(0) := true.B
582 | } .elsewhen(io.csrs.csrButterflyCtrl.stageCfg === 2.U) {
583 | permNet.io.config(1) := true.B
584 | } .otherwise {
585 | permNet.io.config(2) := true.B
586 | }
587 | } else if (ButterflyNum == 16) {
588 | when(io.csrs.csrButterflyCtrl.stageCfg === 0.U) {
589 |
590 | } .elsewhen(io.csrs.csrButterflyCtrl.stageCfg === 1.U) {
591 | permNet.io.config(0) := true.B
592 | } .elsewhen(io.csrs.csrButterflyCtrl.stageCfg === 2.U) {
593 | permNet.io.config(1) := true.B
594 | } .elsewhen(io.csrs.csrButterflyCtrl.stageCfg === 3.U) {
595 | permNet.io.config(2) := true.B
596 | } .otherwise { // split = 32
597 | permNet.io.config(3) := true.B
598 | }
599 | } else if (ButterflyNum == 32) {
600 | when(io.csrs.csrButterflyCtrl.stageCfg === 0.U) {
601 |
602 | } .elsewhen(io.csrs.csrButterflyCtrl.stageCfg === 1.U) {
603 | permNet.io.config(0) := true.B
604 | } .elsewhen(io.csrs.csrButterflyCtrl.stageCfg === 2.U) {
605 | permNet.io.config(1) := true.B
606 | } .elsewhen(io.csrs.csrButterflyCtrl.stageCfg === 3.U) {
607 | permNet.io.config(2) := true.B
608 | } .elsewhen(io.csrs.csrButterflyCtrl.stageCfg === 4.U) { // split = 32
609 | permNet.io.config(3) := true.B
610 | }.otherwise {
611 | permNet.io.config(4) := true.B
612 | }
613 | } else if (ButterflyNum == 64) {
614 | when(io.csrs.csrButterflyCtrl.stageCfg === 0.U) {
615 |
616 | } .elsewhen(io.csrs.csrButterflyCtrl.stageCfg === 1.U) {
617 | permNet.io.config(0) := true.B
618 | } .elsewhen(io.csrs.csrButterflyCtrl.stageCfg === 2.U) {
619 | permNet.io.config(1) := true.B
620 | } .elsewhen(io.csrs.csrButterflyCtrl.stageCfg === 3.U) {
621 | permNet.io.config(2) := true.B
622 | } .elsewhen(io.csrs.csrButterflyCtrl.stageCfg === 4.U) {
623 | permNet.io.config(3) := true.B
624 | }.elsewhen(io.csrs.csrButterflyCtrl.stageCfg === 5.U) {
625 | permNet.io.config(4) := true.B
626 | }.otherwise {
627 | permNet.io.config(5) := true.B
628 | }
629 | }
630 |
631 | // wn addr prepare
632 | val wnBaseAddr1 = Wire(UInt(log2Ceil(Dimension / ButterflyNum).W))
633 | val wnBaseAddr2 = Wire(UInt(log2Ceil(Dimension / ButterflyNum).W))
634 | // contain stage info
635 | wnBaseAddr1 := Mux(io.csrs.csrButterflyCtrl.stageCfg < log2Ceil(ButterflyNum).asUInt(),
636 | 0.U, 1.U << (io.csrs.csrButterflyCtrl.stageCfg - log2Ceil(ButterflyNum).asUInt()))
637 | // contain iter info
638 | wnBaseAddr2 := Mux(io.csrs.csrButterflyCtrl.stageCfg > log2Ceil(ButterflyNum).asUInt(),
639 | io.csrs.csrButterflyCtrl.iterCfg + wnBaseAddr1, wnBaseAddr1)
640 |
641 | io.re := io.valid
642 | io.ra := wnBaseAddr2
643 |
644 | // butterfly PEs
645 | val PEs = VecInit(Seq.fill(ButterflyNum)(Module(new Butterfly()).io))
646 | for (i <- 0 until ButterflyNum) {
647 | PEs(i).a := permNet.io.out(2 * i)
648 | PEs(i).b := permNet.io.out(2 * i + 1)
649 | // wn assign
650 | PEs(i).wn := io.di(DataWidth * i + DataWidth - 1, DataWidth * i)
651 | PEs(i).mode := io.mode
652 | if (ButterflyNum == 4) {
653 | when(io.csrs.csrButterflyCtrl.stageCfg < log2Ceil(ButterflyNum).asUInt()) {
654 | switch(io.csrs.csrButterflyCtrl.stageCfg) {
655 | is(0.U) {
656 | PEs(i).wn := io.di(DataWidth - 1, 0)
657 | }
658 | is(1.U) {
659 | PEs(i).wn := io.di(DataWidth * (i % 2 + 2) + DataWidth - 1, DataWidth * (i % 2 + 2))
660 | }
661 | }
662 | }
663 | } else if (ButterflyNum == 8) {
664 | when(io.csrs.csrButterflyCtrl.stageCfg < log2Ceil(ButterflyNum).asUInt()) {
665 | switch(io.csrs.csrButterflyCtrl.stageCfg) {
666 | is(0.U) {
667 | PEs(i).wn := io.di(DataWidth - 1, 0)
668 | }
669 | is(1.U) {
670 | PEs(i).wn := io.di(DataWidth * (i % 2 + 2) + DataWidth - 1, DataWidth * (i % 2 + 2))
671 | }
672 | is(2.U) {
673 | PEs(i).wn := io.di(DataWidth * (i % 4 + 4) + DataWidth - 1, DataWidth * (i % 4 + 4))
674 | }
675 | }
676 | }
677 | } else if (ButterflyNum == 16) {
678 | when(io.csrs.csrButterflyCtrl.stageCfg < log2Ceil(ButterflyNum).asUInt()) {
679 | switch(io.csrs.csrButterflyCtrl.stageCfg) {
680 | is(0.U) {
681 | PEs(i).wn := io.di(DataWidth - 1, 0)
682 | }
683 | is(1.U) {
684 | PEs(i).wn := io.di(DataWidth * (i % 2 + 2) + DataWidth - 1, DataWidth * (i % 2 + 2))
685 | }
686 | is(2.U) {
687 | PEs(i).wn := io.di(DataWidth * (i % 4 + 4) + DataWidth - 1, DataWidth * (i % 4 + 4))
688 | }
689 | is(3.U) {
690 | PEs(i).wn := io.di(DataWidth * (i % 8 + 8) + DataWidth - 1, DataWidth * (i % 8 + 8))
691 | }
692 | }
693 | }
694 | } else if (ButterflyNum == 32) {
695 | when(io.csrs.csrButterflyCtrl.stageCfg < log2Ceil(ButterflyNum).asUInt()) {
696 | switch(io.csrs.csrButterflyCtrl.stageCfg) {
697 | is(0.U) {
698 | PEs(i).wn := io.di(DataWidth - 1, 0)
699 | }
700 | is(1.U) {
701 | PEs(i).wn := io.di(DataWidth * (i % 2 + 2) + DataWidth - 1, DataWidth * (i % 2 + 2))
702 | }
703 | is(2.U) {
704 | PEs(i).wn := io.di(DataWidth * (i % 4 + 4) + DataWidth - 1, DataWidth * (i % 4 + 4))
705 | }
706 | is(3.U) {
707 | PEs(i).wn := io.di(DataWidth * (i % 8 + 8) + DataWidth - 1, DataWidth * (i % 8 + 8))
708 | }
709 | is(4.U) {
710 | PEs(i).wn := io.di(DataWidth * (i % 16 + 16) + DataWidth - 1, DataWidth * (i % 16 + 16))
711 | }
712 | }
713 | }
714 | } else if (ButterflyNum == 64) {
715 | when(io.csrs.csrButterflyCtrl.stageCfg < log2Ceil(ButterflyNum).asUInt()) {
716 | switch(io.csrs.csrButterflyCtrl.stageCfg) {
717 | is(0.U) {
718 | PEs(i).wn := io.di(DataWidth - 1, 0)
719 | }
720 | is(1.U) {
721 | PEs(i).wn := io.di(DataWidth * (i % 2 + 2) + DataWidth - 1, DataWidth * (i % 2 + 2))
722 | }
723 | is(2.U) {
724 | PEs(i).wn := io.di(DataWidth * (i % 4 + 4) + DataWidth - 1, DataWidth * (i % 4 + 4))
725 | }
726 | is(3.U) {
727 | PEs(i).wn := io.di(DataWidth * (i % 8 + 8) + DataWidth - 1, DataWidth * (i % 8 + 8))
728 | }
729 | is(4.U) {
730 | PEs(i).wn := io.di(DataWidth * (i % 16 + 16) + DataWidth - 1, DataWidth * (i % 16 + 16))
731 | }
732 | is(5.U) {
733 | PEs(i).wn := io.di(DataWidth * (i % 32 + 32) + DataWidth - 1, DataWidth * (i % 32 + 32))
734 | }
735 | }
736 | }
737 | }
738 |
739 | PEs(i).n := io.csrs.csrModulusLen
740 | PEs(i).m := io.csrs.csrModulusq
741 | PEs(i).u := io.csrs.csrBarretu
742 |
743 | }
744 |
745 | // permutation out
746 | val permNetOut = Module(new PermNetOut)
747 | permNetOut.io.config := permNet.io.config
748 | for (i <- 0 until ButterflyNum) {
749 | permNetOut.io.in(2 * i) := PEs(i).aout
750 | permNetOut.io.in(2 * i + 1) := PEs(i).bout
751 | }
752 | io.dataOut := permNetOut.io.out
753 | }
754 |
755 | class NTTWithoutRamShare extends Module
756 | with HasCommonParameters
757 | with HasNTTParameters {
758 | val io = IO(new Bundle {
759 | val valid = Input(Bool())
760 | val mode = Input(Bool())
761 | val vecValid = Input(Vec(3, Bool()))
762 |
763 | // ram from outside
764 | val ra = Output(UInt(log2Ceil(Dimension / ButterflyNum).W))
765 | val re = Output(Bool())
766 | val di = Input(UInt((ButterflyNum * DataWidth).W))
767 |
768 | // from register
769 | val vectorReg1 = Input(Vec(ML, UInt(DataWidth.W)))
770 | val vectorReg2 = Input(Vec(ML, UInt(DataWidth.W)))
771 |
772 | // csr interface
773 | val csrs = Input(new CSRIO)
774 |
775 | // output
776 | val dataOut = Output(Vec(ButterflyNum * 2, UInt(DataWidth.W)))
777 | val done = Output(Bool())
778 | val busy = Output(Bool())
779 | val wb = Output(Bool())
780 | })
781 |
782 |
783 | io.wb := true.B
784 | io.done := RegNext(io.valid)
785 | io.busy := false.B
786 |
787 | // permutation in
788 | val permNet = Module(new PermNetIn)
789 | for (i <- 0 until ButterflyNum*2) {
790 | if (i < ButterflyNum) {
791 | permNet.io.in(i) := io.vectorReg1(i)
792 | }
793 | else {
794 | permNet.io.in(i) := io.vectorReg2(i - ButterflyNum)
795 | }
796 | }
797 |
798 | // to support NTT and INTT : stage i -> stage n-1-i
799 | for (i <- 0 until log2Ceil(ButterflyNum)) {
800 | permNet.io.config(i) := false.B
801 | }
802 | if (ButterflyNum == 4) {
803 | when(io.csrs.csrButterflyCtrl.stageCfg === 0.U) {
804 |
805 | } .elsewhen(io.csrs.csrButterflyCtrl.stageCfg === 1.U) {
806 | permNet.io.config(0) := true.B
807 | } .otherwise {
808 | permNet.io.config(1) := true.B
809 | }
810 | } else if (ButterflyNum == 8) {
811 | when(io.csrs.csrButterflyCtrl.stageCfg === 0.U) {
812 |
813 | } .elsewhen(io.csrs.csrButterflyCtrl.stageCfg === 1.U) {
814 | permNet.io.config(0) := true.B
815 | } .elsewhen(io.csrs.csrButterflyCtrl.stageCfg === 2.U) {
816 | permNet.io.config(1) := true.B
817 | } .otherwise {
818 | permNet.io.config(2) := true.B
819 | }
820 | } else if (ButterflyNum == 16) {
821 | when(io.csrs.csrButterflyCtrl.stageCfg === 0.U) {
822 |
823 | } .elsewhen(io.csrs.csrButterflyCtrl.stageCfg === 1.U) {
824 | permNet.io.config(0) := true.B
825 | } .elsewhen(io.csrs.csrButterflyCtrl.stageCfg === 2.U) {
826 | permNet.io.config(1) := true.B
827 | } .elsewhen(io.csrs.csrButterflyCtrl.stageCfg === 3.U) {
828 | permNet.io.config(2) := true.B
829 | } .otherwise { // split = 32
830 | permNet.io.config(3) := true.B
831 | }
832 | } else if (ButterflyNum == 32) {
833 | when(io.csrs.csrButterflyCtrl.stageCfg === 0.U) {
834 |
835 | } .elsewhen(io.csrs.csrButterflyCtrl.stageCfg === 1.U) {
836 | permNet.io.config(0) := true.B
837 | } .elsewhen(io.csrs.csrButterflyCtrl.stageCfg === 2.U) {
838 | permNet.io.config(1) := true.B
839 | } .elsewhen(io.csrs.csrButterflyCtrl.stageCfg === 3.U) {
840 | permNet.io.config(2) := true.B
841 | } .elsewhen(io.csrs.csrButterflyCtrl.stageCfg === 4.U) { // split = 32
842 | permNet.io.config(3) := true.B
843 | }.otherwise {
844 | permNet.io.config(4) := true.B
845 | }
846 | } else if (ButterflyNum == 64) {
847 | when(io.csrs.csrButterflyCtrl.stageCfg === 0.U) {
848 |
849 | } .elsewhen(io.csrs.csrButterflyCtrl.stageCfg === 1.U) {
850 | permNet.io.config(0) := true.B
851 | } .elsewhen(io.csrs.csrButterflyCtrl.stageCfg === 2.U) {
852 | permNet.io.config(1) := true.B
853 | } .elsewhen(io.csrs.csrButterflyCtrl.stageCfg === 3.U) {
854 | permNet.io.config(2) := true.B
855 | } .elsewhen(io.csrs.csrButterflyCtrl.stageCfg === 4.U) {
856 | permNet.io.config(3) := true.B
857 | }.elsewhen(io.csrs.csrButterflyCtrl.stageCfg === 5.U) {
858 | permNet.io.config(4) := true.B
859 | }.otherwise {
860 | permNet.io.config(5) := true.B
861 | }
862 | }
863 |
864 | // wn addr prepare
865 | val wnBaseAddr1 = Wire(UInt(log2Ceil(Dimension / ButterflyNum).W))
866 | val wnBaseAddr2 = Wire(UInt(log2Ceil(Dimension / ButterflyNum).W))
867 | // contain stage info
868 | wnBaseAddr1 := Mux(io.csrs.csrButterflyCtrl.stageCfg < log2Ceil(ButterflyNum).asUInt(),
869 | 0.U, 1.U << (io.csrs.csrButterflyCtrl.stageCfg - log2Ceil(ButterflyNum).asUInt()))
870 | // contain iter info
871 | wnBaseAddr2 := Mux(io.csrs.csrButterflyCtrl.stageCfg > log2Ceil(ButterflyNum).asUInt(),
872 | io.csrs.csrButterflyCtrl.iterCfg + wnBaseAddr1, wnBaseAddr1)
873 |
874 | io.re := io.valid
875 | io.ra := wnBaseAddr2
876 |
877 | // butterfly PEs
878 | val PEs = VecInit(Seq.fill(ButterflyNum)(Module(new ButterflyShare()).io))
879 | for (i <- 0 until ButterflyNum) {
880 | PEs(i).a := Mux(io.valid, permNet.io.out(2 * i), io.vectorReg1(i))
881 | PEs(i).b := Mux(io.valid, permNet.io.out(2 * i + 1), io.vectorReg2(i))
882 | // wn assign
883 | PEs(i).wn := io.di(DataWidth * i + DataWidth - 1, DataWidth * i)
884 | PEs(i).mode := io.mode
885 | PEs(i).vecValid := io.vecValid
886 | if (ButterflyNum == 4) {
887 | when(io.csrs.csrButterflyCtrl.stageCfg < log2Ceil(ButterflyNum).asUInt()) {
888 | switch(io.csrs.csrButterflyCtrl.stageCfg) {
889 | is(0.U) {
890 | PEs(i).wn := io.di(DataWidth - 1, 0)
891 | }
892 | is(1.U) {
893 | PEs(i).wn := io.di(DataWidth * (i % 2 + 2) + DataWidth - 1, DataWidth * (i % 2 + 2))
894 | }
895 | }
896 | }
897 | } else if (ButterflyNum == 8) {
898 | when(io.csrs.csrButterflyCtrl.stageCfg < log2Ceil(ButterflyNum).asUInt()) {
899 | switch(io.csrs.csrButterflyCtrl.stageCfg) {
900 | is(0.U) {
901 | PEs(i).wn := io.di(DataWidth - 1, 0)
902 | }
903 | is(1.U) {
904 | PEs(i).wn := io.di(DataWidth * (i % 2 + 2) + DataWidth - 1, DataWidth * (i % 2 + 2))
905 | }
906 | is(2.U) {
907 | PEs(i).wn := io.di(DataWidth * (i % 4 + 4) + DataWidth - 1, DataWidth * (i % 4 + 4))
908 | }
909 | }
910 | }
911 | } else if (ButterflyNum == 16) {
912 | when(io.csrs.csrButterflyCtrl.stageCfg < log2Ceil(ButterflyNum).asUInt()) {
913 | switch(io.csrs.csrButterflyCtrl.stageCfg) {
914 | is(0.U) {
915 | PEs(i).wn := io.di(DataWidth - 1, 0)
916 | }
917 | is(1.U) {
918 | PEs(i).wn := io.di(DataWidth * (i % 2 + 2) + DataWidth - 1, DataWidth * (i % 2 + 2))
919 | }
920 | is(2.U) {
921 | PEs(i).wn := io.di(DataWidth * (i % 4 + 4) + DataWidth - 1, DataWidth * (i % 4 + 4))
922 | }
923 | is(3.U) {
924 | PEs(i).wn := io.di(DataWidth * (i % 8 + 8) + DataWidth - 1, DataWidth * (i % 8 + 8))
925 | }
926 | }
927 | }
928 | } else if (ButterflyNum == 32) {
929 | when(io.csrs.csrButterflyCtrl.stageCfg < log2Ceil(ButterflyNum).asUInt()) {
930 | switch(io.csrs.csrButterflyCtrl.stageCfg) {
931 | is(0.U) {
932 | PEs(i).wn := io.di(DataWidth - 1, 0)
933 | }
934 | is(1.U) {
935 | PEs(i).wn := io.di(DataWidth * (i % 2 + 2) + DataWidth - 1, DataWidth * (i % 2 + 2))
936 | }
937 | is(2.U) {
938 | PEs(i).wn := io.di(DataWidth * (i % 4 + 4) + DataWidth - 1, DataWidth * (i % 4 + 4))
939 | }
940 | is(3.U) {
941 | PEs(i).wn := io.di(DataWidth * (i % 8 + 8) + DataWidth - 1, DataWidth * (i % 8 + 8))
942 | }
943 | is(4.U) {
944 | PEs(i).wn := io.di(DataWidth * (i % 16 + 16) + DataWidth - 1, DataWidth * (i % 16 + 16))
945 | }
946 | }
947 | }
948 | } else if (ButterflyNum == 64) {
949 | when(io.csrs.csrButterflyCtrl.stageCfg < log2Ceil(ButterflyNum).asUInt()) {
950 | switch(io.csrs.csrButterflyCtrl.stageCfg) {
951 | is(0.U) {
952 | PEs(i).wn := io.di(DataWidth - 1, 0)
953 | }
954 | is(1.U) {
955 | PEs(i).wn := io.di(DataWidth * (i % 2 + 2) + DataWidth - 1, DataWidth * (i % 2 + 2))
956 | }
957 | is(2.U) {
958 | PEs(i).wn := io.di(DataWidth * (i % 4 + 4) + DataWidth - 1, DataWidth * (i % 4 + 4))
959 | }
960 | is(3.U) {
961 | PEs(i).wn := io.di(DataWidth * (i % 8 + 8) + DataWidth - 1, DataWidth * (i % 8 + 8))
962 | }
963 | is(4.U) {
964 | PEs(i).wn := io.di(DataWidth * (i % 16 + 16) + DataWidth - 1, DataWidth * (i % 16 + 16))
965 | }
966 | is(5.U) {
967 | PEs(i).wn := io.di(DataWidth * (i % 32 + 32) + DataWidth - 1, DataWidth * (i % 32 + 32))
968 | }
969 | }
970 | }
971 | }
972 |
973 | PEs(i).n := io.csrs.csrModulusLen
974 | PEs(i).m := io.csrs.csrModulusq
975 | PEs(i).u := io.csrs.csrBarretu
976 |
977 | }
978 |
979 | // permutation out
980 | val permNetOut = Module(new PermNetOut)
981 | permNetOut.io.config := permNet.io.config
982 | val vecArithOut = Wire(Vec(2 * ButterflyNum, UInt(DataWidth.W)))
983 | for (i <- 0 until ButterflyNum) {
984 | vecArithOut(i) := PEs(i).aout
985 | vecArithOut(i + ButterflyNum) := 0.U
986 | permNetOut.io.in(2 * i) := PEs(i).aout
987 | permNetOut.io.in(2 * i + 1) := PEs(i).bout
988 | }
989 | io.dataOut := Mux(io.valid, permNetOut.io.out, vecArithOut)
990 | }
991 |
992 | object elaborateNTTWithoutRamShare extends App {
993 | chisel3.Driver.execute(args, () => new NTTWithoutRamShare)
994 | }
995 |
996 | object elaborateNTTWithoutRam extends App {
997 | chisel3.Driver.execute(args, () => new NTTWithoutRam)
998 | }
999 |
1000 | object elaborateNTT extends App {
1001 | chisel3.Driver.execute(args, () => new NTT)
1002 | }
1003 |
1004 | //class SwitchIO extends Bundle
1005 | // with HasCommonParameters {
1006 | // val a = Input(UInt(DataWidth.W))
1007 | // val b = Input(UInt(DataWidth.W))
1008 | // val sel = Input(Bool())
1009 | // val aout = Output(UInt(DataWidth.W))
1010 | // val bout = Output(UInt(DataWidth.W))
1011 | //}
1012 | //
1013 | //class Switch extends Module {
1014 | // val io = IO(new SwitchIO())
1015 | // io.aout := Mux(io.sel, io.b, io.a)
1016 | // io.bout := Mux(io.sel, io.a, io.b)
1017 | //}
1018 | //object Switch {
1019 | // def apply(a: UInt, b: UInt, sel: Bool, aout: UInt, bout: UInt): Module = {
1020 | // val inst = Module(new Switch)
1021 | // inst.io.a := a
1022 | // inst.io.b := b
1023 | // inst.io.sel := sel
1024 | // aout := inst.io.aout
1025 | // bout := inst.io.bout
1026 | // inst
1027 | // }
1028 | //}
1029 |
1030 | //class NTTR2MDCIO extends Bundle
1031 | // with HasNTTCommonParameters {
1032 | // val dIn = Input(UInt(DataWidth.W))
1033 | // val dInValid = Input(Bool())
1034 | // val dOut1 = Output(UInt(DataWidth.W))
1035 | // // val addrOut1 = Output(UInt(AddrWidth.W))
1036 | // val dOut2 = Output(UInt(DataWidth.W))
1037 | // // val addrOut2 = Output(UInt(AddrWidth.W))
1038 | // val dOutValid = Output(Bool())
1039 | //}
1040 | //
1041 | //class NTTR2MDC extends Module
1042 | // with HasMRParameters
1043 | // with HasNTTCommonParameters
1044 | // with HasNTTParameters {
1045 | // val io = IO(new NTTR2MDCIO())
1046 | //
1047 | // val stages: Int = AddrWidth
1048 | //
1049 | // val wn = RegInit(VecInit(Seq.fill(stages)(1.U(DataWidth.W))))
1050 | // val cnt = RegInit(0.U((stages).W))
1051 | // when(io.dInValid){
1052 | // cnt := cnt + 1.U
1053 | // for (i <- 0 until stages - 1) {
1054 | // val res = MRFix(wn(i) * WN(i).asUInt())
1055 | // wn(i) := Mux(res === (MRq.asUInt() - 1.U), 1.U, res)
1056 | // }
1057 | // }
1058 | //
1059 | // val out1 = VecInit(Seq.fill(stages)(0.U(DataWidth.W)))
1060 | // val out2 = VecInit(Seq.fill(stages)(0.U(DataWidth.W)))
1061 | // /***
1062 | // *
1063 | // * pre modular reduction
1064 | // *
1065 | // *
1066 | // */
1067 | //
1068 | // val dIn = Mux(io.dIn < MRq.asUInt(), io.dIn, io.dIn - MRq.asUInt())
1069 | //
1070 | // for (i <- 0 until stages){
1071 | // if (i == 0) {
1072 | // val regIn = ShiftRegister(dIn, VectorLength / 2)
1073 | // val BFOut1 = Wire(UInt(DataWidth.W))
1074 | // val BFOut2 = Wire(UInt(DataWidth.W))
1075 | // BF2(regIn, dIn, BFOut1, BFOut2)
1076 | // /**
1077 | // * can add pipeline if necessary
1078 | // *
1079 | // **/
1080 | // val mulRes = MRFix(BFOut2 * wn(i))
1081 | // val switchIn1 = BFOut1
1082 | // val switchIn2 = ShiftRegister(mulRes, VectorLength/4)
1083 | // val swCtrl = cnt(stages-2)
1084 | // Switch(switchIn1, switchIn2, swCtrl, out1(0), out2(0))
1085 | // }
1086 | // else if (i < stages - 1){
1087 | // val regIn = ShiftRegister(out1(i-1), (VectorLength/pow(2, i + 1)).toInt)
1088 | // val BFOut1 = Wire(UInt(14.W))
1089 | // val BFOut2 = Wire(UInt(14.W))
1090 | // BF2(regIn, out2(i-1), BFOut1, BFOut2)
1091 | // /**
1092 | // * can add pipeline if necessary
1093 | // *
1094 | // **/
1095 | // val mulRes = MRFix(BFOut2 * wn(i))
1096 | // val switchIn1 = BFOut1
1097 | // val switchIn2 = ShiftRegister(mulRes, (VectorLength/pow(2, i + 2)).toInt)
1098 | // val swCtrl = cnt(stages-2-i)
1099 | // Switch(switchIn1, switchIn2, swCtrl, out1(i), out2(i))
1100 | // }
1101 | // else {
1102 | // val regIn = ShiftRegister(out1(i-1), (VectorLength/pow(2, i + 1)).toInt)
1103 | // BF2(regIn, out2(i-1), out1(i), out2(i))
1104 | // }
1105 | // }
1106 | // io.dOut1 := RegNext(out1(stages - 1))
1107 | // io.dOut2 := RegNext(out2(stages - 1))
1108 | // io.dOutValid := ShiftRegister(io.dInValid, VectorLength)
1109 | //
1110 | //}
1111 |
--------------------------------------------------------------------------------
/src/main/scala/VPQC/NTTParameters.scala:
--------------------------------------------------------------------------------
1 |
2 | package VPQC
3 |
4 | import chisel3.util._
5 | import scala.math._
6 |
7 | // using Barrett reduction
8 | //trait HasMRParameters {
9 | // val MRalpha = 16 + 1
10 | // val MRbeta = -2
11 | // val MRq = 12289
12 | // val MRu = (2 << (14 + MRalpha)) / MRq
13 | //}
14 |
15 | //trait HasNTTCommonParameters {
16 | // val DataWidth = 16
17 | // val VectorLength = 512
18 | // // 9 bits
19 | // val AddrWidth = log2Ceil(VectorLength)
20 | //}
21 |
22 | trait HasNTTParameters {
23 | /**
24 | * For example in NEWHOPE512 parameters
25 | * r = 10968, w = 3, w-1 = 8193, r-1 = 3656, n-1 = 12265
26 | * w is root of unity, 3 ^ 512 = 1, 3 ^ 256 = 12289 - 1, 3 ^ 128 = 1479
27 | **/
28 |
29 | val ButterflyNum = 32
30 | val Dimension = 1024 //ntt dimension
31 |
32 | val WN = new Array[Int](9)
33 | for (i <- 0 until 9 - 1) {
34 | if (i == 0) {
35 | WN(i) = 3
36 | } else {
37 | WN(i) = (pow(WN(i-1), 2) % 12289).toInt
38 | }
39 | }
40 | }
--------------------------------------------------------------------------------
/src/main/scala/VPQC/PQCCoprocessor.scala:
--------------------------------------------------------------------------------
1 |
2 | package VPQC
3 |
4 | import chisel3._
5 | import chisel3.util._
6 | import utility._
7 |
8 | // standalone coprocessor
9 |
10 | class PQCCoprocessorIO extends Bundle
11 | with HasCommonParameters
12 | with HasNTTParameters
13 | with HasPQCInstructions {
14 | // for decode
15 | val instr = Input(new PQCInstruction)
16 | val rs1 = Input(UInt(32.W))
17 | val rs2 = Input(UInt(32.W))
18 | val rd = Output(UInt(32.W))
19 | val rdw = Output(Bool())
20 | val in_fire = Input(Bool())
21 | val rdy = Output(Bool())
22 | val busy = Output(Bool())
23 |
24 | // for Keccak core
25 | val seed = Input(new stateArray())
26 | val seedWrite = Input(Bool())
27 |
28 | // for twiddle factors / const RAM, from outside
29 | val twiddleData = Input(UInt((DataWidth * ButterflyNum).W))
30 | val twiddleAddr = Input(UInt(log2Ceil(Dimension / ButterflyNum).W))
31 | val twiddleWrite = Input(Bool())
32 |
33 | val constReadEnable = Input(Bool())
34 | val constReadAddr = Input(UInt(log2Ceil(2 * Dimension / ButterflyNum).W))
35 | val constWrite = Input(Bool())
36 |
37 | // vector interface
38 | val vectorOut = Output(Vec(ML, UInt(DataWidth.W)))
39 | val vectorWriteAddr = Output(UInt(32.W))
40 | val vectorWriteEnable = Output(Bool())
41 |
42 | val vectorIn = Input(Vec(ML, UInt(DataWidth.W)))
43 | val vectorReadaddr = Output(UInt(32.W))
44 | val vectorReadEnable = Output(Bool())
45 | }
46 |
47 | class CSRIO extends Bundle
48 | with HasCommonParameters {
49 | val csrBarretu = UInt((DataWidth + 2).W)
50 | val csrBound = UInt((DataWidth * 2).W)
51 | val csrBinomialk = UInt(3.W)
52 | val csrModulusq = UInt(DataWidth.W)
53 | val csrModulusLen = UInt(5.W)
54 | /**
55 | *
56 | * 9 ... 6 5 ... 0
57 | * stage_cfg iter_cfg
58 | */
59 |
60 | val csrButterflyCtrl = new Bundle {
61 | val stageCfg = UInt(4.W)
62 | val iterCfg = UInt(6.W)
63 | }
64 | val csrLsBaseAddr = UInt(32.W)
65 | }
66 |
67 | class PQCCoprocessor extends Module
68 | with HasCommonParameters
69 | with HasPQCInstructions {
70 | val io = IO(new PQCCoprocessorIO)
71 |
72 | // decode
73 | val decoder = Module(new PQCDecode)
74 | decoder.io.instr := io.instr
75 | decoder.io.rs1 := io.rs1
76 | decoder.io.rs2 := io.rs2
77 | decoder.io.in_fire := io.in_fire
78 | val doinstr = decoder.io.result
79 |
80 | // CSRs
81 | val csrBarretu = RegInit(0.U((DataWidth + 2).W))
82 | val csrBound = RegInit(0.U((DataWidth * 2).W))
83 | val csrBinomialk = RegInit(0.U(3.W))
84 | val csrModulusq = RegInit(0.U(DataWidth.W))
85 | val csrModulusLen = RegInit(0.U(5.W))
86 | val csrButterflyCtrl = RegInit(0.U(10.W))
87 | val csrLsBaseAddr = RegInit(0.U(32.W))
88 |
89 | // vector SRAM
90 | val vecSRAM = Module(new VectorRegister2Port)
91 | vecSRAM.io.vectorReadAddr1 := decoder.io.vrs1idx
92 | vecSRAM.io.vectorReadAddr2 := decoder.io.vrs2idx
93 | val vecop1 = vecSRAM.io.vectorReadValue1
94 | val vecop2 = vecSRAM.io.vectorReadValue2
95 |
96 | // const RAM
97 | val constRAM = Module(new SyncRam(dep = 64, dw = DataWidth * ML))
98 | constRAM.io.re := io.constReadEnable
99 | constRAM.io.we := io.constWrite
100 | constRAM.io.ra := io.constReadAddr
101 | constRAM.io.wa := io.twiddleAddr
102 | constRAM.io.di := io.twiddleData
103 | val constVal = Wire(Vec(ML, UInt(DataWidth.W)))
104 | for (i <- 0 until ML) {
105 | constVal(i) := constRAM.io.dout(DataWidth * i + DataWidth - 1, DataWidth * i)
106 | }
107 |
108 | // ======================================
109 | // EXU Stage
110 | // ======================================
111 | val exunit = Module(new PQCExu)
112 |
113 | exunit.io.seed := io.seed
114 | exunit.io.seedWrite := io.seedWrite
115 | exunit.io.twiddleData := io.twiddleData
116 | exunit.io.twiddleAddr := io.twiddleAddr
117 | exunit.io.twiddleWrite := io.twiddleWrite
118 | for (i <- 0 until INSTR_QUANTITY) {
119 | exunit.io.valid(i) := doinstr(i) && decoder.io.fire
120 | }
121 |
122 | exunit.io.vrs1 := vecop1
123 | exunit.io.vrs2 := Mux(io.constReadEnable, constVal, vecop2)
124 | exunit.io.csrs.csrBarretu := csrBarretu
125 | exunit.io.csrs.csrBound := csrBound
126 | exunit.io.csrs.csrBinomialk := csrBinomialk
127 | exunit.io.csrs.csrModulusq := csrModulusq
128 | exunit.io.csrs.csrModulusLen := csrModulusLen
129 | exunit.io.csrs.csrButterflyCtrl.stageCfg := csrButterflyCtrl(9, 6)
130 | exunit.io.csrs.csrButterflyCtrl.iterCfg := csrButterflyCtrl(5, 0)
131 | exunit.io.csrs.csrLsBaseAddr := csrLsBaseAddr
132 |
133 | vecSRAM.io.vectorWriteEnable1 := exunit.io.done && exunit.io.wb
134 | vecSRAM.io.vectorWriteEnable2 := exunit.io.done && exunit.io.wb && exunit.io.wpos
135 | vecSRAM.io.vectorWriteData1 := exunit.io.vres1
136 | vecSRAM.io.vectorWriteData2 := exunit.io.vres2
137 | vecSRAM.io.vectorWriteAddr1 := Mux(!RegNext((doinstr(INSTR_BUTTERFLY) || doinstr(INSTR_IBUTTERFLY))
138 | && decoder.io.fire), RegNext(decoder.io.vrdidx), RegNext(decoder.io.vrs1idx))
139 | vecSRAM.io.vectorWriteAddr2 := RegNext(decoder.io.vrs2idx)
140 |
141 | io.rd := 0.U
142 | io.rdw := false.B
143 | // CSRRW
144 | when(decoder.io.fire && doinstr(INSTR_CSRRW)){
145 | io.rdw:= true.B
146 | switch(decoder.io.vrs2idx) {
147 | is(0.U) {
148 | io.rd := csrBarretu
149 | csrBarretu := decoder.io.srs1
150 | }
151 | is(1.U) {
152 | io.rd := csrBound
153 | csrBound := decoder.io.srs1
154 | }
155 | is(2.U) {
156 | io.rd := csrBinomialk
157 | csrBinomialk := decoder.io.srs1
158 | }
159 | is(3.U) {
160 | io.rd := csrModulusq
161 | csrModulusq := decoder.io.srs1
162 | }
163 | is(4.U) {
164 | io.rd := csrModulusLen
165 | csrModulusLen := decoder.io.srs1
166 | }
167 | is(5.U) {
168 | io.rd := csrButterflyCtrl
169 | csrButterflyCtrl := decoder.io.srs1
170 | }
171 | is(6.U) {
172 | io.rd := csrLsBaseAddr
173 | csrLsBaseAddr := decoder.io.srs1
174 | }
175 | }
176 | }
177 |
178 | // CSRRWI
179 | when(decoder.io.fire && doinstr(INSTR_CSRRWI)){
180 | switch(decoder.io.vrs2idx) {
181 | is(0.U) {
182 | csrBarretu := decoder.io.vrs1idx
183 | }
184 | is(1.U) {
185 | csrBound := decoder.io.vrs1idx
186 | }
187 | is(2.U) {
188 | csrBinomialk := decoder.io.vrs1idx
189 | }
190 | is(3.U) {
191 | csrModulusq := decoder.io.vrs1idx
192 | }
193 | is(4.U) {
194 | csrModulusLen := decoder.io.vrs1idx
195 | }
196 | is(5.U) {
197 | csrButterflyCtrl := decoder.io.vrs1idx
198 | }
199 | is(6.U) {
200 | csrLsBaseAddr := decoder.io.vrs1idx
201 | }
202 | }
203 | }
204 |
205 | // additional connect
206 | decoder.io.busy := exunit.io.busy
207 | io.rdy := exunit.io.done && !exunit.io.busy
208 | io.busy := decoder.io.busy
209 |
210 | // memory
211 | io.vectorReadaddr := csrLsBaseAddr
212 | io.vectorWriteAddr := csrLsBaseAddr
213 | io.vectorOut := vecop1
214 |
215 | when(decoder.io.fire && doinstr(INSTR_VLD)) {
216 | vecSRAM.io.vectorWriteData1 := io.vectorIn
217 | vecSRAM.io.vectorWriteEnable1 := true.B
218 | }
219 |
220 | io.vectorReadEnable := decoder.io.fire && doinstr(INSTR_VLD)
221 | io.vectorWriteEnable := decoder.io.fire && doinstr(INSTR_VST)
222 | }
223 |
224 | object elaboratePQCCoprocessor extends App {
225 | chisel3.Driver.execute(args, () => new PQCCoprocessor())
226 | }
227 |
228 | /**
229 | *
230 | * ----------------------------------------------------------------------
231 | *
232 | *
233 | **/
234 | // for synthesis
235 | class PQCCoprocessorNoMemIO extends Bundle
236 | with HasCommonParameters
237 | with HasNTTParameters
238 | with HasPQCInstructions {
239 | // for decode
240 | val instr = Input(new PQCInstruction)
241 | val rs1 = Input(UInt(32.W))
242 | val rs2 = Input(UInt(32.W))
243 | val rd = Output(UInt(32.W))
244 | val rdw = Output(Bool())
245 | val in_fire = Input(Bool())
246 | val rdy = Output(Bool())
247 | val busy = Output(Bool())
248 |
249 | // for SHA3
250 | val seed = Input(new stateArray())
251 | val seedWrite = Input(Bool())
252 | val writeData = Output(UInt((ML * DataWidth).W))
253 | val writeEnable = Output(Bool())
254 | val readEnable = Output(Bool())
255 | val readData = Input(UInt((ML * DataWidth).W))
256 | val readEmpty = Input(Bool())
257 | val writeFull = Input(Bool())
258 |
259 | // for vecop, from outside
260 | val vectorReadAddr1 = Output(UInt(5.W))
261 | val vectorReadAddr2 = Output(UInt(5.W))
262 | val vectorWriteData1 = Output(Vec(ML, UInt(DataWidth.W)))
263 | val vectorWriteAddr1 = Output(UInt(32.W))
264 | val vectorWriteEnable1 = Output(Bool())
265 | val vectorWriteData2 = Output(Vec(ML, UInt(DataWidth.W)))
266 | val vectorWriteAddr2 = Output(UInt(32.W))
267 | val vectorWriteEnable2 = Output(Bool())
268 | val vectorReadValue1 = Input(Vec(ML, UInt(DataWidth.W)))
269 | val vectorReadValue2 = Input(Vec(ML, UInt(DataWidth.W)))
270 |
271 | // for ConstRam, from outside
272 | // ram from outside
273 | val ra = Output(UInt(log2Ceil(Dimension / ButterflyNum).W))
274 | val re = Output(Bool())
275 | val di = Input(UInt((ButterflyNum * DataWidth).W))
276 |
277 | // vector interface
278 | val vectorOut = Output(Vec(ML, UInt(DataWidth.W)))
279 | val vectorWriteAddr = Output(UInt(32.W))
280 | val vectorWriteEnable = Output(Bool())
281 |
282 | val vectorIn = Input(Vec(ML, UInt(DataWidth.W)))
283 | val vectorReadaddr = Output(UInt(32.W))
284 | val vectorReadEnable = Output(Bool())
285 | }
286 |
287 | class PQCCoprocessorNoMem extends Module
288 | with HasCommonParameters
289 | with HasPQCInstructions {
290 | val io = IO(new PQCCoprocessorNoMemIO)
291 |
292 | // decode
293 | val decoder = Module(new PQCDecode)
294 | decoder.io.instr := io.instr
295 | decoder.io.rs1 := io.rs1
296 | decoder.io.rs2 := io.rs2
297 | decoder.io.in_fire := io.in_fire
298 | val doinstr = decoder.io.result
299 |
300 | // CSRs
301 | val csrBarretu = RegInit(0.U((DataWidth + 2).W))
302 | val csrBound = RegInit(0.U((DataWidth * 2).W))
303 | val csrBinomialk = RegInit(0.U(3.W))
304 | val csrModulusq = RegInit(0.U(DataWidth.W))
305 | val csrModulusLen = RegInit(0.U(5.W))
306 | val csrButterflyCtrl = RegInit(0.U(10.W))
307 | val csrLsBaseAddr = RegInit(0.U(32.W))
308 |
309 | // vector SRAM
310 | io.vectorReadAddr1 := decoder.io.vrs1idx
311 | io.vectorReadAddr2 := decoder.io.vrs2idx
312 | val vecop1 = io.vectorReadValue1
313 | val vecop2 = io.vectorReadValue2
314 |
315 | // ======================================
316 | // EXU Stage
317 | // ======================================
318 | val exunit = Module(new PQCExuNoMem)
319 |
320 | exunit.io.seed := io.seed
321 | exunit.io.seedWrite := io.seedWrite
322 | exunit.io.readEmpty := io.readEmpty
323 | exunit.io.readData := io.readData
324 | exunit.io.writeFull := io.writeFull
325 | io.writeData := exunit.io.writeData
326 | io.writeEnable := exunit.io.writeEnable
327 | io.readEnable := exunit.io.readEnable
328 |
329 | exunit.io.di := io.di
330 | io.re := exunit.io.re
331 | io.ra := exunit.io.ra
332 | for (i <- 0 until INSTR_QUANTITY) {
333 | exunit.io.valid(i) := doinstr(i) && decoder.io.fire
334 | }
335 |
336 | exunit.io.vrs1 := vecop1
337 | exunit.io.vrs2 := vecop2
338 | exunit.io.csrs.csrBarretu := csrBarretu
339 | exunit.io.csrs.csrBound := csrBound
340 | exunit.io.csrs.csrBinomialk := csrBinomialk
341 | exunit.io.csrs.csrModulusq := csrModulusq
342 | exunit.io.csrs.csrModulusLen := csrModulusLen
343 | exunit.io.csrs.csrButterflyCtrl.stageCfg := csrButterflyCtrl(9, 6)
344 | exunit.io.csrs.csrButterflyCtrl.iterCfg := csrButterflyCtrl(5, 0)
345 | exunit.io.csrs.csrLsBaseAddr := csrLsBaseAddr
346 |
347 | io.vectorWriteEnable1 := exunit.io.done && exunit.io.wb
348 | io.vectorWriteEnable2 := exunit.io.done && exunit.io.wb && exunit.io.wpos
349 | io.vectorWriteData1 := exunit.io.vres1
350 | io.vectorWriteData2 := exunit.io.vres2
351 | io.vectorWriteAddr1 := Mux(!RegNext((doinstr(INSTR_BUTTERFLY) || doinstr(INSTR_IBUTTERFLY))
352 | && decoder.io.fire), RegNext(decoder.io.vrdidx), RegNext(decoder.io.vrs1idx))
353 | io.vectorWriteAddr2 := RegNext(decoder.io.vrs2idx)
354 |
355 | io.rd := 0.U
356 | io.rdw := false.B
357 | // CSRRW
358 | when(decoder.io.fire && doinstr(INSTR_CSRRW)){
359 | io.rdw := true.B
360 | switch(decoder.io.vrs2idx) {
361 | is(0.U) {
362 | io.rd := csrBarretu
363 | csrBarretu := decoder.io.srs1
364 | }
365 | is(1.U) {
366 | io.rd := csrBound
367 | csrBound := decoder.io.srs1
368 | }
369 | is(2.U) {
370 | io.rd := csrBinomialk
371 | csrBinomialk := decoder.io.srs1
372 | }
373 | is(3.U) {
374 | io.rd := csrModulusq
375 | csrModulusq := decoder.io.srs1
376 | }
377 | is(4.U) {
378 | io.rd := csrModulusLen
379 | csrModulusLen := decoder.io.srs1
380 | }
381 | is(5.U) {
382 | io.rd := csrButterflyCtrl
383 | csrButterflyCtrl := decoder.io.srs1
384 | }
385 | is(6.U) {
386 | io.rd := csrLsBaseAddr
387 | csrLsBaseAddr := decoder.io.srs1
388 | }
389 | }
390 | }
391 |
392 | // CSRRWI
393 | when(decoder.io.fire && doinstr(INSTR_CSRRWI)){
394 | switch(decoder.io.vrs2idx) {
395 | is(0.U) {
396 | csrBarretu := decoder.io.vrs1idx
397 | }
398 | is(1.U) {
399 | csrBound := decoder.io.vrs1idx
400 | }
401 | is(2.U) {
402 | csrBinomialk := decoder.io.vrs1idx
403 | }
404 | is(3.U) {
405 | csrModulusq := decoder.io.vrs1idx
406 | }
407 | is(4.U) {
408 | csrModulusLen := decoder.io.vrs1idx
409 | }
410 | is(5.U) {
411 | csrButterflyCtrl := decoder.io.vrs1idx
412 | }
413 | is(6.U) {
414 | csrLsBaseAddr := decoder.io.vrs1idx
415 | }
416 | }
417 | }
418 |
419 | // additional connect
420 | decoder.io.busy := exunit.io.busy
421 | io.rdy := exunit.io.done && !exunit.io.busy
422 | io.busy := decoder.io.busy
423 |
424 | // memory
425 | io.vectorReadaddr := csrLsBaseAddr
426 | io.vectorWriteAddr := csrLsBaseAddr
427 | io.vectorOut := vecop1
428 |
429 | when(decoder.io.fire && doinstr(INSTR_VLD)) {
430 | io.vectorWriteData1 := io.vectorIn
431 | io.vectorWriteEnable1 := true.B
432 | }
433 |
434 | io.vectorReadEnable := decoder.io.fire && doinstr(INSTR_VLD)
435 | io.vectorWriteEnable := decoder.io.fire && doinstr(INSTR_VST)
436 | }
437 |
438 | object elaboratePQCCoprocessorNoMem extends App {
439 | chisel3.Driver.execute(args, () => new PQCCoprocessorNoMem())
440 | }
441 |
442 | //class PQCCoprocessorNoMemIO extends Bundle
443 | // with HasCommonParameters
444 | // with HasNTTParameters
445 | // with HasPQCInstructions {
446 | // // for decode
447 | // val instr = Input(new PQCInstruction)
448 | // val rs1 = Input(UInt(32.W))
449 | // val rs2 = Input(UInt(32.W))
450 | // val rd = Output(UInt(32.W))
451 | // val rdw = Output(Bool())
452 | // val in_fire = Input(Bool())
453 | // val rdy = Output(Bool())
454 | // val busy = Output(Bool())
455 | //
456 | // // for SHA3
457 | // val seed = Input(new stateArray())
458 | // val seedWrite = Input(Bool())
459 | // val writeData = Output(UInt((ML * DataWidth).W))
460 | // val writeEnable = Output(Bool())
461 | // val readEnable = Output(Bool())
462 | // val readData = Input(UInt((ML * DataWidth).W))
463 | // val readEmpty = Input(Bool())
464 | // val writeFull = Input(Bool())
465 | //
466 | // // for vecop, from outside
467 | // val vectorReadAddr1 = Output(UInt(5.W))
468 | // val vectorReadAddr2 = Output(UInt(5.W))
469 | // val vectorWriteData1 = Output(Vec(ML, UInt(DataWidth.W)))
470 | // val vectorWriteAddr1 = Output(UInt(32.W))
471 | // val vectorWriteEnable1 = Output(Bool())
472 | // val vectorWriteData2 = Output(Vec(ML, UInt(DataWidth.W)))
473 | // val vectorWriteAddr2 = Output(UInt(32.W))
474 | // val vectorWriteEnable2 = Output(Bool())
475 | // val vectorReadValue1 = Input(Vec(ML, UInt(DataWidth.W)))
476 | // val vectorReadValue2 = Input(Vec(ML, UInt(DataWidth.W)))
477 | //
478 | // // for ConstRam, from outside
479 | // // ram from outside
480 | // val ra = Output(UInt(log2Ceil(Dimension / ButterflyNum).W))
481 | // val re = Output(Bool())
482 | // val di = Input(UInt((ButterflyNum * DataWidth).W))
483 | //
484 | // // vector interface
485 | // val vectorOut = Output(Vec(ML, UInt(DataWidth.W)))
486 | // val vectorWriteAddr = Output(UInt(32.W))
487 | // val vectorWriteEnable = Output(Bool())
488 | //
489 | // val vectorIn = Input(Vec(ML, UInt(DataWidth.W)))
490 | // val vectorReadaddr = Output(UInt(32.W))
491 | // val vectorReadEnable = Output(Bool())
492 | //}
493 | //
494 | //class PQCCoprocessorNoMem extends Module
495 | // with HasCommonParameters
496 | // with HasPQCInstructions {
497 | // val io = IO(new PQCCoprocessorNoMemIO)
498 | //
499 | // // decode
500 | // val decoder = Module(new PQCDecode)
501 | // decoder.io.instr := io.instr
502 | // decoder.io.rs1 := io.rs1
503 | // decoder.io.rs2 := io.rs2
504 | // decoder.io.in_fire := io.in_fire
505 | // val doinstr = decoder.io.result
506 | //
507 | // // CSRs
508 | // val csrBarretu = RegInit(0.U((DataWidth + 2).W))
509 | // val csrBound = RegInit(0.U((DataWidth * 2).W))
510 | // val csrBinomialk = RegInit(0.U(3.W))
511 | // val csrModulusq = RegInit(0.U(DataWidth.W))
512 | // val csrModulusLen = RegInit(0.U(5.W))
513 | // val csrButterflyCtrl = RegInit(0.U(10.W))
514 | // val csrLsBaseAddr = RegInit(0.U(32.W))
515 | //
516 | // // vector SRAM
517 | // io.vectorReadAddr1 := decoder.io.vrs1idx
518 | // io.vectorReadAddr2 := decoder.io.vrs2idx
519 | // val vecop1 = io.vectorReadValue1
520 | // val vecop2 = io.vectorReadValue2
521 | //
522 | // // ======================================
523 | // // EXU Stage
524 | // // ======================================
525 | // val exunit = Module(new PQCExuNoMem)
526 | //
527 | // exunit.io.seed := io.seed
528 | // exunit.io.seedWrite := io.seedWrite
529 | // exunit.io.readEmpty := io.readEmpty
530 | // exunit.io.readData := io.readData
531 | // exunit.io.writeFull := io.writeFull
532 | // io.writeData := exunit.io.writeData
533 | // io.writeEnable := exunit.io.writeEnable
534 | // io.readEnable := exunit.io.readEnable
535 | //
536 | // exunit.io.di := io.di
537 | // io.re := exunit.io.re
538 | // io.ra := exunit.io.ra
539 | // for (i <- 0 until INSTR_QUANTITY) {
540 | // exunit.io.valid(i) := doinstr(i) && decoder.io.fire
541 | // }
542 | //
543 | // exunit.io.vrs1 := vecop1
544 | // exunit.io.vrs2 := vecop2
545 | // exunit.io.csrs.csrBarretu := csrBarretu
546 | // exunit.io.csrs.csrBound := csrBound
547 | // exunit.io.csrs.csrBinomialk := csrBinomialk
548 | // exunit.io.csrs.csrModulusq := csrModulusq
549 | // exunit.io.csrs.csrModulusLen := csrModulusLen
550 | // exunit.io.csrs.csrButterflyCtrl.stageCfg := csrButterflyCtrl(9, 6)
551 | // exunit.io.csrs.csrButterflyCtrl.iterCfg := csrButterflyCtrl(5, 0)
552 | // exunit.io.csrs.csrLsBaseAddr := csrLsBaseAddr
553 | //
554 | // io.vectorWriteEnable1 := exunit.io.done && exunit.io.wb
555 | // io.vectorWriteEnable2 := exunit.io.done && exunit.io.wb && exunit.io.wpos
556 | // io.vectorWriteData1 := exunit.io.vres1
557 | // io.vectorWriteData2 := exunit.io.vres2
558 | // io.vectorWriteAddr1 := Mux(!RegNext((doinstr(INSTR_BUTTERFLY) || doinstr(INSTR_IBUTTERFLY))
559 | // && decoder.io.fire), RegNext(decoder.io.vrdidx), RegNext(decoder.io.vrs1idx))
560 | // io.vectorWriteAddr2 := RegNext(decoder.io.vrs2idx)
561 | //
562 | // io.rd := 0.U
563 | // io.rdw := false.B
564 | // // CSRRW
565 | // when(decoder.io.fire && doinstr(INSTR_CSRRW)){
566 | // io.rdw := true.B
567 | // switch(decoder.io.vrs2idx) {
568 | // is(0.U) {
569 | // io.rd := csrBarretu
570 | // csrBarretu := decoder.io.srs1
571 | // }
572 | // is(1.U) {
573 | // io.rd := csrBound
574 | // csrBound := decoder.io.srs1
575 | // }
576 | // is(2.U) {
577 | // io.rd := csrBinomialk
578 | // csrBinomialk := decoder.io.srs1
579 | // }
580 | // is(3.U) {
581 | // io.rd := csrModulusq
582 | // csrModulusq := decoder.io.srs1
583 | // }
584 | // is(4.U) {
585 | // io.rd := csrModulusLen
586 | // csrModulusLen := decoder.io.srs1
587 | // }
588 | // is(5.U) {
589 | // io.rd := csrButterflyCtrl
590 | // csrButterflyCtrl := decoder.io.srs1
591 | // }
592 | // is(6.U) {
593 | // io.rd := csrLsBaseAddr
594 | // csrLsBaseAddr := decoder.io.srs1
595 | // }
596 | // }
597 | // }
598 | //
599 | // // CSRRWI
600 | // when(decoder.io.fire && doinstr(INSTR_CSRRWI)){
601 | // switch(decoder.io.vrs2idx) {
602 | // is(0.U) {
603 | // csrBarretu := decoder.io.vrs1idx
604 | // }
605 | // is(1.U) {
606 | // csrBound := decoder.io.vrs1idx
607 | // }
608 | // is(2.U) {
609 | // csrBinomialk := decoder.io.vrs1idx
610 | // }
611 | // is(3.U) {
612 | // csrModulusq := decoder.io.vrs1idx
613 | // }
614 | // is(4.U) {
615 | // csrModulusLen := decoder.io.vrs1idx
616 | // }
617 | // is(5.U) {
618 | // csrButterflyCtrl := decoder.io.vrs1idx
619 | // }
620 | // is(6.U) {
621 | // csrLsBaseAddr := decoder.io.vrs1idx
622 | // }
623 | // }
624 | // }
625 | //
626 | // // additional connect
627 | // decoder.io.busy := exunit.io.busy
628 | // io.rdy := exunit.io.done && !exunit.io.busy
629 | // io.busy := decoder.io.busy
630 | //
631 | // // memory
632 | // io.vectorReadaddr := csrLsBaseAddr
633 | // io.vectorWriteAddr := csrLsBaseAddr
634 | // io.vectorOut := vecop1
635 | //
636 | // when(decoder.io.fire && doinstr(INSTR_VLD)) {
637 | // io.vectorWriteData1 := io.vectorIn
638 | // io.vectorWriteEnable1 := true.B
639 | // }
640 | //
641 | // io.vectorReadEnable := decoder.io.fire && doinstr(INSTR_VLD)
642 | // io.vectorWriteEnable := decoder.io.fire && doinstr(INSTR_VST)
643 | //}
644 |
--------------------------------------------------------------------------------
/src/main/scala/VPQC/PQCDecode.scala:
--------------------------------------------------------------------------------
1 |
2 | package VPQC
3 |
4 | import chisel3._
5 |
6 | class PQCInstruction extends Bundle {
7 | val funct = UInt(7.W)
8 | val rs2 = UInt(5.W)
9 | val rs1 = UInt(5.W)
10 | val xd = Bool()
11 | val xs1 = Bool()
12 | val xs2 = Bool()
13 | val rd = UInt(5.W)
14 | val opcode = UInt(7.W)
15 | }
16 |
17 | class PQCDecode extends Module
18 | with HasPQCInstructions
19 | with HasCommonParameters {
20 | val io = IO(new Bundle{
21 | val instr = Input(new PQCInstruction)
22 | val rs1 = Input(UInt(32.W))
23 | val rs2 = Input(UInt(32.W))
24 | val in_fire = Input(Bool())
25 | val result = Output(Vec(INSTR_QUANTITY, Bool()))
26 | val busy = Input(Bool())
27 | val pqcBusy = Output(Bool())
28 | val srs1 = Output(UInt(32.W))
29 | val srs2 = Output(UInt(32.W))
30 | val vrs1idx = Output(UInt(5.W))
31 | val vrs2idx = Output(UInt(5.W))
32 | val vrdidx = Output(UInt(5.W))
33 | val fire = Output(Bool())
34 | })
35 | for(i <- 0 until INSTR_QUANTITY){
36 | io.result(i) := io.instr.funct === i.U
37 | }
38 | io.pqcBusy := io.busy
39 | io.vrs1idx := io.instr.rs1
40 | io.vrs2idx := io.instr.rs2
41 | io.vrdidx := io.instr.rd
42 | io.srs1 := io.rs1
43 | io.srs2 := io.rs2
44 | io.fire := io.in_fire
45 | }
46 |
47 |
--------------------------------------------------------------------------------
/src/main/scala/VPQC/PQCExu.scala:
--------------------------------------------------------------------------------
1 |
2 | package VPQC
3 |
4 | import chisel3._
5 | import chisel3.util._
6 |
7 | // Current simplified version does not apply
8 | // resource sharing between ntt module and vector arithmetic module
9 | class PQCExuIO extends Bundle
10 | with HasPQCInstructions
11 | with HasCommonParameters
12 | with HasNTTParameters {
13 | // for SHA3, from outside
14 | val seed = Input(new stateArray())
15 | val seedWrite = Input(Bool())
16 | // for ConstRam, from outside
17 | val twiddleData = Input(UInt((DataWidth * ButterflyNum).W))
18 | val twiddleAddr = Input(UInt(log2Ceil(Dimension / ButterflyNum).W))
19 | val twiddleWrite = Input(Bool())
20 | val wpos = Output(Bool())
21 |
22 | // pipeline
23 | val valid = Input(Vec(INSTR_QUANTITY, Bool()))
24 | // val rs1 = Input(UInt(64.W))
25 | // val rs2 = Input(UInt(64.W))
26 | val vrs1 = Input(Vec(ML, UInt(DataWidth.W)))
27 | val vrs2 = Input(Vec(ML, UInt(DataWidth.W)))
28 | val csrs = Input(new CSRIO)
29 | val done = Output(Bool())
30 | val busy = Output(Bool())
31 | val wb = Output(Bool())
32 | val vres1 = Output(Vec(ML, UInt(DataWidth.W)))
33 | val vres2 = Output(Vec(ML, UInt(DataWidth.W)))
34 | }
35 |
36 | class PQCExu extends Module
37 | with HasPQCInstructions
38 | with HasCommonParameters
39 | with HasNTTParameters {
40 | val io = IO(new PQCExuIO())
41 |
42 | val vrs1 = io.vrs1
43 | val vrs2 = io.vrs2
44 |
45 | val SHA3Core = Module(new KeccakWithFifo)
46 | SHA3Core.io.seed := io.seed
47 | SHA3Core.io.seedWrite := io.seedWrite
48 | SHA3Core.io.valid := io.valid(INSTR_FETCHRN)
49 |
50 | val Samplers = Module(new Samplers)
51 | Samplers.io.vectorReg1 := vrs1
52 | Samplers.io.vectorReg2 := vrs2
53 | Samplers.io.csrs := io.csrs
54 | Samplers.io.valid := io.valid(INSTR_SAMPLEBINOMIAL) || io.valid(INSTR_SAMPLEREJECTION)
55 | Samplers.io.mode := io.valid(INSTR_SAMPLEBINOMIAL) || !io.valid(INSTR_SAMPLEREJECTION)
56 |
57 | val VecArith = Module(new VectorArith)
58 | VecArith.io.addA := vrs1
59 | VecArith.io.addB := vrs2
60 | VecArith.io.subA := vrs1
61 | VecArith.io.subB := vrs2
62 | VecArith.io.mulA := vrs1
63 | VecArith.io.mulB := vrs2
64 | VecArith.io.csrs := io.csrs
65 | VecArith.io.valid := io.valid(INSTR_VADD) || io.valid(INSTR_VSUB) || io.valid(INSTR_VMUL)
66 | val vecData = Mux(io.valid(INSTR_VADD), VecArith.io.addRes,
67 | Mux(io.valid(INSTR_VSUB), VecArith.io.subRes,
68 | Mux(io.valid(INSTR_VMUL), VecArith.io.mulRes, VecInit(Seq.fill(ML)(0.U(DataWidth.W))))))
69 |
70 | val NTT = Module(new NTT)
71 | NTT.io.wa := io.twiddleAddr
72 | NTT.io.di := io.twiddleData
73 | NTT.io.we := io.twiddleWrite
74 | NTT.io.vectorReg1 := vrs1
75 | NTT.io.vectorReg2 := vrs2
76 | NTT.io.csrs := io.csrs
77 | NTT.io.valid := io.valid(INSTR_BUTTERFLY) || io.valid(INSTR_IBUTTERFLY)
78 | NTT.io.mode := !io.valid(INSTR_BUTTERFLY) && io.valid(INSTR_IBUTTERFLY)
79 |
80 |
81 | val done = SHA3Core.io.done || Samplers.io.done || NTT.io.done || VecArith.io.done
82 | val busy = SHA3Core.io.busy || Samplers.io.busy || NTT.io.busy || VecArith.io.done
83 | val NTTData1 = Wire(Vec(ButterflyNum , UInt(DataWidth.W)))
84 | val NTTData2 = Wire(Vec(ButterflyNum , UInt(DataWidth.W)))
85 | for (i <- 0 until ButterflyNum) {
86 | NTTData1(i) := NTT.io.dataOut(i)
87 | }
88 | for (i <- ButterflyNum until ButterflyNum*2) {
89 | NTTData2(i-ButterflyNum) := NTT.io.dataOut(i)
90 | }
91 |
92 | val vres1 = MuxCase(SHA3Core.io.prn, Array(
93 | Samplers.io.done -> Samplers.io.sampledData,
94 | VecArith.io.done -> vecData,
95 | NTT.io.done -> NTTData1
96 | ))
97 | val vres2 = MuxCase(SHA3Core.io.prn, Array(
98 | Samplers.io.done -> Samplers.io.sampledData,
99 | VecArith.io.done -> vecData,
100 | NTT.io.done -> NTTData2
101 | ))
102 |
103 | val wb = MuxCase(false.B, Array(
104 | SHA3Core.io.done -> SHA3Core.io.wb,
105 | Samplers.io.done -> Samplers.io.wb,
106 | VecArith.io.done -> VecArith.io.wb,
107 | NTT.io.done -> NTT.io.wb
108 | ))
109 |
110 | io.done := done
111 | io.vres1 := vres1
112 | io.vres2 := vres2
113 | io.busy := busy
114 | io.wb := wb
115 | io.wpos := NTT.io.done
116 | }
117 |
118 | class PQCExuNoMemIO extends Bundle
119 | with HasPQCInstructions
120 | with HasCommonParameters
121 | with HasNTTParameters {
122 | // for SHA3, from outside
123 | val seed = Input(new stateArray())
124 | val seedWrite = Input(Bool())
125 | // prefetch FIFO IO
126 | val writeData = Output(UInt((ML * DataWidth).W))
127 | val writeEnable = Output(Bool())
128 | val readEnable = Output(Bool())
129 | val readData = Input(UInt((ML * DataWidth).W))
130 | val readEmpty = Input(Bool())
131 | val writeFull = Input(Bool())
132 |
133 | // for ConstRam, from outside
134 | // ram from outside
135 | val ra = Output(UInt(log2Ceil(Dimension / ButterflyNum).W))
136 | val re = Output(Bool())
137 | val di = Input(UInt((ButterflyNum * DataWidth).W))
138 |
139 | val wpos = Output(Bool())
140 |
141 | // pipeline
142 | val valid = Input(Vec(INSTR_QUANTITY, Bool()))
143 | // val rs1 = Input(UInt(64.W))
144 | // val rs2 = Input(UInt(64.W))
145 | val vrs1 = Input(Vec(ML, UInt(DataWidth.W)))
146 | val vrs2 = Input(Vec(ML, UInt(DataWidth.W)))
147 | val csrs = Input(new CSRIO)
148 | val done = Output(Bool())
149 | val busy = Output(Bool())
150 | val wb = Output(Bool())
151 | val vres1 = Output(Vec(ML, UInt(DataWidth.W)))
152 | val vres2 = Output(Vec(ML, UInt(DataWidth.W)))
153 | }
154 |
155 | // resource sharing
156 | class PQCExuNoMem extends Module
157 | with HasPQCInstructions
158 | with HasCommonParameters
159 | with HasNTTParameters {
160 | val io = IO(new PQCExuNoMemIO())
161 |
162 | val vrs1 = io.vrs1
163 | val vrs2 = io.vrs2
164 |
165 | val SHA3Core = Module(new KeccakNoFifo)
166 | SHA3Core.io.seed := io.seed
167 | SHA3Core.io.seedWrite := io.seedWrite
168 | SHA3Core.io.valid := io.valid(INSTR_FETCHRN)
169 | SHA3Core.io.readEmpty := io.readEmpty
170 | SHA3Core.io.readData := io.readData
171 | SHA3Core.io.writeFull := io.writeFull
172 | io.writeData := SHA3Core.io.writeData
173 | io.writeEnable := SHA3Core.io.writeEnable
174 | io.readEnable := SHA3Core.io.readEnable
175 |
176 | val Samplers = Module(new Samplers)
177 | Samplers.io.vectorReg1 := vrs1
178 | Samplers.io.vectorReg2 := vrs2
179 | Samplers.io.csrs := io.csrs
180 | Samplers.io.valid := io.valid(INSTR_SAMPLEBINOMIAL) || io.valid(INSTR_SAMPLEREJECTION)
181 | Samplers.io.mode := io.valid(INSTR_SAMPLEBINOMIAL) || !io.valid(INSTR_SAMPLEREJECTION)
182 |
183 | val NTT = Module(new NTTWithoutRam)
184 | NTT.io.di := io.di
185 | io.ra := NTT.io.ra
186 | io.re := NTT.io.re
187 |
188 | NTT.io.vectorReg1 := vrs1
189 | NTT.io.vectorReg2 := vrs2
190 | NTT.io.csrs := io.csrs
191 | NTT.io.valid := io.valid(INSTR_BUTTERFLY) || io.valid(INSTR_IBUTTERFLY)
192 | NTT.io.mode := !io.valid(INSTR_BUTTERFLY) && io.valid(INSTR_IBUTTERFLY)
193 |
194 |
195 | val done = SHA3Core.io.done || Samplers.io.done || NTT.io.done
196 | val busy = SHA3Core.io.busy || Samplers.io.busy || NTT.io.busy
197 | val NTTData1 = Wire(Vec(ButterflyNum , UInt(DataWidth.W)))
198 | val NTTData2 = Wire(Vec(ButterflyNum , UInt(DataWidth.W)))
199 | for (i <- 0 until ButterflyNum) {
200 | NTTData1(i) := NTT.io.dataOut(i)
201 | }
202 | for (i <- ButterflyNum until ButterflyNum*2) {
203 | NTTData2(i-ButterflyNum) := NTT.io.dataOut(i)
204 | }
205 |
206 | val vres1 = MuxCase(SHA3Core.io.prn, Array(
207 | Samplers.io.done -> Samplers.io.sampledData,
208 | NTT.io.done -> NTTData1
209 | ))
210 | val vres2 = MuxCase(SHA3Core.io.prn, Array(
211 | Samplers.io.done -> Samplers.io.sampledData,
212 | NTT.io.done -> NTTData2
213 | ))
214 |
215 | val wb = MuxCase(false.B, Array(
216 | SHA3Core.io.done -> SHA3Core.io.wb,
217 | Samplers.io.done -> Samplers.io.wb,
218 | NTT.io.done -> NTT.io.wb
219 | ))
220 |
221 | io.done := done
222 | io.vres1 := vres1
223 | io.vres2 := vres2
224 | io.busy := busy
225 | io.wb := wb
226 | io.wpos := NTT.io.done
227 | }
228 |
229 | class PQCExuNoMemIO2 extends Bundle
230 | with HasPQCInstructions
231 | with HasCommonParameters
232 | with HasNTTParameters {
233 | // for SHA3, from outside
234 | val seed = Input(new stateArray())
235 | val seedWrite = Input(Bool())
236 | // prefetch FIFO IO
237 | val writeData = Output(UInt((ML * DataWidth).W))
238 | val writeEnable = Output(Bool())
239 | val readEnable = Output(Bool())
240 | val readData = Input(UInt((ML * DataWidth).W))
241 | val readEmpty = Input(Bool())
242 | val writeFull = Input(Bool())
243 |
244 | // for ConstRam, from outside
245 | // ram from outside
246 | val ra = Output(UInt(log2Ceil(Dimension / ButterflyNum).W))
247 | val re = Output(Bool())
248 | val di = Input(UInt((ButterflyNum * DataWidth).W))
249 |
250 | val wpos = Output(Bool())
251 |
252 | // pipeline
253 | val valid = Input(Vec(INSTR_QUANTITY, Bool()))
254 | // val rs1 = Input(UInt(64.W))
255 | // val rs2 = Input(UInt(64.W))
256 | val vrs1 = Input(Vec(ML, UInt(DataWidth.W)))
257 | val vrs2 = Input(Vec(ML, UInt(DataWidth.W)))
258 | val csrs = Input(new CSRIO)
259 | val done = Output(Bool())
260 | val busy = Output(Bool())
261 | val wb = Output(Bool())
262 | val vres1 = Output(Vec(ML, UInt(DataWidth.W)))
263 | val vres2 = Output(Vec(ML, UInt(DataWidth.W)))
264 | }
265 |
266 | // resource sharing
267 | class PQCExuNoMem2 extends Module
268 | with HasPQCInstructions
269 | with HasCommonParameters
270 | with HasNTTParameters {
271 | val io = IO(new PQCExuNoMemIO2())
272 |
273 | val vrs1 = io.vrs1
274 | val vrs2 = io.vrs2
275 |
276 | val SHA3Core = Module(new KeccakNoFifo)
277 | SHA3Core.io.seed := io.seed
278 | SHA3Core.io.seedWrite := io.seedWrite
279 | SHA3Core.io.valid := io.valid(INSTR_FETCHRN)
280 | SHA3Core.io.readEmpty := io.readEmpty
281 | SHA3Core.io.readData := io.readData
282 | SHA3Core.io.writeFull := io.writeFull
283 | io.writeData := SHA3Core.io.writeData
284 | io.writeEnable := SHA3Core.io.writeEnable
285 | io.readEnable := SHA3Core.io.readEnable
286 |
287 | val Samplers = Module(new Samplers)
288 | Samplers.io.vectorReg1 := vrs1
289 | Samplers.io.vectorReg2 := vrs2
290 | Samplers.io.csrs := io.csrs
291 | Samplers.io.valid := io.valid(INSTR_SAMPLEBINOMIAL) || io.valid(INSTR_SAMPLEREJECTION)
292 | Samplers.io.mode := io.valid(INSTR_SAMPLEBINOMIAL) || !io.valid(INSTR_SAMPLEREJECTION)
293 |
294 | val NTT = Module(new NTTWithoutRamShare)
295 | NTT.io.di := io.di
296 | io.ra := NTT.io.ra
297 | io.re := NTT.io.re
298 |
299 | NTT.io.vectorReg1 := vrs1
300 | NTT.io.vectorReg2 := vrs2
301 | NTT.io.csrs := io.csrs
302 | NTT.io.valid := io.valid(INSTR_BUTTERFLY) || io.valid(INSTR_IBUTTERFLY)
303 | NTT.io.vecValid(0) := io.valid(INSTR_VADD)
304 | NTT.io.vecValid(1) := io.valid(INSTR_VSUB)
305 | NTT.io.vecValid(2) := io.valid(INSTR_VMUL)
306 | NTT.io.mode := !io.valid(INSTR_BUTTERFLY) && io.valid(INSTR_IBUTTERFLY)
307 |
308 |
309 | val done = SHA3Core.io.done || Samplers.io.done || NTT.io.done
310 | val busy = SHA3Core.io.busy || Samplers.io.busy || NTT.io.busy
311 | val NTTData1 = Wire(Vec(ButterflyNum , UInt(DataWidth.W)))
312 | val NTTData2 = Wire(Vec(ButterflyNum , UInt(DataWidth.W)))
313 | for (i <- 0 until ButterflyNum) {
314 | NTTData1(i) := NTT.io.dataOut(i)
315 | }
316 | for (i <- ButterflyNum until ButterflyNum*2) {
317 | NTTData2(i-ButterflyNum) := NTT.io.dataOut(i)
318 | }
319 |
320 | val vres1 = MuxCase(SHA3Core.io.prn, Array(
321 | Samplers.io.done -> Samplers.io.sampledData,
322 | NTT.io.done -> NTTData1
323 | ))
324 | val vres2 = MuxCase(SHA3Core.io.prn, Array(
325 | Samplers.io.done -> Samplers.io.sampledData,
326 | NTT.io.done -> NTTData2
327 | ))
328 |
329 | val wb = MuxCase(false.B, Array(
330 | SHA3Core.io.done -> SHA3Core.io.wb,
331 | Samplers.io.done -> Samplers.io.wb,
332 | NTT.io.done -> NTT.io.wb
333 | ))
334 |
335 | io.done := done
336 | io.vres1 := vres1
337 | io.vres2 := vres2
338 | io.busy := busy
339 | io.wb := wb
340 | io.wpos := NTT.io.done && NTT.io.valid
341 | }
--------------------------------------------------------------------------------
/src/main/scala/VPQC/SamplerParameters.scala:
--------------------------------------------------------------------------------
1 |
2 | package VPQC
3 |
4 | trait HasSamplerParameters {
5 | // datawidth for binomial
6 | // support k = 1, 2, 4, 8, 16
7 |
8 | // number of Sampler
9 | val SamplerNum = 32
10 | }
11 |
--------------------------------------------------------------------------------
/src/main/scala/VPQC/Samplers.scala:
--------------------------------------------------------------------------------
1 |
2 | package VPQC
3 |
4 | import chisel3._
5 | import chisel3.util._
6 |
7 | class BinomialSamplerIO extends Bundle
8 | with HasSamplerParameters
9 | with HasCommonParameters {
10 | val data1 = Input(UInt(DataWidth.W))
11 | val data2 = Input(UInt(DataWidth.W))
12 | val k = Input(UInt(3.W))
13 | val q = Input(UInt(DataWidth.W))
14 | val dataOut = Output(UInt(DataWidth.W))
15 | }
16 |
17 | class BinomialSampler extends Module
18 | with HasSamplerParameters
19 | with HasCommonParameters {
20 | val io = IO(new BinomialSamplerIO())
21 |
22 | val data1pop2 = Wire(Vec(DataWidth / 2, UInt(2.W)))
23 | val data2pop2 = Wire(Vec(DataWidth / 2, UInt(2.W)))
24 | for (i <- 0 until DataWidth / 2) {
25 | data1pop2(i) := io.data1(2 * i + 1) + io.data1(2 * i)
26 | data2pop2(i) := io.data2(2 * i + 1) + io.data2(2 * i)
27 | }
28 |
29 | val data1pop4 = Wire(Vec(DataWidth / 4, UInt(3.W)))
30 | val data2pop4 = Wire(Vec(DataWidth / 4, UInt(3.W)))
31 | for (i <- 0 until DataWidth / 4) {
32 | data1pop4(i) := data1pop2(2 * i + 1) + data1pop2(2 * i)
33 | data2pop4(i) := data2pop2(2 * i + 1) + data2pop2(2 * i)
34 | }
35 |
36 | val data1pop8 = Wire(Vec(DataWidth / 8, UInt(4.W)))
37 | val data2pop8 = Wire(Vec(DataWidth / 8, UInt(4.W)))
38 | for (i <- 0 until DataWidth / 8) {
39 | data1pop8(i) := data1pop4(2 * i + 1) + data1pop4(2 * i)
40 | data2pop8(i) := data2pop4(2 * i + 1) + data2pop4(2 * i)
41 | }
42 |
43 | val data1pop16 = Wire(UInt(5.W))
44 | val data2pop16 = Wire(UInt(5.W))
45 | if (DataWidth == 16) {
46 | data1pop16 := data1pop8(1) + data1pop8(0)
47 | data2pop16 := data2pop8(1) + data2pop8(0)
48 |
49 | } else {
50 | data1pop16 := data1pop8(0)
51 | data2pop16 := data2pop8(0)
52 |
53 | }
54 |
55 | // two minus operands
56 | val op1 = WireInit(0.U(5.W))
57 | val op2 = WireInit(0.U(5.W))
58 |
59 | switch(io.k) {
60 | // k = 1
61 | is(0.U) {
62 | op1 := io.data1(0)
63 | op2 := io.data2(0)
64 | }
65 | // k = 2
66 | is(1.U) {
67 | op1 := data1pop2(0)
68 | op2 := data2pop2(0)
69 | }
70 | // k = 4
71 | is(2.U) {
72 | op1 := data1pop4(0)
73 | op2 := data2pop4(0)
74 | }
75 | // k = 8
76 | is(3.U) {
77 | op1 := data1pop8(0)
78 | op2 := data2pop8(0)
79 | }
80 | // k = 16
81 | is(4.U) {
82 | op1 := data1pop16(0)
83 | op2 := data2pop16(0)
84 | }
85 | }
86 |
87 | val diff = op1 < op2
88 | val out = Mux(diff, op1 + io.q - op2, op1 - op2)
89 | io.dataOut := out
90 | }
91 |
92 | class RejectionSamplerIO extends Bundle
93 | with HasSamplerParameters
94 | with HasCommonParameters {
95 | val dataIn = Input(UInt((DataWidth * 2).W))
96 | val n = Input(UInt(5.W))
97 | val m = Input(UInt(DataWidth.W))
98 | val u = Input(UInt((DataWidth + 2).W))
99 | val bound = Input(UInt((DataWidth * 2).W))
100 | val dataOut = Output(UInt(DataWidth.W))
101 | val valid = Output(Bool())
102 | }
103 |
104 | // small rejection rate
105 | class RejectionSampler extends Module
106 | with HasSamplerParameters {
107 | val io = IO(new RejectionSamplerIO())
108 |
109 | val failed = !(io.dataIn < io.bound)
110 | io.dataOut := Mux(failed, 0.U, MR(io.dataIn, io.n, io.m, io.u))
111 | io.valid := !failed
112 | }
113 |
114 | class SamplersIO extends Bundle
115 | with HasSamplerParameters
116 | with HasKeccakParameters
117 | with HasCommonParameters {
118 | // 0: rejection sample 1: binomial sample
119 | val valid = Input(Bool())
120 | val mode = Input(Bool())
121 |
122 | // from register
123 | val vectorReg1 = Input(Vec(ML, UInt(DataWidth.W)))
124 | val vectorReg2 = Input(Vec(ML, UInt(DataWidth.W)))
125 |
126 | // csr interface
127 | val csrs = Input(new CSRIO)
128 |
129 | // output
130 | val sampledData = Output(Vec(SamplerNum, UInt(DataWidth.W)))
131 | val sampledDataValid = Output(Bool())
132 | val done = Output(Bool())
133 | val busy = Output(Bool())
134 | val wb = Output(Bool())
135 | }
136 |
137 | class Samplers extends Module
138 | with HasSamplerParameters
139 | with HasCommonParameters {
140 | val io = IO(new SamplersIO())
141 |
142 | val rejectionSamplers = VecInit(Seq.fill(SamplerNum)(Module(new RejectionSampler()).io))
143 | val rejectionData = Wire(Vec(SamplerNum, UInt(DataWidth.W)))
144 | for (i <- 0 until SamplerNum) {
145 | rejectionSamplers(i).dataIn := Cat(io.vectorReg2(i), io.vectorReg1(i))
146 | rejectionSamplers(i).bound := io.csrs.csrBound
147 | rejectionSamplers(i).n := io.csrs.csrModulusLen
148 | rejectionSamplers(i).m := io.csrs.csrModulusq
149 | rejectionSamplers(i).u := io.csrs.csrBarretu
150 | rejectionData(i) := rejectionSamplers(i).dataOut
151 | }
152 |
153 | val binomialSamplers = VecInit(Seq.fill(SamplerNum)(Module(new BinomialSampler()).io))
154 | val binomialData = Wire(Vec(SamplerNum, UInt(DataWidth.W)))
155 | for (i <- 0 until SamplerNum) {
156 | binomialSamplers(i).data1 := io.vectorReg1(i)
157 | binomialSamplers(i).data2 := io.vectorReg2(i)
158 | binomialSamplers(i).k := io.csrs.csrBinomialk
159 | binomialSamplers(i).q := io.csrs.csrModulusq
160 | binomialData(i) := binomialSamplers(i).dataOut
161 | }
162 |
163 | io.sampledData := Mux(io.mode, binomialData, rejectionData)
164 | val rejectionValid = (0 until SamplerNum).map(i => rejectionSamplers(i).valid).reduce(_ && _)
165 | io.sampledDataValid := Mux(io.mode, true.B, rejectionValid)
166 | io.done := RegNext(io.valid)
167 | io.busy := false.B
168 | io.wb := true.B
169 | }
170 |
171 | object elaborateSamplers extends App {
172 | chisel3.Driver.execute(args, () => new Samplers())
173 | }
174 |
--------------------------------------------------------------------------------
/src/main/scala/VPQC/VectorRegister.scala:
--------------------------------------------------------------------------------
1 |
2 | package VPQC
3 |
4 | import chisel3._
5 |
6 | // behavior model
7 | // use SRAM to implement it
8 | class VectorRegisterIO extends Bundle
9 | with HasCommonParameters {
10 | val vectorReadAddr1 = Input(UInt(5.W))
11 | val vectorReadAddr2 = Input(UInt(5.W))
12 | val vectorWriteAddr = Input(UInt(5.W))
13 | val vectorWriteData = Input(Vec(ML, UInt(DataWidth.W)))
14 | val vectorWriteEnable = Input(Bool())
15 | val vectorReadValue1 = Output(Vec(ML, UInt(DataWidth.W)))
16 | val vectorReadValue2 = Output(Vec(ML, UInt(DataWidth.W)))
17 | }
18 |
19 | class VectorRegister extends Module
20 | with HasCommonParameters {
21 | val io = IO(new VectorRegisterIO)
22 |
23 | val vectorSRAM = SyncReadMem(32, UInt((ML * DataWidth).W))
24 |
25 | // Create one write port and two read port.
26 | when (io.vectorWriteEnable) {
27 | vectorSRAM.write(io.vectorWriteAddr, io.vectorWriteData.asUInt())
28 | }
29 |
30 | for (i <- 0 until ML) {
31 | io.vectorReadValue1(i) := vectorSRAM.read(io.vectorReadAddr1, true.B)(DataWidth * i + DataWidth - 1, DataWidth * i)
32 | io.vectorReadValue2(i) := vectorSRAM.read(io.vectorReadAddr2, true.B)(DataWidth * i + DataWidth - 1, DataWidth * i)
33 | }
34 | }
35 |
36 | class VectorRegister2PortIO extends Bundle
37 | with HasCommonParameters {
38 | val vectorReadAddr1 = Input(UInt(5.W))
39 | val vectorReadAddr2 = Input(UInt(5.W))
40 | val vectorWriteAddr1 = Input(UInt(5.W))
41 | val vectorWriteData1 = Input(Vec(ML, UInt(DataWidth.W)))
42 | val vectorWriteAddr2 = Input(UInt(5.W))
43 | val vectorWriteData2 = Input(Vec(ML, UInt(DataWidth.W)))
44 | val vectorWriteEnable1 = Input(Bool())
45 | val vectorWriteEnable2 = Input(Bool())
46 | val vectorReadValue1 = Output(Vec(ML, UInt(DataWidth.W)))
47 | val vectorReadValue2 = Output(Vec(ML, UInt(DataWidth.W)))
48 | }
49 |
50 | class VectorRegister2Port extends Module
51 | with HasCommonParameters {
52 | val io = IO(new VectorRegister2PortIO)
53 |
54 | val vectorSRAM = SyncReadMem(32, UInt((ML * DataWidth).W))
55 |
56 | // Create two write port and two read port.
57 | when (io.vectorWriteEnable1) {
58 | vectorSRAM.write(io.vectorWriteAddr1, io.vectorWriteData1.asUInt())
59 | }
60 | when (io.vectorWriteEnable2) {
61 | vectorSRAM.write(io.vectorWriteAddr2, io.vectorWriteData2.asUInt())
62 | }
63 |
64 | for (i <- 0 until ML) {
65 | io.vectorReadValue1(i) := vectorSRAM.read(io.vectorReadAddr1, true.B)(DataWidth * i + DataWidth - 1, DataWidth * i)
66 | io.vectorReadValue2(i) := vectorSRAM.read(io.vectorReadAddr2, true.B)(DataWidth * i + DataWidth - 1, DataWidth * i)
67 | }
68 | }
69 |
70 | object elaborateVectorRegister extends App {
71 | chisel3.Driver.execute(args, () => new VectorRegister())
72 | }
--------------------------------------------------------------------------------
/src/main/scala/utility/ShiftRegs.scala:
--------------------------------------------------------------------------------
1 |
2 | package utility
3 |
4 | import chisel3._
5 |
6 | // paramized shift register
7 | class ShiftRegs(val n: Int, val w: Int) extends Module {
8 | val io = IO(new Bundle {
9 | val in = Input(UInt(w.W))
10 | val out = Output(UInt(w.W))
11 | })
12 |
13 | val initValues = Seq.fill(n) { 0.U(w.W) }
14 |
15 | val delays = RegInit(VecInit(initValues))
16 | for (i <- 0 until n) {
17 | if (i == 0) {
18 | delays(0) := io.in
19 | }
20 | else {
21 | delays(i) := delays(i-1)
22 | }
23 | }
24 | io.out := delays(n-1)
25 | }
26 |
27 | // object vecshift using apply method
28 | object ShiftRegs {
29 | def apply (n : Int, w : Int) (in: UInt) : UInt = {
30 | val inst = Module(new ShiftRegs (n, w))
31 | inst.io.in := in
32 | inst.io.out
33 | }
34 | }
35 |
--------------------------------------------------------------------------------
/src/main/scala/utility/SyncFifo.scala:
--------------------------------------------------------------------------------
1 |
2 | package utility
3 |
4 | import chisel3._
5 | import chisel3.util._
6 |
7 | class SyncFifo[T <: Data](dep: Int, dataType: T) extends Module {
8 | val io = IO(new Bundle {
9 | val writeData = Input(dataType)
10 | val writeEnable = Input(Bool())
11 | val readEnable = Input(Bool())
12 | val readData = Output(dataType)
13 | val readEmpty = Output(Bool())
14 | val writeFull = Output(Bool())
15 | })
16 |
17 | val fifo = SyncReadMem(dep, dataType)
18 | val readPtr = RegInit(0.U((log2Ceil(dep) + 1).W))
19 | val writePtr = RegInit(0.U((log2Ceil(dep) + 1).W))
20 | val readAddr = readPtr(log2Ceil(dep)-1, 0)
21 | val writeAddr = writePtr(log2Ceil(dep)-1, 0)
22 |
23 | io.readEmpty := (readPtr === writePtr)
24 | io.writeFull := (readAddr === writeAddr) && (readPtr(log2Ceil(dep)) ^ writePtr(log2Ceil(dep)))
25 |
26 | io.readData := fifo.read(readAddr)
27 | when(io.readEnable && !io.readEmpty) {
28 | readPtr := readPtr + 1.U
29 | }
30 | when(io.writeEnable && !io.writeFull) {
31 | fifo.write(writeAddr, io.writeData)
32 | writePtr := writePtr + 1.U
33 | }
34 | }
35 |
36 | object elaborateSyncFifo extends App {
37 | chisel3.Driver.execute(args, () => new SyncFifo(dep = 32, dataType = UInt(512.W)))
38 | }
39 |
--------------------------------------------------------------------------------
/src/main/scala/utility/SyncRam.scala:
--------------------------------------------------------------------------------
1 |
2 | package utility
3 |
4 | import chisel3._
5 | import chisel3.util._
6 |
7 | // this is a synchronous-read, synchronous-write memory
8 |
9 | class SyncRam(dep: Int, dw: Int) extends Module{
10 |
11 | val io = IO(new Bundle {
12 | //control signal
13 | val re = Input(Bool())
14 | val we = Input(Bool())
15 |
16 | //data signal
17 | val ra = Input(UInt(log2Ceil(dep).W))
18 | val wa = Input(UInt(log2Ceil(dep).W))
19 | val di = Input(UInt(dw.W))
20 | val dout = Output(UInt(dw.W))
21 | })
22 |
23 | // Create a synchronous-read, synchronous-write memory (like in FPGAs).
24 | val mem = SyncReadMem(dep, UInt(dw.W))
25 | // Create one write port and one read port.
26 | when (io.we) {
27 | mem.write(io.wa, io.di)
28 | }
29 |
30 | io.dout := mem.read(io.ra, io.re)
31 | }
32 |
--------------------------------------------------------------------------------
/src/test/scala/VPQC/KeccakTest.scala:
--------------------------------------------------------------------------------
1 |
2 | package VPQC
3 |
4 | import chisel3.iotesters
5 | import chisel3.iotesters.{ChiselFlatSpec, Driver, PeekPokeTester}
6 |
7 | // test value from: SHAKE128
8 | // https://csrc.nist.gov/projects/cryptographic-standards-and-guidelines/example-values
9 | class KeccakTest(c: KeccakCore) extends PeekPokeTester(c) {
10 | var message = new Array[Long](1600 / 64)
11 | for (i <- 0 until 1600 / 64) {
12 | if (i < 21) {
13 | message(i) = 0xA3A3A3A3A3A3A3A3L
14 | }
15 | else {
16 | message(i) = 0L
17 | }
18 | }
19 | for (y <- 0 until 5) {
20 | for (x <- 0 until 5) {
21 | poke(c.io.seed.s(y)(x), message(y * 5 + x))
22 | }
23 | }
24 | poke(c.io.seedWrite, 1)
25 | step(1)
26 | poke(c.io.seedWrite, 0)
27 | poke(c.io.valid, 1)
28 | step(1)
29 | poke(c.io.valid, 0)
30 | step(24)
31 |
32 | expect(c.io.done, 1)
33 | var res = new Array[BigInt](1600 / 64)
34 | for (y <- 0 until 5) {
35 | for (x <- 0 until 5) {
36 | res(y * 5 + x) = peek(c.io.prngNumber.s(y)(x))
37 | printf("res[%d] = %x\n", y * 5 + x, res(y * 5 + x))
38 | }
39 | }
40 | }
41 | object KeccakTestMain extends App {
42 | iotesters.Driver.execute(Array("--backend-name", "verilator"), () => new KeccakCore) {
43 | c => new KeccakTest(c)
44 | }
45 | }
46 |
47 |
--------------------------------------------------------------------------------
/src/test/scala/VPQC/NTTTest.scala:
--------------------------------------------------------------------------------
1 |
2 | package VPQC
3 |
4 | import chisel3.iotesters
5 | import chisel3.iotesters.{ChiselFlatSpec, Driver, PeekPokeTester}
6 | import chisel3.util._
7 |
8 | import scala.util._
9 |
10 |
11 | //class MRTest(c: MR) extends PeekPokeTester(c) {
12 | // var a = new scala.util.Random
13 | // poke(c.io.n, 14)
14 | // for (i <- 0 until 1000) {
15 | // var aIn = a.nextInt(1 << 28)
16 | // poke(c.io.a, aIn)
17 | // var m = (1 << 13) + a.nextInt(1 << 13)
18 | // poke(c.io.m, m)
19 | // poke(c.io.u, scala.math.pow(2, 29).toLong / m)
20 | // step(1)
21 | // expect(c.io.ar, aIn % m)
22 | // }
23 | //}
24 | //
25 | //object MRTestMain extends App {
26 | // iotesters.Driver.execute(Array(), () => new MR) {
27 | // c => new MRTest(c)
28 | // }
29 | //}
30 |
31 | class BTest(c: Butterfly) extends PeekPokeTester(c) {
32 | var s = new scala.util.Random
33 | poke(c.io.n, 14)
34 |
35 | for (i <- 0 until 1000) {
36 | var m = (1 << 13) + s.nextInt(1 << 13)
37 | poke(c.io.m, m)
38 | poke(c.io.u, scala.math.pow(2, 29).toLong / m)
39 |
40 | var a = s.nextInt(m)
41 | var b = s.nextInt(m)
42 | var wn = s.nextInt(m)
43 | poke(c.io.a, a)
44 | poke(c.io.b, b)
45 | poke(c.io.wn, wn)
46 |
47 | step(1)
48 | var mul = (b * wn) % m
49 | var expectB = 0
50 | if (a < mul) {
51 | expectB = a + m- mul
52 | } else {
53 | expectB = a - mul
54 | }
55 | expect(c.io.aout, (a + mul) % m)
56 | expect(c.io.bout, expectB)
57 | }
58 | }
59 |
60 | object BTestMain extends App {
61 | iotesters.Driver.execute(Array(), () => new Butterfly) {
62 | c => new BTest(c)
63 | }
64 | }
65 |
66 | class PermTest(c: PermNetIn) extends PeekPokeTester(c)
67 | with HasNTTParameters {
68 | var r = scala.util.Random
69 | for (i <- 0 until 4) {
70 | for (j <- 0 until 1) {
71 | for (k <- 0 until ButterflyNum * 2) {
72 | poke(c.io.in(k), r.nextInt(1 << 16))
73 | }
74 | for (k <- 0 until 4) {
75 | poke(c.io.config(k), 0)
76 | }
77 | poke(c.io.config(i), 1)
78 | print("Config:\t" + peek(c.io.config) + "\n")
79 | for (k <- 0 until ButterflyNum * 2) {
80 | print(k + ":\t" + peek(c.io.in(k)) + "->" + peek(c.io.out(k)) + "\n")
81 | }
82 | }
83 | }
84 | }
85 |
86 | object PermTestMain extends App {
87 | iotesters.Driver.execute(Array(), () => new PermNetIn) {
88 | c => new PermTest(c)
89 | }
90 | }
91 |
92 | class NTTTest(c: NTT) extends PeekPokeTester(c)
93 | with HasNTTParameters {
94 | // test butterfly
95 |
96 | var r = Random
97 | // configure csrs
98 | poke(c.io.mode, 0)
99 | poke(c.io.valid, 0)
100 |
101 | poke(c.io.csrs.csrBinomialk, log2Ceil(8))
102 | poke(c.io.csrs.csrModulusq, 12289)
103 | poke(c.io.csrs.csrBarretu, scala.math.pow(2, 29).toLong / 12289)
104 | poke(c.io.csrs.csrModulusLen, 14)
105 | poke(c.io.csrs.csrBound, scala.math.pow(2, 32).toLong / 12289 * 12289)
106 | poke(c.io.csrs.csrButterflyCtrl.stageCfg, 0)
107 | poke(c.io.csrs.csrButterflyCtrl.iterCfg, 0)
108 |
109 | // 1: DIF
110 | // 0: DIT
111 | var s = new Array[Int](256)
112 | for (i <- 0 until 8) {
113 | for (j <- 0 until 32) {
114 | s(32*i + j) = r.nextInt() & 0xffff
115 | }
116 | }
117 |
118 | // write const ram
119 | var w256 = new Array[Int](256)
120 | for (i <- 0 until 256) {
121 | if (i == 0) {
122 | w256(i) = 1
123 | } else {
124 | w256(i) = (w256(i-1) * 9) % 12289
125 | }
126 | }
127 | var dataBuf = Array.ofDim[BigInt](256 / ButterflyNum, ButterflyNum)
128 | var idx = 0
129 | var div = 64
130 | for (i <- 0 until 256 / ButterflyNum) {
131 | for (j <- 0 until ButterflyNum) {
132 | if (i == 0) {
133 | dataBuf(i)(0) = w256(0)
134 | dataBuf(i)(1) = w256(0)
135 | dataBuf(i)(2) = w256(0)
136 | dataBuf(i)(3) = w256(256/4)
137 | dataBuf(i)(4) = w256(0)
138 | dataBuf(i)(5) = w256(256/8)
139 | dataBuf(i)(6) = w256(256*2/8)
140 | dataBuf(i)(7) = w256(256*3/8)
141 | dataBuf(i)(8) = w256(0)
142 | dataBuf(i)(9) = w256(256/16)
143 | dataBuf(i)(10) = w256(256*2/16)
144 | dataBuf(i)(11) = w256(256*3/16)
145 | dataBuf(i)(12) = w256(256*4/16)
146 | dataBuf(i)(13) = w256(256*5/16)
147 | dataBuf(i)(14) = w256(256*6/16)
148 | dataBuf(i)(15) = w256(256*7/16)
149 | dataBuf(i)(16) = w256(0)
150 | dataBuf(i)(17) = w256(256/32)
151 | dataBuf(i)(18) = w256(256*2/32)
152 | dataBuf(i)(19) = w256(256*3/32)
153 | dataBuf(i)(20) = w256(256*4/32)
154 | dataBuf(i)(21) = w256(256*5/32)
155 | dataBuf(i)(22) = w256(256*6/32)
156 | dataBuf(i)(23) = w256(256*7/32)
157 | dataBuf(i)(24) = w256(256*8/32)
158 | dataBuf(i)(25) = w256(256*9/32)
159 | dataBuf(i)(26) = w256(256*10/32)
160 | dataBuf(i)(27) = w256(256*11/32)
161 | dataBuf(i)(28) = w256(256*12/32)
162 | dataBuf(i)(29) = w256(256*13/32)
163 | dataBuf(i)(30) = w256(256*14/32)
164 | dataBuf(i)(31) = w256(256*15/32)
165 | } else {
166 | dataBuf(i)(j) = w256(256 * idx / div)
167 | idx = idx + 1
168 | }
169 | }
170 | if (idx == (div / 2)) {
171 | idx = 0
172 | div = div * 2
173 | }
174 | }
175 |
176 | var a:BigInt = 0
177 | var fill_time = 0
178 | for (i <- 0 until 256 / ButterflyNum) {
179 | poke(c.io.wa, i)
180 | a = 0
181 | for (j <- 0 until ButterflyNum) {
182 | a = a | ((dataBuf(i)(j) & 0xffff) << (16 * j))
183 | }
184 | poke(c.io.di, a)
185 | poke(c.io.we, true)
186 | step(1)
187 | fill_time += 1
188 | }
189 | poke(c.io.we, false)
190 | step(1)
191 | fill_time += 1
192 |
193 | // raw ntt
194 | var len = 256
195 | def ntt(x: Array[Int]): Array[Int] = {
196 | var res = new Array[Int](len)
197 | for (i <- 0 until len) {
198 | res(i) = 0
199 | }
200 | for (i <- 0 until len) {
201 | for (j <- 0 until len) {
202 | res(i) += (x(j) * w256((i*j) % len)) % 12289
203 | }
204 | res(i) = res(i) % 12289
205 | }
206 | res
207 | }
208 | def range(a: Int, upBound: Int, downBound: Int) : Int = {
209 | assert(upBound < 32)
210 | assert(downBound >= 0)
211 | return (a >> downBound) & (0xffffffff >>> (31-upBound+downBound))
212 | }
213 | def reverse(a: Int, len: Int): Int = {
214 | var res: Int = 0
215 | for(i <- 0 until len) {
216 | res = res | range(a, i, i) << (len-1-i)
217 | }
218 | res
219 | }
220 | var sInOrder = new Array[Int](256)
221 | for (i <- 0 until 256) {
222 | sInOrder(reverse(i, log2Ceil(len))) = s(i)
223 | }
224 | val rawRes = ntt(sInOrder)
225 |
226 | // iterative ntt (assume bit reverse)
227 | var sit = new Array[Int](256)
228 | for (i <- 0 until 256) {
229 | sit(i) = s(i)
230 | }
231 | for(ii <- 1 to 8) {
232 | var i = 1 << ii
233 | if(ButterflyNum >= i/2) {
234 | for(j <- 0 to 255 by 2*ButterflyNum) {
235 | for(k1 <- 0 to 2*ButterflyNum-1 by i) {
236 | for(k2 <- 0 to i/2 - 1) {
237 | var aIn = sit(k1 + k2 + j)
238 | var bIn = sit(k1 + k2 + j + i / 2)
239 | var mul = (bIn * w256(256 * k2 / i)) % 12289
240 | if (aIn < mul) {
241 | sit(k1 + k2 + j + i / 2) = aIn + 12289 - mul
242 | } else {
243 | sit(k1 + k2 + j + i / 2) = aIn - mul
244 | }
245 | sit(k1 + k2 + j) = (aIn + mul) % 12289
246 | }
247 | }
248 | }
249 | }
250 | else {
251 | for(k1 <- 0 to i/2 - 1 by ButterflyNum) {
252 | for (j <- 0 to 255 by i) {
253 | for (k2 <- 0 to ButterflyNum - 1) {
254 | var aIn = sit(k1 + k2 + j)
255 | var bIn = sit(k1 + k2 + j + i / 2)
256 | var mul = (bIn * w256(256 * (k1+k2) / i)) % 12289
257 | if (aIn < mul) {
258 | sit(k1 + k2 + j + i / 2) = aIn + 12289 - mul
259 | } else {
260 | sit(k1 + k2 + j + i / 2) = aIn - mul
261 | }
262 | sit(k1 + k2 + j) = (aIn + mul) % 12289
263 | }
264 | }
265 | }
266 | }
267 | }
268 |
269 | var ntt_time = 0
270 | poke(c.io.mode, 0)
271 | poke(c.io.valid, 1)
272 | for (stage <- 0 until 8) {
273 | var i = 2 << stage
274 | if(stage < 6) {
275 | poke(c.io.valid, 0)
276 | poke(c.io.csrs.csrButterflyCtrl.stageCfg, stage)
277 | poke(c.io.csrs.csrButterflyCtrl.iterCfg, 0)
278 | step(1)
279 | ntt_time += 1
280 | poke(c.io.valid, 1)
281 | for (iter <- 0 until 4) {
282 | for (j <- 0 until 32) {
283 | poke(c.io.vectorReg1(j), s(iter * 2 * 32 + j))
284 | }
285 | for (j <- 0 until 32) {
286 | poke(c.io.vectorReg2(j), s((iter * 2 + 1) * 32 + j))
287 | }
288 | step(1)
289 | ntt_time += 1
290 | var res = peek(c.io.dataOut)
291 | for (j <- 0 until 32) {
292 | s(iter * 2 * 32 + j) = res(j).toInt & 0xffff
293 | }
294 | for (j <- 0 until 32) {
295 | s((iter * 2 + 1) * 32 + j) = res(32 + j).toInt & 0xffff
296 | }
297 | }
298 | }
299 | else {
300 | for (iter <- 0 until i / (2 * ButterflyNum)) {
301 | poke(c.io.valid, 0)
302 | poke(c.io.csrs.csrButterflyCtrl.stageCfg, stage)
303 | poke(c.io.csrs.csrButterflyCtrl.iterCfg, iter)
304 | step(1)
305 | ntt_time += 1
306 | poke(c.io.valid, 1)
307 | for (k <- 0 until 256 / i) {
308 | for (j <- 0 until 32) {
309 | poke(c.io.vectorReg1(j), s((2 * k * i / (2 * ButterflyNum) + iter) * 32 + j))
310 | }
311 | for (j <- 0 until 32) {
312 | poke(c.io.vectorReg2(j), s(((2 * k + 1) * i / (2 * ButterflyNum) + iter) * 32 + j ))
313 | }
314 | step(1)
315 | ntt_time += 1
316 | var res = peek(c.io.dataOut)
317 | for (j <- 0 until 32) {
318 | s((2 * k * i / (2 * ButterflyNum) + iter) * 32 + j) = res(j).toInt & 0xffff
319 | }
320 | for (j <- 0 until 32) {
321 | s(((2 * k + 1) * i / (2 * ButterflyNum) + iter) * 32 + j ) = res(32+j).toInt & 0xffff
322 | }
323 | }
324 | }
325 | }
326 | }
327 | poke(c.io.valid, 0)
328 | step(1)
329 | ntt_time += 1
330 |
331 | // expect
332 | for (i <- 0 until 256) {
333 | // iterative | recursive | harware
334 | printf("%d: 0x%x\t%d: 0x%x\t%d: 0x%x\n", i, sit(i), i, rawRes(i), i, s(i))
335 | }
336 |
337 | // performance
338 | printf("NTT time(Dimension = 256): %d cycles\n", ntt_time)
339 | }
340 |
341 | object NTTTestSimple extends App {
342 | iotesters.Driver.execute(Array(), () => new NTT) {
343 | c => new NTTTest(c)
344 | }
345 | }
346 |
347 | object NTTTestMain extends App {
348 | iotesters.Driver.execute(Array("--backend-name", "verilator"), () => new NTT) {
349 | c => new NTTTest(c)
350 | }
351 | }
352 |
353 | //class NTTTest(c:NTTR2MDC) extends PeekPokeTester(c) {
354 | //
355 | // var len = 512
356 | // def ntt(x: Array[Int]): Array[Int] = {
357 | // var wn = new Array[Int](len)
358 | // for (i <- 0 until len) {
359 | // if (i == 0){
360 | // wn(i) = 1
361 | // } else {
362 | // wn(i) = (wn(i-1) * 3) % 12289
363 | // }
364 | // }
365 | // var res = new Array[Int](len)
366 | // for (i <- 0 until len) {
367 | // res(i) = 0
368 | // }
369 | // for (i <- 0 until len) {
370 | // for (j <- 0 until len) {
371 | // res(i) += (x(j) * wn((i*j) % len)) % 12289
372 | // }
373 | // res(i) = res(i) % 12289
374 | // }
375 | // res
376 | // }
377 | //
378 | // def range(a: Int, upBound: Int, downBound: Int) : Int = {
379 | // assert(upBound < 32)
380 | // assert(downBound >= 0)
381 | // return (a >> downBound) & (0xffffffff >>> (31-upBound+downBound))
382 | // }
383 | //
384 | // def reverse(a: Int, len: Int): Int = {
385 | // var res: Int = 0
386 | // for(i <- 0 until len) {
387 | // res = res | range(a, i, i) << (len-1-i)
388 | // }
389 | // res
390 | // }
391 | //
392 | // var l = 14
393 | // val r = new scala.util.Random
394 | // var bound: Double = math.pow(2.0, l)
395 | // var iterNum: Int = 100
396 | //
397 | // for (t <- 0 until iterNum) {
398 | // var a = new Array[Int](len)
399 | // for (i <- 0 until len) {
400 | // a(i) = r.nextInt(bound.toInt)
401 | // poke(c.io.dIn, a(i) & 0x3fff)
402 | // poke(c.io.dInValid, 1)
403 | // step(1)
404 | // }
405 | // var ref = ntt(a)
406 | //
407 | // for (i <- 0 until len / 2) {
408 | // var ref1 = ref(reverse(i * 2, log2Ceil(len)))
409 | // expect(c.io.dOut1, ref1)
410 | //
411 | // var ref2 = ref(reverse(i * 2 + 1, log2Ceil(len)))
412 | // expect(c.io.dOut2, ref2)
413 | //
414 | // expect(c.io.dOutValid, 1)
415 | //
416 | // a(reverse(i * 2, log2Ceil(len))) = peek(c.io.dOut1).toInt
417 | // a(reverse(i * 2 + 1, log2Ceil(len))) = peek(c.io.dOut2).toInt
418 | // step(1)
419 | // }
420 | //// for (i <- 0 until len) {
421 | //// print(ref(i) + "\n")
422 | //// }
423 | //// for (i <- 0 until len) {
424 | //// print(a(i) + "\n")
425 | //// }
426 | // }
427 | //}
428 | //
429 | //object NTTTestMain extends App {
430 | // iotesters.Driver.execute(Array(), () => new NTTR2MDC) {
431 | // c => new NTTTest(c)
432 | // }
433 | //}
--------------------------------------------------------------------------------
/src/test/scala/VPQC/SamplersTest.scala:
--------------------------------------------------------------------------------
1 |
2 | package VPQC
3 |
4 | import chisel3.iotesters
5 | import chisel3.iotesters.PeekPokeTester
6 | import chisel3.util._
7 | import scala.util._
8 |
9 | class SamplersTest(c: Samplers) extends PeekPokeTester(c){
10 |
11 | // write seed
12 | var r = Random
13 |
14 | poke(c.io.csrs.csrBinomialk, log2Ceil(8))
15 | poke(c.io.csrs.csrModulusq, 12289)
16 | poke(c.io.csrs.csrBarretu, scala.math.pow(2, 29).toLong / 12289)
17 | poke(c.io.csrs.csrModulusLen, 14)
18 | poke(c.io.csrs.csrBound, scala.math.pow(2, 32).toLong / 12289 * 12289)
19 |
20 | // rejection sample test
21 | var iterNum = 16
22 | for (i <- 0 until iterNum) {
23 | poke(c.io.mode, 0)
24 | for (i <- 0 until 32) {
25 | poke(c.io.vectorReg1(i), r.nextInt() & 0xffff)
26 | poke(c.io.vectorReg2(i), r.nextInt() & 0xffff)
27 | }
28 | step(1)
29 | }
30 |
31 | // binomial(gauss) sample test
32 | for (i <- 0 until iterNum) {
33 | poke(c.io.mode, 1)
34 | for (i <- 0 until 32) {
35 | poke(c.io.vectorReg1(i), r.nextInt() & 0xffff)
36 | poke(c.io.vectorReg2(i), r.nextInt() & 0xffff)
37 | }
38 | step(1)
39 | }
40 | }
41 |
42 | object SamplersTestSimple extends App {
43 | iotesters.Driver.execute(Array(), () => new Samplers) {
44 | c => new SamplersTest(c)
45 | }
46 | }
47 |
48 | object SamplersTestMain extends App {
49 | iotesters.Driver.execute(Array("--backend-name", "verilator"), () => new Samplers) {
50 | c => new SamplersTest(c)
51 | }
52 | }
53 |
--------------------------------------------------------------------------------
/src/test/scala/VPQC/procTest.scala:
--------------------------------------------------------------------------------
1 |
2 | package VPQC
3 |
4 | import chisel3.iotesters
5 | import chisel3.iotesters.PeekPokeTester
6 | import chisel3.util._
7 | import scala.util._
8 |
9 | class ProcTest(c: PQCCoprocessor) extends PeekPokeTester(c)
10 | with HasNTTParameters
11 | with HasCommonParameters
12 | with HasPQCInstructions {
13 |
14 | /* instruction encoding style */
15 | // func7 rs2 rs1 reserved rd 00010 11
16 | def instr(f: Int, rs1: Int, rs2: Int, rd: Int): Int= {
17 | (f << 25) | (rs2 << 20) | (rs1 << 15) | (rd << 7) | 11
18 | }
19 |
20 | // init
21 | var r = Random
22 | poke(c.io.instr.rs1, 0)
23 | poke(c.io.instr.rs2, 0)
24 | poke(c.io.instr.rd, 0)
25 | poke(c.io.instr.funct, INSTR_FETCHRN) // nop
26 | poke(c.io.in_fire, 0)
27 |
28 | poke(c.io.twiddleData, 0)
29 | poke(c.io.twiddleAddr, 0)
30 | poke(c.io.twiddleWrite, false)
31 | step(1)
32 |
33 | // configure csrs
34 | poke(c.io.in_fire, 1)
35 | poke(c.io.instr.funct, INSTR_CSRRW)
36 | poke(c.io.instr.rs2, 0)
37 | poke(c.io.rs1, scala.math.pow(2, 29).toLong / 12289)
38 | step(1)
39 | poke(c.io.instr.rs2, 1)
40 | poke(c.io.rs1, scala.math.pow(2, 32).toLong / 12289 * 12289)
41 | step(1)
42 | poke(c.io.instr.rs2, 2)
43 | poke(c.io.rs1, log2Ceil(8))
44 | step(1)
45 | poke(c.io.instr.rs2, 3)
46 | poke(c.io.rs1, 12289)
47 | step(1)
48 | poke(c.io.instr.rs2, 4)
49 | poke(c.io.rs1, 14)
50 | step(1)
51 | poke(c.io.instr.rs2, 5)
52 | poke(c.io.rs1, 0 << 6 | 0)
53 | step(1)
54 | poke(c.io.in_fire, 0)
55 |
56 | // write seed
57 | for (y <- 0 until 5) {
58 | for (x <- 0 until 5) {
59 | poke(c.io.seed.s(y)(x), r.nextLong())
60 | }
61 | }
62 | poke(c.io.seedWrite, 1)
63 | step(1)
64 | poke(c.io.seedWrite, 0)
65 |
66 |
67 | // waite for the buffer-filling
68 | var sample_time = 0
69 | step(24 * 16 + 1)
70 | sample_time += 24*16 + 1
71 |
72 | // test fetchrn
73 | poke(c.io.in_fire, 1)
74 | poke(c.io.instr.funct, INSTR_FETCHRN)
75 | poke(c.io.instr.rs1, 0)
76 | poke(c.io.instr.rs2, 0)
77 | poke(c.io.instr.rd, 0)
78 |
79 | for (i <- 0 until 16) {
80 | poke(c.io.instr.rd, i)
81 | step(1)
82 | sample_time += 1
83 | }
84 | poke(c.io.in_fire, 0)
85 | step(1)
86 | sample_time += 1
87 |
88 | // test sample
89 | poke(c.io.in_fire, 1)
90 | poke(c.io.instr.funct, INSTR_SAMPLEBINOMIAL)
91 |
92 | for (i <- 0 until 8) {
93 | poke(c.io.instr.rs1, 2 * i)
94 | poke(c.io.instr.rs2, 2 * i + 1)
95 | poke(c.io.instr.rd, i)
96 | step(1)
97 | sample_time += 1
98 | }
99 | poke(c.io.in_fire, 0)
100 | step(1)
101 | sample_time += 1
102 |
103 | // peek sampled value
104 | var s = new Array[Int](256)
105 | poke(c.io.in_fire, 1)
106 | for (i <- 0 until 8) {
107 | poke(c.io.instr.funct, INSTR_VST)
108 | poke(c.io.instr.rs1, i)
109 | step(1)
110 | for (j <- 0 until 32) {
111 | s(32*i + j) = peek(c.io.vectorOut(j)).toInt
112 | }
113 | }
114 | poke(c.io.in_fire, 0)
115 |
116 | // write const ram
117 | var w256 = new Array[Int](256)
118 | for (i <- 0 until 256) {
119 | if (i == 0) {
120 | w256(i) = 1
121 | } else {
122 | w256(i) = (w256(i-1) * 9) % 12289
123 | }
124 | }
125 |
126 | var dataBuf = Array.ofDim[BigInt](256 / ButterflyNum, ButterflyNum)
127 | var idx = 0
128 | var div = 64
129 | for (i <- 0 until 256 / ButterflyNum) {
130 | for (j <- 0 until ButterflyNum) {
131 | if (i == 0) {
132 | dataBuf(i)(0) = w256(0)
133 | dataBuf(i)(1) = w256(0)
134 | dataBuf(i)(2) = w256(0)
135 | dataBuf(i)(3) = w256(256/4)
136 | dataBuf(i)(4) = w256(0)
137 | dataBuf(i)(5) = w256(256/8)
138 | dataBuf(i)(6) = w256(256*2/8)
139 | dataBuf(i)(7) = w256(256*3/8)
140 | dataBuf(i)(8) = w256(0)
141 | dataBuf(i)(9) = w256(256/16)
142 | dataBuf(i)(10) = w256(256*2/16)
143 | dataBuf(i)(11) = w256(256*3/16)
144 | dataBuf(i)(12) = w256(256*4/16)
145 | dataBuf(i)(13) = w256(256*5/16)
146 | dataBuf(i)(14) = w256(256*6/16)
147 | dataBuf(i)(15) = w256(256*7/16)
148 | dataBuf(i)(16) = w256(0)
149 | dataBuf(i)(17) = w256(256/32)
150 | dataBuf(i)(18) = w256(256*2/32)
151 | dataBuf(i)(19) = w256(256*3/32)
152 | dataBuf(i)(20) = w256(256*4/32)
153 | dataBuf(i)(21) = w256(256*5/32)
154 | dataBuf(i)(22) = w256(256*6/32)
155 | dataBuf(i)(23) = w256(256*7/32)
156 | dataBuf(i)(24) = w256(256*8/32)
157 | dataBuf(i)(25) = w256(256*9/32)
158 | dataBuf(i)(26) = w256(256*10/32)
159 | dataBuf(i)(27) = w256(256*11/32)
160 | dataBuf(i)(28) = w256(256*12/32)
161 | dataBuf(i)(29) = w256(256*13/32)
162 | dataBuf(i)(30) = w256(256*14/32)
163 | dataBuf(i)(31) = w256(256*15/32)
164 | } else {
165 | dataBuf(i)(j) = w256(256 * idx / div)
166 | idx = idx + 1
167 | }
168 | }
169 | if (idx == (div / 2)) {
170 | idx = 0
171 | div = div * 2
172 | }
173 | }
174 |
175 | var a:BigInt = 0
176 | for (i <- 0 until 256 / ButterflyNum) {
177 | poke(c.io.twiddleAddr, i)
178 | a = 0
179 | for (j <- 0 until ButterflyNum) {
180 | a = a | ((dataBuf(i)(j) & 0xffff) << (16 * j))
181 | }
182 | poke(c.io.twiddleData, a)
183 | poke(c.io.twiddleWrite, true)
184 | step(1)
185 | }
186 | poke(c.io.twiddleWrite, false)
187 | step(1)
188 |
189 | // raw ntt
190 | var len = 256
191 | def ntt(x: Array[Int]): Array[Int] = {
192 | var res = new Array[Int](len)
193 | for (i <- 0 until len) {
194 | res(i) = 0
195 | }
196 | for (i <- 0 until len) {
197 | for (j <- 0 until len) {
198 | res(i) += (x(j) * w256((i*j) % len)) % 12289
199 | }
200 | res(i) = res(i) % 12289
201 | }
202 | res
203 | }
204 | def range(a: Int, upBound: Int, downBound: Int) : Int = {
205 | assert(upBound < 32)
206 | assert(downBound >= 0)
207 | return (a >> downBound) & (0xffffffff >>> (31-upBound+downBound))
208 | }
209 | def reverse(a: Int, len: Int): Int = {
210 | var res: Int = 0
211 | for(i <- 0 until len) {
212 | res = res | range(a, i, i) << (len-1-i)
213 | }
214 | res
215 | }
216 | var sInOrder = new Array[Int](256)
217 | for (i <- 0 until 256) {
218 | sInOrder(reverse(i, log2Ceil(len))) = s(i)
219 | }
220 | val rawRes = ntt(sInOrder)
221 |
222 | // iterative ntt (assume bit reverse)
223 | for(ii <- 1 to 8) {
224 | var i = 1 << ii
225 | if(ButterflyNum >= i/2) {
226 | for(j <- 0 to 255 by 2*ButterflyNum) {
227 | for(k1 <- 0 to 2*ButterflyNum-1 by i) {
228 | for(k2 <- 0 to i/2 - 1) {
229 | var aIn = s(k1 + k2 + j)
230 | var bIn = s(k1 + k2 + j + i / 2)
231 | var mul = (bIn * w256(256 * k2 / i)) % 12289
232 | if (aIn < mul) {
233 | s(k1 + k2 + j + i / 2) = aIn + 12289 - mul
234 | } else {
235 | s(k1 + k2 + j + i / 2) = aIn - mul
236 | }
237 | s(k1 + k2 + j) = (aIn + mul) % 12289
238 | }
239 | }
240 | }
241 | }
242 | else {
243 | for(k1 <- 0 to i/2 - 1 by ButterflyNum) {
244 | for (j <- 0 to 255 by i) {
245 | for (k2 <- 0 to ButterflyNum - 1) {
246 | var aIn = s(k1 + k2 + j)
247 | var bIn = s(k1 + k2 + j + i / 2)
248 | var mul = (bIn * w256(256 * (k1+k2) / i)) % 12289
249 | if (aIn < mul) {
250 | s(k1 + k2 + j + i / 2) = aIn + 12289 - mul
251 | } else {
252 | s(k1 + k2 + j + i / 2) = aIn - mul
253 | }
254 | s(k1 + k2 + j) = (aIn + mul) % 12289
255 | }
256 | }
257 | }
258 | }
259 | }
260 |
261 | // test butterfly
262 | var ntt_time = 0
263 | poke(c.io.in_fire, 1)
264 | for (stage <- 0 until 8) {
265 | var i = 2 << stage
266 | if(stage < 6) {
267 | poke(c.io.instr.funct, INSTR_CSRRW)
268 | poke(c.io.instr.rs2, 5)
269 | poke(c.io.rs1, stage << 6 | 0)
270 | step(1)
271 | ntt_time += 1
272 | poke(c.io.instr.funct, INSTR_BUTTERFLY)
273 | for (iter <- 0 until 4) {
274 | poke(c.io.instr.rs1, iter * 2)
275 | poke(c.io.instr.rs2, iter * 2 + 1)
276 | step(1)
277 | ntt_time += 1
278 | }
279 | }
280 | else {
281 | for (iter <- 0 until i / (2 * ButterflyNum)) {
282 | poke(c.io.instr.funct, INSTR_CSRRW)
283 | poke(c.io.instr.rs2, 5)
284 | poke(c.io.rs1, stage << 6 | iter)
285 | // simulate CSRRW
286 | step(1)
287 | ntt_time += 1
288 | poke(c.io.instr.funct, INSTR_BUTTERFLY)
289 | for (k <- 0 until 256 / i) {
290 | poke(c.io.instr.rs1, 2 * k * i / (2 * ButterflyNum) + iter)
291 | poke(c.io.instr.rs2, (2 * k + 1) * i / (2 * ButterflyNum) + iter)
292 | step(1)
293 | ntt_time += 1
294 | }
295 | }
296 | }
297 | }
298 | poke(c.io.in_fire, 0)
299 | step(1)
300 | ntt_time += 1
301 |
302 | // peek ntt value
303 | var nttRes = new Array[Int](256)
304 | poke(c.io.in_fire, 1)
305 | for (i <- 0 until 8) {
306 | poke(c.io.instr.funct, INSTR_VST)
307 | poke(c.io.instr.rs1, i)
308 | step(1)
309 | for (j <- 0 until 32) {
310 | expect(c.io.vectorOut(j), s(32 * i + j))
311 | nttRes(32 * i + j) = peek(c.io.vectorOut(j)).toInt
312 | }
313 | }
314 | poke(c.io.in_fire, 0)
315 |
316 | // expect
317 | for (i <- 0 until 256) {
318 | // iterative | recursive | harware
319 | printf("%d: 0x%x\t%d: 0x%x\t%d: 0x%x\n", i, s(i), i, rawRes(i), i, nttRes(i))
320 | }
321 | // performance
322 | printf("Binomial Sampling time(Dimension = 256): %d cycles\n", sample_time)
323 | printf("NTT time(Dimension = 256): %d cycles\n", ntt_time)
324 | }
325 |
326 | object TestTopTestSimple extends App {
327 | iotesters.Driver.execute(Array(), () => new PQCCoprocessor) {
328 | c => new ProcTest(c)
329 | }
330 | }
331 |
332 | object TestTopTestMain extends App {
333 | iotesters.Driver.execute(Array("--backend-name", "verilator"), () => new PQCCoprocessor) {
334 | c => new ProcTest(c)
335 | }
336 | }
337 |
--------------------------------------------------------------------------------