├── .cargo └── config.toml ├── .gitignore ├── Cargo.lock ├── Cargo.toml ├── README.md └── src ├── aligned_memory.rs ├── arith.rs ├── client.rs ├── discrete_gaussian.rs ├── gadget.rs ├── key_value.rs ├── lib.rs ├── noise_estimate.rs ├── ntt.rs ├── number_theory.rs ├── params.rs ├── poly.rs ├── server.rs └── util.rs /.cargo/config.toml: -------------------------------------------------------------------------------- 1 | [build] 2 | rustflags = [ 3 | "-C", 4 | "target-cpu=native", 5 | ] 6 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /target -------------------------------------------------------------------------------- /Cargo.lock: -------------------------------------------------------------------------------- 1 | # This file is automatically @generated by Cargo. 2 | # It is not intended for manual editing. 3 | version = 3 4 | 5 | [[package]] 6 | name = "autocfg" 7 | version = "1.1.0" 8 | source = "registry+https://github.com/rust-lang/crates.io-index" 9 | checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" 10 | 11 | [[package]] 12 | name = "block-buffer" 13 | version = "0.10.3" 14 | source = "registry+https://github.com/rust-lang/crates.io-index" 15 | checksum = "69cce20737498f97b993470a6e536b8523f0af7892a4f928cceb1ac5e52ebe7e" 16 | dependencies = [ 17 | "generic-array", 18 | ] 19 | 20 | [[package]] 21 | name = "bumpalo" 22 | version = "3.11.1" 23 | source = "registry+https://github.com/rust-lang/crates.io-index" 24 | checksum = "572f695136211188308f16ad2ca5c851a712c464060ae6974944458eb83880ba" 25 | 26 | [[package]] 27 | name = "cfg-if" 28 | version = "1.0.0" 29 | source = "registry+https://github.com/rust-lang/crates.io-index" 30 | checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" 31 | 32 | [[package]] 33 | name = "cpufeatures" 34 | version = "0.2.5" 35 | source = "registry+https://github.com/rust-lang/crates.io-index" 36 | checksum = "28d997bd5e24a5928dd43e46dc529867e207907fe0b239c3477d924f7f2ca320" 37 | dependencies = [ 38 | "libc", 39 | ] 40 | 41 | [[package]] 42 | name = "crossbeam-channel" 43 | version = "0.5.6" 44 | source = "registry+https://github.com/rust-lang/crates.io-index" 45 | checksum = "c2dd04ddaf88237dc3b8d8f9a3c1004b506b54b3313403944054d23c0870c521" 46 | dependencies = [ 47 | "cfg-if", 48 | "crossbeam-utils", 49 | ] 50 | 51 | [[package]] 52 | name = "crossbeam-deque" 53 | version = "0.8.2" 54 | source = "registry+https://github.com/rust-lang/crates.io-index" 55 | checksum = "715e8152b692bba2d374b53d4875445368fdf21a94751410af607a5ac677d1fc" 56 | dependencies = [ 57 | "cfg-if", 58 | "crossbeam-epoch", 59 | "crossbeam-utils", 60 | ] 61 | 62 | [[package]] 63 | name = "crossbeam-epoch" 64 | version = "0.9.13" 65 | source = "registry+https://github.com/rust-lang/crates.io-index" 66 | checksum = "01a9af1f4c2ef74bb8aa1f7e19706bc72d03598c8a570bb5de72243c7a9d9d5a" 67 | dependencies = [ 68 | "autocfg", 69 | "cfg-if", 70 | "crossbeam-utils", 71 | "memoffset", 72 | "scopeguard", 73 | ] 74 | 75 | [[package]] 76 | name = "crossbeam-utils" 77 | version = "0.8.14" 78 | source = "registry+https://github.com/rust-lang/crates.io-index" 79 | checksum = "4fb766fa798726286dbbb842f174001dab8abc7b627a1dd86e0b7222a95d929f" 80 | dependencies = [ 81 | "cfg-if", 82 | ] 83 | 84 | [[package]] 85 | name = "crypto-common" 86 | version = "0.1.6" 87 | source = "registry+https://github.com/rust-lang/crates.io-index" 88 | checksum = "1bfb12502f3fc46cca1bb51ac28df9d618d813cdc3d2f25b9fe775a34af26bb3" 89 | dependencies = [ 90 | "generic-array", 91 | "typenum", 92 | ] 93 | 94 | [[package]] 95 | name = "digest" 96 | version = "0.10.6" 97 | source = "registry+https://github.com/rust-lang/crates.io-index" 98 | checksum = "8168378f4e5023e7218c89c891c0fd8ecdb5e5e4f18cb78f38cf245dd021e76f" 99 | dependencies = [ 100 | "block-buffer", 101 | "crypto-common", 102 | ] 103 | 104 | [[package]] 105 | name = "either" 106 | version = "1.8.1" 107 | source = "registry+https://github.com/rust-lang/crates.io-index" 108 | checksum = "7fcaabb2fef8c910e7f4c7ce9f67a1283a1715879a7c230ca9d6d1ae31f16d91" 109 | 110 | [[package]] 111 | name = "fastrand" 112 | version = "2.0.1" 113 | source = "registry+https://github.com/rust-lang/crates.io-index" 114 | checksum = "25cbce373ec4653f1a01a31e8a5e5ec0c622dc27ff9c4e6606eefef5cbbed4a5" 115 | 116 | [[package]] 117 | name = "generic-array" 118 | version = "0.14.6" 119 | source = "registry+https://github.com/rust-lang/crates.io-index" 120 | checksum = "bff49e947297f3312447abdca79f45f4738097cc82b06e72054d2223f601f1b9" 121 | dependencies = [ 122 | "typenum", 123 | "version_check", 124 | ] 125 | 126 | [[package]] 127 | name = "getrandom" 128 | version = "0.2.8" 129 | source = "registry+https://github.com/rust-lang/crates.io-index" 130 | checksum = "c05aeb6a22b8f62540c194aac980f2115af067bfe15a0734d7277a768d396b31" 131 | dependencies = [ 132 | "cfg-if", 133 | "js-sys", 134 | "libc", 135 | "wasi", 136 | "wasm-bindgen", 137 | ] 138 | 139 | [[package]] 140 | name = "hermit-abi" 141 | version = "0.2.6" 142 | source = "registry+https://github.com/rust-lang/crates.io-index" 143 | checksum = "ee512640fe35acbfb4bb779db6f0d80704c2cacfa2e39b601ef3e3f47d1ae4c7" 144 | dependencies = [ 145 | "libc", 146 | ] 147 | 148 | [[package]] 149 | name = "itoa" 150 | version = "1.0.1" 151 | source = "registry+https://github.com/rust-lang/crates.io-index" 152 | checksum = "1aab8fc367588b89dcee83ab0fd66b72b50b72fa1904d7095045ace2b0c81c35" 153 | 154 | [[package]] 155 | name = "js-sys" 156 | version = "0.3.60" 157 | source = "registry+https://github.com/rust-lang/crates.io-index" 158 | checksum = "49409df3e3bf0856b916e2ceaca09ee28e6871cf7d9ce97a692cacfdb2a25a47" 159 | dependencies = [ 160 | "wasm-bindgen", 161 | ] 162 | 163 | [[package]] 164 | name = "libc" 165 | version = "0.2.137" 166 | source = "registry+https://github.com/rust-lang/crates.io-index" 167 | checksum = "fc7fcc620a3bff7cdd7a365be3376c97191aeaccc2a603e600951e452615bf89" 168 | 169 | [[package]] 170 | name = "log" 171 | version = "0.4.17" 172 | source = "registry+https://github.com/rust-lang/crates.io-index" 173 | checksum = "abb12e687cfb44aa40f41fc3978ef76448f9b6038cad6aef4259d3c095a2382e" 174 | dependencies = [ 175 | "cfg-if", 176 | ] 177 | 178 | [[package]] 179 | name = "memoffset" 180 | version = "0.7.1" 181 | source = "registry+https://github.com/rust-lang/crates.io-index" 182 | checksum = "5de893c32cde5f383baa4c04c5d6dbdd735cfd4a794b0debdb2bb1b421da5ff4" 183 | dependencies = [ 184 | "autocfg", 185 | ] 186 | 187 | [[package]] 188 | name = "num_cpus" 189 | version = "1.15.0" 190 | source = "registry+https://github.com/rust-lang/crates.io-index" 191 | checksum = "0fac9e2da13b5eb447a6ce3d392f23a29d8694bff781bf03a16cd9ac8697593b" 192 | dependencies = [ 193 | "hermit-abi", 194 | "libc", 195 | ] 196 | 197 | [[package]] 198 | name = "once_cell" 199 | version = "1.16.0" 200 | source = "registry+https://github.com/rust-lang/crates.io-index" 201 | checksum = "86f0b0d4bf799edbc74508c1e8bf170ff5f41238e5f8225603ca7caaae2b7860" 202 | 203 | [[package]] 204 | name = "ppv-lite86" 205 | version = "0.2.16" 206 | source = "registry+https://github.com/rust-lang/crates.io-index" 207 | checksum = "eb9f9e6e233e5c4a35559a617bf40a4ec447db2e84c20b55a6f83167b7e57872" 208 | 209 | [[package]] 210 | name = "proc-macro2" 211 | version = "1.0.47" 212 | source = "registry+https://github.com/rust-lang/crates.io-index" 213 | checksum = "5ea3d908b0e36316caf9e9e2c4625cdde190a7e6f440d794667ed17a1855e725" 214 | dependencies = [ 215 | "unicode-ident", 216 | ] 217 | 218 | [[package]] 219 | name = "quote" 220 | version = "1.0.21" 221 | source = "registry+https://github.com/rust-lang/crates.io-index" 222 | checksum = "bbe448f377a7d6961e30f5955f9b8d106c3f5e449d493ee1b125c1d43c2b5179" 223 | dependencies = [ 224 | "proc-macro2", 225 | ] 226 | 227 | [[package]] 228 | name = "rand" 229 | version = "0.8.5" 230 | source = "registry+https://github.com/rust-lang/crates.io-index" 231 | checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" 232 | dependencies = [ 233 | "libc", 234 | "rand_chacha", 235 | "rand_core", 236 | ] 237 | 238 | [[package]] 239 | name = "rand_chacha" 240 | version = "0.3.1" 241 | source = "registry+https://github.com/rust-lang/crates.io-index" 242 | checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" 243 | dependencies = [ 244 | "ppv-lite86", 245 | "rand_core", 246 | ] 247 | 248 | [[package]] 249 | name = "rand_core" 250 | version = "0.6.3" 251 | source = "registry+https://github.com/rust-lang/crates.io-index" 252 | checksum = "d34f1408f55294453790c48b2f1ebbb1c5b4b7563eb1f418bcfcfdbb06ebb4e7" 253 | dependencies = [ 254 | "getrandom", 255 | ] 256 | 257 | [[package]] 258 | name = "rayon" 259 | version = "1.6.1" 260 | source = "registry+https://github.com/rust-lang/crates.io-index" 261 | checksum = "6db3a213adf02b3bcfd2d3846bb41cb22857d131789e01df434fb7e7bc0759b7" 262 | dependencies = [ 263 | "either", 264 | "rayon-core", 265 | ] 266 | 267 | [[package]] 268 | name = "rayon-core" 269 | version = "1.10.2" 270 | source = "registry+https://github.com/rust-lang/crates.io-index" 271 | checksum = "356a0625f1954f730c0201cdab48611198dc6ce21f4acff55089b5a78e6e835b" 272 | dependencies = [ 273 | "crossbeam-channel", 274 | "crossbeam-deque", 275 | "crossbeam-utils", 276 | "num_cpus", 277 | ] 278 | 279 | [[package]] 280 | name = "ryu" 281 | version = "1.0.9" 282 | source = "registry+https://github.com/rust-lang/crates.io-index" 283 | checksum = "73b4b750c782965c211b42f022f59af1fbceabdd026623714f104152f1ec149f" 284 | 285 | [[package]] 286 | name = "scopeguard" 287 | version = "1.1.0" 288 | source = "registry+https://github.com/rust-lang/crates.io-index" 289 | checksum = "d29ab0c6d3fc0ee92fe66e2d99f700eab17a8d57d1c1d3b748380fb20baa78cd" 290 | 291 | [[package]] 292 | name = "serde" 293 | version = "1.0.137" 294 | source = "registry+https://github.com/rust-lang/crates.io-index" 295 | checksum = "61ea8d54c77f8315140a05f4c7237403bf38b72704d031543aa1d16abbf517d1" 296 | 297 | [[package]] 298 | name = "serde_json" 299 | version = "1.0.80" 300 | source = "registry+https://github.com/rust-lang/crates.io-index" 301 | checksum = "f972498cf015f7c0746cac89ebe1d6ef10c293b94175a243a2d9442c163d9944" 302 | dependencies = [ 303 | "itoa", 304 | "ryu", 305 | "serde", 306 | ] 307 | 308 | [[package]] 309 | name = "sha2" 310 | version = "0.10.6" 311 | source = "registry+https://github.com/rust-lang/crates.io-index" 312 | checksum = "82e6b795fe2e3b1e845bafcb27aa35405c4d47cdfc92af5fc8d3002f76cebdc0" 313 | dependencies = [ 314 | "cfg-if", 315 | "cpufeatures", 316 | "digest", 317 | ] 318 | 319 | [[package]] 320 | name = "spiral-rs" 321 | version = "0.3.0" 322 | dependencies = [ 323 | "fastrand", 324 | "getrandom", 325 | "rand", 326 | "rand_chacha", 327 | "rayon", 328 | "serde_json", 329 | "sha2", 330 | "subtle", 331 | ] 332 | 333 | [[package]] 334 | name = "subtle" 335 | version = "2.4.1" 336 | source = "registry+https://github.com/rust-lang/crates.io-index" 337 | checksum = "6bdef32e8150c2a081110b42772ffe7d7c9032b606bc226c8260fd97e0976601" 338 | 339 | [[package]] 340 | name = "syn" 341 | version = "1.0.103" 342 | source = "registry+https://github.com/rust-lang/crates.io-index" 343 | checksum = "a864042229133ada95abf3b54fdc62ef5ccabe9515b64717bcb9a1919e59445d" 344 | dependencies = [ 345 | "proc-macro2", 346 | "quote", 347 | "unicode-ident", 348 | ] 349 | 350 | [[package]] 351 | name = "typenum" 352 | version = "1.15.0" 353 | source = "registry+https://github.com/rust-lang/crates.io-index" 354 | checksum = "dcf81ac59edc17cc8697ff311e8f5ef2d99fcbd9817b34cec66f90b6c3dfd987" 355 | 356 | [[package]] 357 | name = "unicode-ident" 358 | version = "1.0.5" 359 | source = "registry+https://github.com/rust-lang/crates.io-index" 360 | checksum = "6ceab39d59e4c9499d4e5a8ee0e2735b891bb7308ac83dfb4e80cad195c9f6f3" 361 | 362 | [[package]] 363 | name = "version_check" 364 | version = "0.9.4" 365 | source = "registry+https://github.com/rust-lang/crates.io-index" 366 | checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f" 367 | 368 | [[package]] 369 | name = "wasi" 370 | version = "0.11.0+wasi-snapshot-preview1" 371 | source = "registry+https://github.com/rust-lang/crates.io-index" 372 | checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" 373 | 374 | [[package]] 375 | name = "wasm-bindgen" 376 | version = "0.2.83" 377 | source = "registry+https://github.com/rust-lang/crates.io-index" 378 | checksum = "eaf9f5aceeec8be17c128b2e93e031fb8a4d469bb9c4ae2d7dc1888b26887268" 379 | dependencies = [ 380 | "cfg-if", 381 | "wasm-bindgen-macro", 382 | ] 383 | 384 | [[package]] 385 | name = "wasm-bindgen-backend" 386 | version = "0.2.83" 387 | source = "registry+https://github.com/rust-lang/crates.io-index" 388 | checksum = "4c8ffb332579b0557b52d268b91feab8df3615f265d5270fec2a8c95b17c1142" 389 | dependencies = [ 390 | "bumpalo", 391 | "log", 392 | "once_cell", 393 | "proc-macro2", 394 | "quote", 395 | "syn", 396 | "wasm-bindgen-shared", 397 | ] 398 | 399 | [[package]] 400 | name = "wasm-bindgen-macro" 401 | version = "0.2.83" 402 | source = "registry+https://github.com/rust-lang/crates.io-index" 403 | checksum = "052be0f94026e6cbc75cdefc9bae13fd6052cdcaf532fa6c45e7ae33a1e6c810" 404 | dependencies = [ 405 | "quote", 406 | "wasm-bindgen-macro-support", 407 | ] 408 | 409 | [[package]] 410 | name = "wasm-bindgen-macro-support" 411 | version = "0.2.83" 412 | source = "registry+https://github.com/rust-lang/crates.io-index" 413 | checksum = "07bc0c051dc5f23e307b13285f9d75df86bfdf816c5721e573dec1f9b8aa193c" 414 | dependencies = [ 415 | "proc-macro2", 416 | "quote", 417 | "syn", 418 | "wasm-bindgen-backend", 419 | "wasm-bindgen-shared", 420 | ] 421 | 422 | [[package]] 423 | name = "wasm-bindgen-shared" 424 | version = "0.2.83" 425 | source = "registry+https://github.com/rust-lang/crates.io-index" 426 | checksum = "1c38c045535d93ec4f0b4defec448e4291638ee608530863b1e2ba115d4fff7f" 427 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "spiral-rs" 3 | version = "0.3.0" 4 | edition = "2021" 5 | authors = ["Samir Menon ", "Neil Movva "] 6 | homepage = "https://blyss.dev" 7 | repository = "https://github.com/menonsamir/spiral-rs" 8 | description = "Rust implementation of the Spiral PIR scheme" 9 | keywords = ["privacy", "fhe", "cryptography"] 10 | categories = ["cryptography"] 11 | readme = "README.md" 12 | license = "MIT" 13 | 14 | [features] 15 | server = ["rayon"] 16 | 17 | [dependencies] 18 | rayon = { version = "1.6.1", optional = true } 19 | getrandom = { features = ["js"], version = "0.2.8" } 20 | rand = { version = "0.8.5", features = ["small_rng"] } 21 | serde_json = "1.0" 22 | rand_chacha = "0.3.1" 23 | sha2 = "0.10" 24 | subtle = "2.4" 25 | fastrand = "2.0.1" 26 | 27 | [profile.release-with-debug] 28 | inherits = "release" 29 | debug = true 30 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # spiral-rs 2 | 3 | This is a Rust implementation of some functionality in the [Spiral PIR scheme](https://eprint.iacr.org/2022/368). This fork provides core routines for use in [YPIR](https://github.com/menonsamir/ypir). 4 | 5 | For a complete, working version of spiral-rs, please see [this repository](https://github.com/blyssprivacy/sdk/tree/main/lib). 6 | 7 | ## Building 8 | 9 | You must have AVX512 to build the main branch of this repository. For an implemnetation that does not require AVX512, switch to the `avoid-avx512` branch. -------------------------------------------------------------------------------- /src/aligned_memory.rs: -------------------------------------------------------------------------------- 1 | use std::{ 2 | alloc::{alloc_zeroed, dealloc, Layout}, 3 | mem::size_of, 4 | ops::{Index, IndexMut}, 5 | slice::{from_raw_parts, from_raw_parts_mut}, 6 | }; 7 | 8 | const ALIGN_SIMD: usize = 64; // enough to support AVX-512 9 | pub type AlignedMemory64 = AlignedMemory; 10 | 11 | pub struct AlignedMemory { 12 | p: *mut u64, 13 | sz_u64: usize, 14 | layout: Layout, 15 | } 16 | 17 | impl AlignedMemory<{ ALIGN }> { 18 | pub fn new(sz_u64: usize) -> Self { 19 | let sz_bytes = sz_u64 * size_of::(); 20 | let layout = Layout::from_size_align(sz_bytes, ALIGN).unwrap(); 21 | 22 | let ptr; 23 | unsafe { 24 | ptr = alloc_zeroed(layout); 25 | } 26 | 27 | Self { 28 | p: ptr as *mut u64, 29 | sz_u64, 30 | layout, 31 | } 32 | } 33 | 34 | // pub fn from(data: &[u8]) -> Self { 35 | // let sz_u64 = (data.len() + size_of::() - 1) / size_of::(); 36 | // let mut out = Self::new(sz_u64); 37 | // let out_slice = out.as_mut_slice(); 38 | // let mut i = 0; 39 | // for chunk in data.chunks(size_of::()) { 40 | // out_slice[i] = u64::from_ne_bytes(chunk); 41 | // i += 1; 42 | // } 43 | // out 44 | // } 45 | 46 | pub fn as_slice(&self) -> &[u64] { 47 | unsafe { from_raw_parts(self.p, self.sz_u64) } 48 | } 49 | 50 | pub fn as_mut_slice(&mut self) -> &mut [u64] { 51 | unsafe { from_raw_parts_mut(self.p, self.sz_u64) } 52 | } 53 | 54 | pub unsafe fn as_ptr(&self) -> *const u64 { 55 | self.p 56 | } 57 | 58 | pub unsafe fn as_mut_ptr(&mut self) -> *mut u64 { 59 | self.p 60 | } 61 | 62 | pub fn len(&self) -> usize { 63 | self.sz_u64 64 | } 65 | } 66 | 67 | unsafe impl Send for AlignedMemory<{ ALIGN }> {} 68 | unsafe impl Sync for AlignedMemory<{ ALIGN }> {} 69 | 70 | impl Drop for AlignedMemory<{ ALIGN }> { 71 | fn drop(&mut self) { 72 | unsafe { 73 | dealloc(self.p as *mut u8, self.layout); 74 | } 75 | } 76 | } 77 | 78 | impl Index for AlignedMemory<{ ALIGN }> { 79 | type Output = u64; 80 | 81 | fn index(&self, index: usize) -> &Self::Output { 82 | &self.as_slice()[index] 83 | } 84 | } 85 | 86 | impl IndexMut for AlignedMemory<{ ALIGN }> { 87 | fn index_mut(&mut self, index: usize) -> &mut Self::Output { 88 | &mut self.as_mut_slice()[index] 89 | } 90 | } 91 | 92 | impl Clone for AlignedMemory<{ ALIGN }> { 93 | fn clone(&self) -> Self { 94 | let mut out = Self::new(self.sz_u64); 95 | out.as_mut_slice().copy_from_slice(self.as_slice()); 96 | out 97 | } 98 | } 99 | -------------------------------------------------------------------------------- /src/arith.rs: -------------------------------------------------------------------------------- 1 | use crate::params::*; 2 | use std::mem; 3 | use std::slice; 4 | 5 | pub fn multiply_uint_mod(a: u64, b: u64, modulus: u64) -> u64 { 6 | (((a as u128) * (b as u128)) % (modulus as u128)) as u64 7 | } 8 | 9 | pub const fn log2(a: u64) -> u64 { 10 | std::mem::size_of::() as u64 * 8 - a.leading_zeros() as u64 - 1 11 | } 12 | 13 | pub fn log2_ceil(a: u64) -> u64 { 14 | f64::ceil(f64::log2(a as f64)) as u64 15 | } 16 | 17 | pub fn log2_ceil_usize(a: usize) -> usize { 18 | f64::ceil(f64::log2(a as f64)) as usize 19 | } 20 | 21 | pub fn multiply_modular(params: &Params, a: u64, b: u64, c: usize) -> u64 { 22 | if params.crt_count == 1 { 23 | return multiply_uint_mod(a, b, params.moduli[c]); 24 | } 25 | barrett_coeff_u64(params, a * b, c) 26 | } 27 | 28 | pub fn multiply_add_modular(params: &Params, a: u64, b: u64, x: u64, c: usize) -> u64 { 29 | if params.crt_count == 1 { 30 | return multiply_uint_mod(a, b, params.moduli[c]); 31 | } 32 | barrett_coeff_u64(params, a * b + x, c) 33 | } 34 | 35 | pub fn add_modular(params: &Params, a: u64, b: u64, c: usize) -> u64 { 36 | barrett_coeff_u64(params, a + b, c) 37 | } 38 | 39 | pub fn sub_modular(params: &Params, a: u64, b: u64, c: usize) -> u64 { 40 | let val = a + params.moduli[c] - barrett_coeff_u64(params, b, c); 41 | barrett_coeff_u64(params, val, c) 42 | } 43 | 44 | pub fn invert_modular(params: &Params, a: u64, c: usize) -> u64 { 45 | (params.moduli[c] - a) % params.moduli[c] 46 | } 47 | 48 | pub fn modular_reduce(params: &Params, x: u64, c: usize) -> u64 { 49 | barrett_coeff_u64(params, x, c) 50 | } 51 | 52 | pub fn exponentiate_uint_mod(operand: u64, mut exponent: u64, modulus: u64) -> u64 { 53 | if exponent == 0 { 54 | return 1; 55 | } 56 | 57 | if exponent == 1 { 58 | return operand; 59 | } 60 | 61 | let mut power = operand; 62 | let mut product; 63 | let mut intermediate = 1u64; 64 | 65 | loop { 66 | if (exponent % 2) == 1 { 67 | product = multiply_uint_mod(power, intermediate, modulus); 68 | mem::swap(&mut product, &mut intermediate); 69 | } 70 | exponent >>= 1; 71 | if exponent == 0 { 72 | break; 73 | } 74 | product = multiply_uint_mod(power, power, modulus); 75 | mem::swap(&mut product, &mut power); 76 | } 77 | intermediate 78 | } 79 | 80 | pub fn reverse_bits(x: u64, bit_count: usize) -> u64 { 81 | if bit_count == 0 { 82 | return 0; 83 | } 84 | 85 | let r = x.reverse_bits(); 86 | r >> (mem::size_of::() * 8 - bit_count) 87 | } 88 | 89 | pub fn div2_uint_mod(operand: u64, modulus: u64) -> u64 { 90 | if operand & 1 == 1 { 91 | let res = operand.overflowing_add(modulus); 92 | if res.1 { 93 | return (res.0 >> 1) | (1u64 << 63); 94 | } else { 95 | return res.0 >> 1; 96 | } 97 | } else { 98 | return operand >> 1; 99 | } 100 | } 101 | 102 | pub fn recenter(val: u64, from_modulus: u64, to_modulus: u64) -> u64 { 103 | assert!(from_modulus >= to_modulus); 104 | 105 | let from_modulus_i64 = from_modulus as i64; 106 | let to_modulus_i64 = to_modulus as i64; 107 | 108 | let mut a_val = val as i64; 109 | if val >= from_modulus / 2 { 110 | a_val -= from_modulus_i64; 111 | } 112 | a_val = a_val + (from_modulus_i64 / to_modulus_i64) * to_modulus_i64 + 2 * to_modulus_i64; 113 | a_val %= to_modulus_i64; 114 | a_val as u64 115 | } 116 | 117 | pub fn get_barrett_crs(modulus: u64) -> (u64, u64) { 118 | let numerator = [0, 0, 1]; 119 | let (_, quotient) = divide_uint192_inplace(numerator, modulus); 120 | 121 | (quotient[0], quotient[1]) 122 | } 123 | 124 | pub fn get_barrett(moduli: &[u64]) -> ([u64; MAX_MODULI], [u64; MAX_MODULI]) { 125 | let mut cr0 = [0u64; MAX_MODULI]; 126 | let mut cr1 = [0u64; MAX_MODULI]; 127 | for i in 0..moduli.len() { 128 | (cr0[i], cr1[i]) = get_barrett_crs(moduli[i]); 129 | } 130 | (cr0, cr1) 131 | } 132 | 133 | pub fn barrett_raw_u64(input: u64, const_ratio_1: u64, modulus: u64) -> u64 { 134 | let tmp = (((input as u128) * (const_ratio_1 as u128)) >> 64) as u64; 135 | 136 | // Barrett subtraction 137 | let res = input - tmp * modulus; 138 | 139 | // One more subtraction is enough 140 | if res >= modulus { 141 | res - modulus 142 | } else { 143 | res 144 | } 145 | } 146 | 147 | pub fn barrett_u64(params: &Params, val: u64) -> u64 { 148 | barrett_raw_u64(val, params.barrett_cr_1_modulus, params.modulus) 149 | } 150 | 151 | pub fn barrett_coeff_u64(params: &Params, val: u64, n: usize) -> u64 { 152 | barrett_raw_u64(val, params.barrett_cr_1[n], params.moduli[n]) 153 | } 154 | 155 | fn split(x: u128) -> (u64, u64) { 156 | let lo = x & ((1u128 << 64) - 1); 157 | let hi = x >> 64; 158 | (lo as u64, hi as u64) 159 | } 160 | 161 | fn mul_u128(a: u64, b: u64) -> (u64, u64) { 162 | let prod = (a as u128) * (b as u128); 163 | split(prod) 164 | } 165 | 166 | fn add_u64(op1: u64, op2: u64, out: &mut u64) -> u64 { 167 | match op1.checked_add(op2) { 168 | Some(x) => { 169 | *out = x; 170 | 0 171 | } 172 | None => 1, 173 | } 174 | } 175 | 176 | fn barrett_raw_u128(val: u128, cr0: u64, cr1: u64, modulus: u64) -> u64 { 177 | let (zx, zy) = split(val); 178 | 179 | let mut tmp1 = 0; 180 | let mut tmp3; 181 | let mut carry; 182 | let (_, prody) = mul_u128(zx, cr0); 183 | carry = prody; 184 | let (mut tmp2x, mut tmp2y) = mul_u128(zx, cr1); 185 | tmp3 = tmp2y + add_u64(tmp2x, carry, &mut tmp1); 186 | (tmp2x, tmp2y) = mul_u128(zy, cr0); 187 | carry = tmp2y + add_u64(tmp1, tmp2x, &mut tmp1); 188 | tmp1 = zy * cr1 + tmp3 + carry; 189 | tmp3 = zx.wrapping_sub(tmp1.wrapping_mul(modulus)); 190 | 191 | tmp3 192 | 193 | // uint64_t zx = val & (((__uint128_t)1 << 64) - 1); 194 | // uint64_t zy = val >> 64; 195 | 196 | // uint64_t tmp1, tmp3, carry; 197 | // ulonglong2_h prod = umul64wide(zx, const_ratio_0); 198 | // carry = prod.y; 199 | // ulonglong2_h tmp2 = umul64wide(zx, const_ratio_1); 200 | // tmp3 = tmp2.y + cpu_add_u64(tmp2.x, carry, &tmp1); 201 | // tmp2 = umul64wide(zy, const_ratio_0); 202 | // carry = tmp2.y + cpu_add_u64(tmp1, tmp2.x, &tmp1); 203 | // tmp1 = zy * const_ratio_1 + tmp3 + carry; 204 | // tmp3 = zx - tmp1 * modulus; 205 | 206 | // return tmp3; 207 | } 208 | 209 | pub fn barrett_reduction_u128_raw(modulus: u64, cr0: u64, cr1: u64, val: u128) -> u64 { 210 | let mut reduced_val = barrett_raw_u128(val, cr0, cr1, modulus); 211 | reduced_val -= (modulus) * ((reduced_val >= modulus) as u64); 212 | reduced_val 213 | } 214 | 215 | pub fn barrett_reduction_u128(params: &Params, val: u128) -> u64 { 216 | let modulus = params.modulus; 217 | let cr0 = params.barrett_cr_0_modulus; 218 | let cr1 = params.barrett_cr_1_modulus; 219 | barrett_reduction_u128_raw(modulus, cr0, cr1, val) 220 | } 221 | 222 | // Following code is ported from SEAL (github.com/microsoft/SEAL) 223 | 224 | pub fn get_significant_bit_count(val: &[u64]) -> usize { 225 | for i in (0..val.len()).rev() { 226 | for j in (0..64).rev() { 227 | if (val[i] & (1u64 << j)) != 0 { 228 | return i * 64 + j + 1; 229 | } 230 | } 231 | } 232 | 0 233 | } 234 | 235 | fn divide_round_up(num: usize, denom: usize) -> usize { 236 | (num + (denom - 1)) / denom 237 | } 238 | 239 | const BITS_PER_U64: usize = u64::BITS as usize; 240 | 241 | fn left_shift_uint192(operand: [u64; 3], shift_amount: usize) -> [u64; 3] { 242 | let mut result = [0u64; 3]; 243 | if (shift_amount & (BITS_PER_U64 << 1)) != 0 { 244 | result[2] = operand[0]; 245 | result[1] = 0; 246 | result[0] = 0; 247 | } else if (shift_amount & BITS_PER_U64) != 0 { 248 | result[2] = operand[1]; 249 | result[1] = operand[0]; 250 | result[0] = 0; 251 | } else { 252 | result[2] = operand[2]; 253 | result[1] = operand[1]; 254 | result[0] = operand[0]; 255 | } 256 | 257 | let bit_shift_amount = shift_amount & (BITS_PER_U64 - 1); 258 | 259 | if bit_shift_amount != 0 { 260 | let neg_bit_shift_amount = BITS_PER_U64 - bit_shift_amount; 261 | 262 | result[2] = (result[2] << bit_shift_amount) | (result[1] >> neg_bit_shift_amount); 263 | result[1] = (result[1] << bit_shift_amount) | (result[0] >> neg_bit_shift_amount); 264 | result[0] = result[0] << bit_shift_amount; 265 | } 266 | 267 | result 268 | } 269 | 270 | fn right_shift_uint192(operand: [u64; 3], shift_amount: usize) -> [u64; 3] { 271 | let mut result = [0u64; 3]; 272 | 273 | if (shift_amount & (BITS_PER_U64 << 1)) != 0 { 274 | result[0] = operand[2]; 275 | result[1] = 0; 276 | result[2] = 0; 277 | } else if (shift_amount & BITS_PER_U64) != 0 { 278 | result[0] = operand[1]; 279 | result[1] = operand[2]; 280 | result[2] = 0; 281 | } else { 282 | result[2] = operand[2]; 283 | result[1] = operand[1]; 284 | result[0] = operand[0]; 285 | } 286 | 287 | let bit_shift_amount = shift_amount & (BITS_PER_U64 - 1); 288 | 289 | if bit_shift_amount != 0 { 290 | let neg_bit_shift_amount = BITS_PER_U64 - bit_shift_amount; 291 | 292 | result[0] = (result[0] >> bit_shift_amount) | (result[1] << neg_bit_shift_amount); 293 | result[1] = (result[1] >> bit_shift_amount) | (result[2] << neg_bit_shift_amount); 294 | result[2] = result[2] >> bit_shift_amount; 295 | } 296 | 297 | result 298 | } 299 | 300 | fn add_uint64(operand1: u64, operand2: u64, result: &mut u64) -> u8 { 301 | *result = operand1.wrapping_add(operand2); 302 | (*result < operand1) as u8 303 | } 304 | 305 | fn add_uint64_carry(operand1: u64, operand2: u64, carry: u8, result: &mut u64) -> u8 { 306 | let operand1 = operand1.wrapping_add(operand2); 307 | *result = operand1.wrapping_add(carry as u64); 308 | ((operand1 < operand2) || (!operand1 < (carry as u64))) as u8 309 | } 310 | 311 | fn sub_uint64(operand1: u64, operand2: u64, result: &mut u64) -> u8 { 312 | *result = operand1.wrapping_sub(operand2); 313 | (operand2 > operand1) as u8 314 | } 315 | 316 | fn sub_uint64_borrow(operand1: u64, operand2: u64, borrow: u8, result: &mut u64) -> u8 { 317 | let diff = operand1.wrapping_sub(operand2); 318 | *result = diff.wrapping_sub((borrow != 0) as u64); 319 | ((diff > operand1) || (diff < (borrow as u64))) as u8 320 | } 321 | 322 | pub fn sub_uint(operand1: &[u64], operand2: &[u64], uint64_count: usize, result: &mut [u64]) -> u8 { 323 | let mut borrow = sub_uint64(operand1[0], operand2[0], &mut result[0]); 324 | 325 | for i in 0..uint64_count - 1 { 326 | let mut temp_result = 0u64; 327 | borrow = sub_uint64_borrow(operand1[1 + i], operand2[1 + i], borrow, &mut temp_result); 328 | result[1 + i] = temp_result; 329 | } 330 | 331 | borrow 332 | } 333 | 334 | pub fn add_uint(operand1: &[u64], operand2: &[u64], uint64_count: usize, result: &mut [u64]) -> u8 { 335 | let mut carry = add_uint64(operand1[0], operand2[0], &mut result[0]); 336 | 337 | for i in 0..uint64_count - 1 { 338 | let mut temp_result = 0u64; 339 | carry = add_uint64_carry(operand1[1 + i], operand2[1 + i], carry, &mut temp_result); 340 | result[1 + i] = temp_result; 341 | } 342 | 343 | carry 344 | } 345 | 346 | pub fn divide_uint192_inplace(mut numerator: [u64; 3], denominator: u64) -> ([u64; 3], [u64; 3]) { 347 | let mut numerator_bits = get_significant_bit_count(&numerator); 348 | let mut denominator_bits = get_significant_bit_count(slice::from_ref(&denominator)); 349 | 350 | let mut quotient = [0u64; 3]; 351 | 352 | if numerator_bits < denominator_bits { 353 | return (numerator, quotient); 354 | } 355 | 356 | let uint64_count = divide_round_up(numerator_bits, BITS_PER_U64); 357 | 358 | if uint64_count == 1 { 359 | quotient[0] = numerator[0] / denominator; 360 | numerator[0] -= quotient[0] * denominator; 361 | return (numerator, quotient); 362 | } 363 | 364 | let mut shifted_denominator = [0u64; 3]; 365 | shifted_denominator[0] = denominator; 366 | 367 | let mut difference = [0u64; 3]; 368 | 369 | let denominator_shift = numerator_bits - denominator_bits; 370 | 371 | let shifted_denominator = left_shift_uint192(shifted_denominator, denominator_shift); 372 | denominator_bits += denominator_shift; 373 | 374 | let mut remaining_shifts = denominator_shift; 375 | while numerator_bits == denominator_bits { 376 | if (sub_uint( 377 | &numerator, 378 | &shifted_denominator, 379 | uint64_count, 380 | &mut difference, 381 | )) != 0 382 | { 383 | if remaining_shifts == 0 { 384 | break; 385 | } 386 | 387 | add_uint( 388 | &difference.clone(), 389 | &numerator, 390 | uint64_count, 391 | &mut difference, 392 | ); 393 | 394 | quotient = left_shift_uint192(quotient, 1); 395 | remaining_shifts -= 1; 396 | } 397 | 398 | quotient[0] |= 1; 399 | 400 | numerator_bits = get_significant_bit_count(&difference); 401 | let mut numerator_shift = denominator_bits - numerator_bits; 402 | if numerator_shift > remaining_shifts { 403 | numerator_shift = remaining_shifts; 404 | } 405 | 406 | if numerator_bits > 0 { 407 | numerator = left_shift_uint192(difference, numerator_shift); 408 | numerator_bits += numerator_shift; 409 | } else { 410 | for w in 0..uint64_count { 411 | numerator[w] = 0; 412 | } 413 | } 414 | 415 | quotient = left_shift_uint192(quotient, numerator_shift); 416 | remaining_shifts -= numerator_shift; 417 | } 418 | 419 | if numerator_bits > 0 { 420 | numerator = right_shift_uint192(numerator, denominator_shift); 421 | } 422 | 423 | (numerator, quotient) 424 | } 425 | 426 | pub fn recenter_mod(val: u64, small_modulus: u64, large_modulus: u64) -> u64 { 427 | assert!(val < small_modulus); 428 | let mut val_i64 = val as i64; 429 | let small_modulus_i64 = small_modulus as i64; 430 | let large_modulus_i64 = large_modulus as i64; 431 | if val_i64 > small_modulus_i64 / 2 { 432 | val_i64 -= small_modulus_i64; 433 | } 434 | if val_i64 < 0 { 435 | val_i64 += large_modulus_i64; 436 | } 437 | val_i64 as u64 438 | } 439 | 440 | pub fn rescale(a: u64, inp_mod: u64, out_mod: u64) -> u64 { 441 | let inp_mod_i64 = inp_mod as i64; 442 | let out_mod_i128 = out_mod as i128; 443 | let mut inp_val = (a % inp_mod) as i64; 444 | if inp_val >= (inp_mod_i64 / 2) { 445 | inp_val -= inp_mod_i64; 446 | } 447 | let sign: i64 = if inp_val >= 0 { 1 } else { -1 }; 448 | let val = (inp_val as i128) * (out_mod as i128); 449 | let mut result = (val + (sign * (inp_mod_i64 / 2)) as i128) / (inp_mod as i128); 450 | result = (result + ((inp_mod / out_mod) * out_mod) as i128 + (2 * out_mod_i128)) % out_mod_i128; 451 | 452 | assert!(result >= 0); 453 | 454 | ((result + out_mod_i128) % out_mod_i128) as u64 455 | } 456 | 457 | #[cfg(test)] 458 | mod test { 459 | use super::*; 460 | use crate::util::get_seeded_rng; 461 | use rand::Rng; 462 | 463 | fn combine(lo: u64, hi: u64) -> u128 { 464 | (lo as u128) & ((hi as u128) << 64) 465 | } 466 | 467 | #[test] 468 | fn div2_uint_mod_correct() { 469 | assert_eq!(div2_uint_mod(3, 7), 5); 470 | } 471 | 472 | #[test] 473 | fn divide_uint192_inplace_correct() { 474 | assert_eq!( 475 | divide_uint192_inplace([35, 0, 0], 7), 476 | ([0, 0, 0], [5, 0, 0]) 477 | ); 478 | assert_eq!( 479 | divide_uint192_inplace([0x10101010, 0x2B2B2B2B, 0xF1F1F1F1], 0x1000), 480 | ( 481 | [0x10, 0, 0], 482 | [0xB2B0000000010101, 0x1F1000000002B2B2, 0xF1F1F] 483 | ) 484 | ); 485 | } 486 | 487 | #[test] 488 | fn get_barrett_crs_correct() { 489 | assert_eq!( 490 | get_barrett_crs(268369921u64), 491 | (16144578669088582089u64, 68736257792u64) 492 | ); 493 | assert_eq!( 494 | get_barrett_crs(249561089u64), 495 | (10966983149909726427u64, 73916747789u64) 496 | ); 497 | assert_eq!( 498 | get_barrett_crs(66974689739603969u64), 499 | (7906011006380390721u64, 275u64) 500 | ); 501 | } 502 | 503 | #[test] 504 | fn barrett_reduction_u128_raw_correct() { 505 | let modulus = 66974689739603969u64; 506 | let modulus_u128 = modulus as u128; 507 | let exec = |val| { 508 | barrett_reduction_u128_raw(66974689739603969u64, 7906011006380390721u64, 275u64, val) 509 | }; 510 | assert_eq!(exec(modulus_u128), 0); 511 | assert_eq!(exec(modulus_u128 + 1), 1); 512 | assert_eq!(exec(modulus_u128 * 7 + 5), 5); 513 | 514 | let mut rng = get_seeded_rng(); 515 | for _ in 0..100 { 516 | let val = combine(rng.gen(), rng.gen()); 517 | assert_eq!(exec(val), (val % modulus_u128) as u64); 518 | } 519 | } 520 | 521 | #[test] 522 | fn barrett_raw_u64_correct() { 523 | let modulus = 66974689739603969u64; 524 | let cr1 = 275u64; 525 | 526 | let mut rng = get_seeded_rng(); 527 | for _ in 0..100 { 528 | let val = rng.gen(); 529 | assert_eq!(barrett_raw_u64(val, cr1, modulus), val % modulus); 530 | } 531 | } 532 | } 533 | -------------------------------------------------------------------------------- /src/client.rs: -------------------------------------------------------------------------------- 1 | use crate::{ 2 | arith::*, discrete_gaussian::*, gadget::*, number_theory::*, params::*, poly::*, util::*, 3 | }; 4 | use rand::seq::SliceRandom; 5 | use rand::{Rng, SeedableRng}; 6 | use rand_chacha::ChaCha20Rng; 7 | use std::{iter::once, mem::size_of}; 8 | use subtle::ConditionallySelectable; 9 | use subtle::ConstantTimeEq; 10 | 11 | pub type Seed = ::Seed; 12 | pub const SEED_LENGTH: usize = 32; 13 | pub const HAMMING_WEIGHT: usize = 256; 14 | 15 | pub static mut CLIENT_TEST: Option<(PolyMatrixRaw, PolyMatrixRaw)> = None; 16 | 17 | pub const DEFAULT_PARAMS: &'static str = r#" 18 | {"n": 2, 19 | "nu_1": 10, 20 | "nu_2": 6, 21 | "p": 512, 22 | "q2_bits": 21, 23 | "s_e": 85.83255142749422, 24 | "t_gsw": 10, 25 | "t_conv": 4, 26 | "t_exp_left": 16, 27 | "t_exp_right": 56, 28 | "instances": 11, 29 | "db_item_size": 100000 } 30 | "#; 31 | 32 | const UUID_V4_LEN: usize = 36; 33 | 34 | fn new_vec_raw<'a>( 35 | params: &'a Params, 36 | num: usize, 37 | rows: usize, 38 | cols: usize, 39 | ) -> Vec> { 40 | let mut v = Vec::with_capacity(num); 41 | for _ in 0..num { 42 | v.push(PolyMatrixRaw::zero(params, rows, cols)); 43 | } 44 | v 45 | } 46 | 47 | fn get_inv_from_rng(params: &Params, rng: &mut ChaCha20Rng) -> u64 { 48 | params.modulus - (rng.gen::() % params.modulus) 49 | } 50 | 51 | fn mat_sz_bytes_excl_first_row(a: &PolyMatrixRaw) -> usize { 52 | (a.rows - 1) * a.cols * a.params.poly_len * size_of::() 53 | } 54 | 55 | fn serialize_polymatrix_for_rng(vec: &mut Vec, a: &PolyMatrixRaw) { 56 | let offs = a.cols * a.params.poly_len; // skip the first row 57 | for i in 0..(a.rows - 1) * a.cols * a.params.poly_len { 58 | vec.extend_from_slice(&u64::to_ne_bytes(a.data[offs + i])); 59 | } 60 | } 61 | 62 | fn serialize_vec_polymatrix_for_rng(vec: &mut Vec, a: &Vec) { 63 | for i in 0..a.len() { 64 | serialize_polymatrix_for_rng(vec, &a[i]); 65 | } 66 | } 67 | 68 | fn deserialize_polymatrix_rng(a: &mut PolyMatrixRaw, data: &[u8], rng: &mut ChaCha20Rng) -> usize { 69 | let (first_row, rest) = a 70 | .data 71 | .as_mut_slice() 72 | .split_at_mut(a.cols * a.params.poly_len); 73 | for i in 0..first_row.len() { 74 | first_row[i] = get_inv_from_rng(a.params, rng); 75 | } 76 | for (i, chunk) in data.chunks(size_of::()).enumerate() { 77 | rest[i] = u64::from_ne_bytes(chunk.try_into().unwrap()); 78 | } 79 | mat_sz_bytes_excl_first_row(a) 80 | } 81 | 82 | fn deserialize_vec_polymatrix_rng( 83 | a: &mut Vec, 84 | data: &[u8], 85 | rng: &mut ChaCha20Rng, 86 | ) -> usize { 87 | let mut chunks = data.chunks(mat_sz_bytes_excl_first_row(&a[0])); 88 | let mut bytes_read = 0; 89 | for i in 0..a.len() { 90 | bytes_read += deserialize_polymatrix_rng(&mut a[i], chunks.next().unwrap(), rng); 91 | } 92 | bytes_read 93 | } 94 | 95 | fn extract_excl_rng_data(v_buf: &[u64]) -> Vec { 96 | let mut out = Vec::new(); 97 | for i in 0..v_buf.len() { 98 | if i % 2 == 1 { 99 | out.push(v_buf[i]); 100 | } 101 | } 102 | out 103 | } 104 | 105 | fn _interleave_rng_data(params: &Params, v_buf: &[u64], rng: &mut ChaCha20Rng) -> Vec { 106 | let mut out = Vec::new(); 107 | 108 | let mut reg_cts = Vec::new(); 109 | for _ in 0..params.num_expanded() { 110 | let mut sigma = PolyMatrixRaw::zero(¶ms, 2, 1); 111 | for z in 0..params.poly_len { 112 | sigma.data[z] = get_inv_from_rng(params, rng); 113 | } 114 | reg_cts.push(sigma.ntt()); 115 | } 116 | // reorient into server's preferred indexing 117 | let reg_cts_buf_words = params.num_expanded() * 2 * params.poly_len; 118 | let mut reg_cts_buf = vec![0u64; reg_cts_buf_words]; 119 | reorient_reg_ciphertexts(params, reg_cts_buf.as_mut_slice(), ®_cts); 120 | 121 | assert_eq!(reg_cts_buf_words, 2 * v_buf.len()); 122 | 123 | for i in 0..v_buf.len() { 124 | out.push(reg_cts_buf[2 * i]); 125 | out.push(v_buf[i]); 126 | } 127 | out 128 | } 129 | 130 | fn gen_ternary_mat(mat: &mut PolyMatrixRaw, hamming: usize, rng: &mut ChaCha20Rng) { 131 | let modulus = mat.params.modulus; 132 | for r in 0..mat.rows { 133 | for c in 0..mat.cols { 134 | let pol = mat.get_poly_mut(r, c); 135 | for i in 0..hamming { 136 | pol[i] = 1; 137 | } 138 | for i in hamming..2 * hamming { 139 | pol[i] = modulus - 1; 140 | } 141 | pol.shuffle(rng); 142 | } 143 | } 144 | } 145 | 146 | /// The maximum number of expansion rounds supported for a single ciphertext. 147 | pub const MAX_EXP_DIM: usize = 8; 148 | 149 | #[derive(Clone)] 150 | pub struct PublicParameters<'a> { 151 | pub v_packing: Vec>, // Ws 152 | pub v_expansion_left: Option>>, 153 | pub v_expansion_right: Option>>, 154 | pub v_conversion: Option>>, // V 155 | pub seed: Option, 156 | } 157 | 158 | impl<'a> PublicParameters<'a> { 159 | pub fn init(params: &'a Params) -> Self { 160 | if params.expand_queries { 161 | PublicParameters { 162 | v_packing: Vec::new(), 163 | v_expansion_left: Some(Vec::new()), 164 | v_expansion_right: Some(Vec::new()), 165 | v_conversion: Some(Vec::new()), 166 | seed: None, 167 | } 168 | } else { 169 | PublicParameters { 170 | v_packing: Vec::new(), 171 | v_expansion_left: None, 172 | v_expansion_right: None, 173 | v_conversion: None, 174 | seed: None, 175 | } 176 | } 177 | } 178 | 179 | fn from_ntt_alloc_vec(v: &Vec>) -> Option>> { 180 | Some(v.iter().map(from_ntt_alloc).collect()) 181 | } 182 | 183 | fn from_ntt_alloc_opt_vec( 184 | v: &Option>>, 185 | ) -> Option>> { 186 | Some(v.as_ref()?.iter().map(from_ntt_alloc).collect()) 187 | } 188 | 189 | fn to_ntt_alloc_vec(v: &Vec>) -> Option>> { 190 | Some(v.iter().map(to_ntt_alloc).collect()) 191 | } 192 | 193 | pub fn to_raw(&self) -> Vec>> { 194 | vec![ 195 | Self::from_ntt_alloc_vec(&self.v_packing), 196 | Self::from_ntt_alloc_opt_vec(&self.v_expansion_left), 197 | Self::from_ntt_alloc_opt_vec(&self.v_expansion_right), 198 | Self::from_ntt_alloc_opt_vec(&self.v_conversion), 199 | ] 200 | } 201 | 202 | pub fn serialize(&self) -> Vec { 203 | let mut data = Vec::new(); 204 | if self.seed.is_some() { 205 | let seed = self.seed.as_ref().unwrap(); 206 | data.extend(seed); 207 | } 208 | for v in self.to_raw().iter() { 209 | if v.is_some() { 210 | serialize_vec_polymatrix_for_rng(&mut data, v.as_ref().unwrap()); 211 | } 212 | } 213 | data 214 | } 215 | 216 | pub fn deserialize(params: &'a Params, data: &[u8]) -> Self { 217 | assert_eq!(params.setup_bytes(), data.len()); 218 | 219 | let mut idx = 0; 220 | 221 | let seed = data[0..SEED_LENGTH].try_into().unwrap(); 222 | let mut rng = ChaCha20Rng::from_seed(seed); 223 | idx += SEED_LENGTH; 224 | 225 | let mut v_packing = new_vec_raw(params, params.n, params.n + 1, params.t_conv); 226 | idx += deserialize_vec_polymatrix_rng(&mut v_packing, &data[idx..], &mut rng); 227 | 228 | if params.expand_queries { 229 | let mut v_expansion_left = new_vec_raw(params, params.g(), 2, params.t_exp_left); 230 | idx += deserialize_vec_polymatrix_rng(&mut v_expansion_left, &data[idx..], &mut rng); 231 | 232 | let mut v_expansion_right = v_expansion_left.clone(); 233 | if params.version == 0 || params.t_exp_right != params.t_exp_left { 234 | let mut v_expansion_right_tmp = 235 | new_vec_raw(params, params.stop_round() + 1, 2, params.t_exp_right); 236 | idx += deserialize_vec_polymatrix_rng( 237 | &mut v_expansion_right_tmp, 238 | &data[idx..], 239 | &mut rng, 240 | ); 241 | v_expansion_right = v_expansion_right_tmp; 242 | } 243 | 244 | let mut v_conversion = new_vec_raw(params, 1, 2, 2 * params.t_conv); 245 | _ = deserialize_vec_polymatrix_rng(&mut v_conversion, &data[idx..], &mut rng); 246 | 247 | Self { 248 | v_packing: Self::to_ntt_alloc_vec(&v_packing).unwrap(), 249 | v_expansion_left: Self::to_ntt_alloc_vec(&v_expansion_left), 250 | v_expansion_right: Self::to_ntt_alloc_vec(&v_expansion_right), 251 | v_conversion: Self::to_ntt_alloc_vec(&v_conversion), 252 | seed: Some(seed), 253 | } 254 | } else { 255 | Self { 256 | v_packing: Self::to_ntt_alloc_vec(&v_packing).unwrap(), 257 | v_expansion_left: None, 258 | v_expansion_right: None, 259 | v_conversion: None, 260 | seed: Some(seed), 261 | } 262 | } 263 | } 264 | } 265 | 266 | #[derive(Clone)] 267 | pub struct Query<'a> { 268 | pub ct: Option>, 269 | pub v_buf: Option>, 270 | pub v_ct: Option>>, 271 | pub seed: Option, 272 | } 273 | 274 | impl<'a> Query<'a> { 275 | pub fn empty() -> Self { 276 | Query { 277 | ct: None, 278 | v_ct: None, 279 | v_buf: None, 280 | seed: None, 281 | } 282 | } 283 | 284 | pub fn serialize(&self) -> Vec { 285 | let mut data = Vec::new(); 286 | if self.seed.is_some() { 287 | let seed = self.seed.as_ref().unwrap(); 288 | data.extend(seed); 289 | } 290 | if self.ct.is_some() { 291 | let ct = self.ct.as_ref().unwrap(); 292 | serialize_polymatrix_for_rng(&mut data, &ct); 293 | } 294 | if self.v_buf.is_some() { 295 | let v_buf = self.v_buf.as_ref().unwrap(); 296 | let v_buf_extracted = extract_excl_rng_data(&v_buf); 297 | data.extend(v_buf_extracted.iter().map(|x| x.to_ne_bytes()).flatten()); 298 | } 299 | if self.v_ct.is_some() { 300 | let v_ct = self.v_ct.as_ref().unwrap(); 301 | for x in v_ct { 302 | serialize_polymatrix_for_rng(&mut data, x); 303 | } 304 | } 305 | data 306 | } 307 | 308 | pub fn deserialize(params: &'a Params, mut data: &[u8]) -> Self { 309 | assert_eq!(params.query_bytes(), data.len()); 310 | 311 | let mut out = Query::empty(); 312 | let seed = data[0..SEED_LENGTH].try_into().unwrap(); 313 | out.seed = Some(seed); 314 | let mut rng = ChaCha20Rng::from_seed(seed); 315 | data = &data[SEED_LENGTH..]; 316 | if params.expand_queries { 317 | let mut ct = PolyMatrixRaw::zero(params, 2, 1); 318 | deserialize_polymatrix_rng(&mut ct, data, &mut rng); 319 | out.ct = Some(ct); 320 | } else { 321 | // let v_buf_bytes = params.query_v_buf_bytes(); 322 | // let v_buf: Vec = (&data[..v_buf_bytes]) 323 | // .chunks(size_of::()) 324 | // .map(|x| u64::from_ne_bytes(x.try_into().unwrap())) 325 | // .collect(); 326 | // let v_buf_interleaved = interleave_rng_data(params, &v_buf, &mut rng); 327 | // out.v_buf = Some(v_buf_interleaved); 328 | 329 | let mut v_ct = new_vec_raw(params, params.db_dim_1, 2, 2 * params.t_gsw); 330 | deserialize_vec_polymatrix_rng(&mut v_ct, &data, &mut rng); 331 | out.v_ct = Some(v_ct); 332 | } 333 | out 334 | } 335 | } 336 | 337 | pub fn matrix_with_identity<'a>(p: &PolyMatrixRaw<'a>) -> PolyMatrixRaw<'a> { 338 | assert_eq!(p.cols, 1); 339 | let mut r = PolyMatrixRaw::zero(p.params, p.rows, p.rows + 1); 340 | r.copy_into(p, 0, 0); 341 | r.copy_into(&PolyMatrixRaw::identity(p.params, p.rows, p.rows), 0, 1); 342 | r 343 | } 344 | 345 | fn params_with_moduli(params: &Params, moduli: &Vec) -> Params { 346 | Params::init( 347 | params.poly_len, 348 | moduli, 349 | params.noise_width, 350 | params.n, 351 | params.pt_modulus, 352 | params.q2_bits, 353 | params.t_conv, 354 | params.t_exp_left, 355 | params.t_exp_right, 356 | params.t_gsw, 357 | params.expand_queries, 358 | params.db_dim_1, 359 | params.db_dim_2, 360 | params.instances, 361 | params.db_item_size, 362 | params.version, 363 | ) 364 | } 365 | 366 | pub struct Client<'a> { 367 | params: &'a Params, 368 | sk_gsw: PolyMatrixRaw<'a>, 369 | sk_reg: PolyMatrixRaw<'a>, 370 | sk_gsw_full: PolyMatrixRaw<'a>, 371 | sk_reg_full: PolyMatrixRaw<'a>, 372 | dg: DiscreteGaussian, 373 | } 374 | 375 | impl<'a> Client<'a> { 376 | pub fn init(params: &'a Params) -> Self { 377 | let sk_gsw_dims = params.get_sk_gsw(); 378 | let sk_reg_dims = params.get_sk_reg(); 379 | let sk_gsw = PolyMatrixRaw::zero(params, sk_gsw_dims.0, sk_gsw_dims.1); 380 | let sk_reg = PolyMatrixRaw::zero(params, sk_reg_dims.0, sk_reg_dims.1); 381 | let sk_gsw_full = matrix_with_identity(&sk_gsw); 382 | let sk_reg_full = matrix_with_identity(&sk_reg); 383 | 384 | let dg = DiscreteGaussian::init(params.noise_width); 385 | 386 | Self { 387 | params, 388 | sk_gsw, 389 | sk_reg, 390 | sk_gsw_full, 391 | sk_reg_full, 392 | dg, 393 | } 394 | } 395 | 396 | pub fn get_sk_reg(&self) -> &PolyMatrixRaw<'a> { 397 | &self.sk_reg 398 | } 399 | 400 | pub fn get_sk_gsw(&self) -> &PolyMatrixRaw<'a> { 401 | &self.sk_gsw 402 | } 403 | 404 | pub fn get_sk_gsw_full(&self) -> &PolyMatrixRaw<'a> { 405 | &self.sk_gsw_full 406 | } 407 | 408 | fn get_fresh_gsw_public_key( 409 | &self, 410 | m: usize, 411 | rng: &mut ChaCha20Rng, 412 | rng_pub: &mut ChaCha20Rng, 413 | ) -> PolyMatrixRaw<'a> { 414 | let params = self.params; 415 | let n = params.n; 416 | 417 | let a = PolyMatrixRaw::random_rng(params, 1, m, rng_pub); 418 | let e = PolyMatrixRaw::noise(params, n, m, &self.dg, rng); 419 | let a_inv = -&a; 420 | let b_p = &self.sk_gsw.ntt() * &a.ntt(); 421 | let b = &e.ntt() + &b_p; 422 | let p = stack(&a_inv, &b.raw()); 423 | p 424 | } 425 | 426 | fn get_regev_sample( 427 | &self, 428 | rng: &mut ChaCha20Rng, 429 | rng_pub: &mut ChaCha20Rng, 430 | ) -> PolyMatrixNTT<'a> { 431 | let params = self.params; 432 | let a = PolyMatrixRaw::random_rng(params, 1, 1, rng_pub); 433 | let e = PolyMatrixRaw::fast_noise(params, 1, 1, &self.dg, rng); 434 | let b_p = &self.sk_reg.ntt() * &a.ntt(); 435 | let b = &e.ntt() + &b_p; 436 | let mut p = PolyMatrixNTT::zero(params, 2, 1); 437 | p.copy_into(&(-&a).ntt(), 0, 0); 438 | p.copy_into(&b, 1, 0); 439 | p 440 | } 441 | 442 | fn get_scaled_regev_sample( 443 | &self, 444 | rng: &mut ChaCha20Rng, 445 | rng_pub: &mut ChaCha20Rng, 446 | scale: u64, 447 | ) -> PolyMatrixNTT<'a> { 448 | let params = self.params; 449 | let a = PolyMatrixRaw::random_rng(params, 1, 1, rng_pub); 450 | let mut e = PolyMatrixRaw::fast_noise(params, 1, 1, &self.dg, rng); 451 | for i in 0..params.poly_len { 452 | e.data[i] = multiply_uint_mod(e.data[i], scale, params.modulus); 453 | } 454 | let b_p = &self.sk_reg.ntt() * &a.ntt(); 455 | let b = &e.ntt() + &b_p; 456 | let mut p = PolyMatrixNTT::zero(params, 2, 1); 457 | p.copy_into(&(-&a).ntt(), 0, 0); 458 | p.copy_into(&b, 1, 0); 459 | p 460 | } 461 | 462 | fn get_fresh_reg_public_key( 463 | &self, 464 | m: usize, 465 | rng: &mut ChaCha20Rng, 466 | rng_pub: &mut ChaCha20Rng, 467 | ) -> PolyMatrixNTT<'a> { 468 | let params = self.params; 469 | 470 | let mut p = PolyMatrixNTT::zero(params, 2, m); 471 | 472 | for i in 0..m { 473 | p.copy_into(&self.get_regev_sample(rng, rng_pub), 0, i); 474 | } 475 | p 476 | } 477 | 478 | fn get_fresh_scaled_reg_public_key( 479 | &self, 480 | m: usize, 481 | rng: &mut ChaCha20Rng, 482 | rng_pub: &mut ChaCha20Rng, 483 | scale: u64, 484 | ) -> PolyMatrixNTT<'a> { 485 | let params = self.params; 486 | 487 | let mut p = PolyMatrixNTT::zero(params, 2, m); 488 | 489 | for i in 0..m { 490 | p.copy_into(&self.get_scaled_regev_sample(rng, rng_pub, scale), 0, i); 491 | } 492 | p 493 | } 494 | 495 | pub fn encrypt_matrix_gsw( 496 | &self, 497 | ag: &PolyMatrixNTT<'a>, 498 | rng: &mut ChaCha20Rng, 499 | rng_pub: &mut ChaCha20Rng, 500 | ) -> PolyMatrixNTT<'a> { 501 | let mx = ag.cols; 502 | let p = self.get_fresh_gsw_public_key(mx, rng, rng_pub); 503 | let res = &(p.ntt()) + &(ag.pad_top(1)); 504 | res 505 | } 506 | 507 | pub fn encrypt_matrix_reg( 508 | &self, 509 | a: &PolyMatrixNTT<'a>, 510 | rng: &mut ChaCha20Rng, 511 | rng_pub: &mut ChaCha20Rng, 512 | ) -> PolyMatrixNTT<'a> { 513 | let m = a.cols; 514 | let p = self.get_fresh_reg_public_key(m, rng, rng_pub); 515 | &p + &a.pad_top(1) 516 | } 517 | 518 | pub fn encrypt_matrix_scaled_reg( 519 | &self, 520 | a: &PolyMatrixNTT<'a>, 521 | rng: &mut ChaCha20Rng, 522 | rng_pub: &mut ChaCha20Rng, 523 | scale: u64, 524 | ) -> PolyMatrixNTT<'a> { 525 | let m = a.cols; 526 | let p = self.get_fresh_scaled_reg_public_key(m, rng, rng_pub, scale); 527 | &p + &a.pad_top(1) 528 | } 529 | 530 | pub fn decrypt_matrix_reg(&self, a: &PolyMatrixNTT<'a>) -> PolyMatrixNTT<'a> { 531 | &self.sk_reg_full.ntt() * a 532 | } 533 | 534 | pub fn decrypt_matrix_gsw(&self, a: &PolyMatrixNTT<'a>) -> PolyMatrixNTT<'a> { 535 | &self.sk_gsw_full.ntt() * a 536 | } 537 | 538 | pub fn generate_expansion_params( 539 | &self, 540 | num_exp: usize, 541 | m_exp: usize, 542 | rng: &mut ChaCha20Rng, 543 | rng_pub: &mut ChaCha20Rng, 544 | ) -> Vec> { 545 | let params = self.params; 546 | let g_exp = build_gadget(params, 1, m_exp); 547 | let g_exp_ntt = g_exp.ntt(); 548 | let mut res = Vec::new(); 549 | 550 | for i in 0..num_exp { 551 | let t = (params.poly_len / (1 << i)) + 1; 552 | let tau_sk_reg = automorph_alloc(&self.sk_reg, t); 553 | let prod = &tau_sk_reg.ntt() * &g_exp_ntt; 554 | let w_exp_i = self.encrypt_matrix_reg(&prod, rng, rng_pub); 555 | res.push(w_exp_i); 556 | } 557 | res 558 | } 559 | 560 | pub fn negacyclic_sk_reg(&self) -> PolyMatrixRaw<'a> { 561 | let mut sk_reg = self.get_sk_reg().clone(); 562 | 563 | (&mut sk_reg.as_mut_slice()[1..]).reverse(); 564 | for z in 1..sk_reg.data.len() { 565 | sk_reg.data[z] = self.params.modulus - sk_reg.data[z]; 566 | } 567 | 568 | sk_reg 569 | } 570 | 571 | pub fn generate_keys_from_seed(&mut self, seed: Seed) -> PublicParameters<'a> { 572 | self.generate_keys_impl(&mut ChaCha20Rng::from_seed(seed)) 573 | } 574 | 575 | pub fn generate_keys(&mut self) -> PublicParameters<'a> { 576 | self.generate_keys_impl(&mut ChaCha20Rng::from_entropy()) 577 | } 578 | 579 | pub fn generate_secret_keys_from_seed(&mut self, seed: Seed) { 580 | self.generate_secret_keys_impl(&mut ChaCha20Rng::from_seed(seed)) 581 | } 582 | 583 | pub fn generate_secret_keys(&mut self) { 584 | self.generate_secret_keys_impl(&mut ChaCha20Rng::from_entropy()) 585 | } 586 | 587 | pub fn generate_keys_optional( 588 | &mut self, 589 | seed: Seed, 590 | generate_pub_params: bool, 591 | ) -> Option> { 592 | if generate_pub_params { 593 | Some(self.generate_keys_from_seed(seed).serialize()) 594 | } else { 595 | self.generate_secret_keys_from_seed(seed); 596 | None 597 | } 598 | } 599 | 600 | fn generate_secret_keys_impl(&mut self, rng: &mut ChaCha20Rng) { 601 | gen_ternary_mat(&mut self.sk_gsw, HAMMING_WEIGHT, rng); 602 | gen_ternary_mat(&mut self.sk_reg, HAMMING_WEIGHT, rng); 603 | self.sk_gsw_full = matrix_with_identity(&self.sk_gsw); 604 | self.sk_reg_full = matrix_with_identity(&self.sk_reg); 605 | } 606 | 607 | fn generate_keys_impl(&mut self, rng: &mut ChaCha20Rng) -> PublicParameters<'a> { 608 | let params = self.params; 609 | 610 | self.generate_secret_keys_impl(rng); 611 | let sk_reg_ntt = to_ntt_alloc(&self.sk_reg); 612 | let sk_gsw_ntt = to_ntt_alloc(&self.sk_gsw); 613 | 614 | let mut rng = ChaCha20Rng::from_entropy(); 615 | let mut pp = PublicParameters::init(params); 616 | let pp_seed = rng.gen(); 617 | pp.seed = Some(pp_seed); 618 | let mut rng_pub = ChaCha20Rng::from_seed(pp_seed); 619 | 620 | // Params for packing 621 | let gadget_conv = build_gadget(params, 1, params.t_conv); 622 | let gadget_conv_ntt = to_ntt_alloc(&gadget_conv); 623 | let num_packing_mats = if params.version == 0 { params.n } else { 1 }; 624 | for i in 0..num_packing_mats { 625 | let scaled = scalar_multiply_alloc(&sk_reg_ntt, &gadget_conv_ntt); 626 | let mut ag = PolyMatrixNTT::zero(params, params.n, params.t_conv); 627 | ag.copy_into(&scaled, i, 0); 628 | let w = self.encrypt_matrix_gsw(&ag, &mut rng, &mut rng_pub); 629 | pp.v_packing.push(w); 630 | } 631 | 632 | if params.version > 0 { 633 | let scaled = &sk_gsw_ntt * &gadget_conv_ntt; 634 | let scaled_rotated = shift_rows_by_one(&scaled); 635 | let w = self.encrypt_matrix_gsw(&scaled_rotated, &mut rng, &mut rng_pub); 636 | pp.v_packing.push(w); 637 | } 638 | 639 | if params.expand_queries { 640 | // Params for expansion 641 | pp.v_expansion_left = Some(self.generate_expansion_params( 642 | params.g().min(MAX_EXP_DIM), 643 | params.t_exp_left, 644 | &mut rng, 645 | &mut rng_pub, 646 | )); 647 | 648 | if params.version == 0 || params.t_exp_right != params.t_exp_left { 649 | pp.v_expansion_right = Some(self.generate_expansion_params( 650 | params.stop_round() + 1, 651 | params.t_exp_right, 652 | &mut rng, 653 | &mut rng_pub, 654 | )); 655 | } else { 656 | pp.v_expansion_right = None; 657 | } 658 | 659 | // Params for converison 660 | let g_conv = build_gadget(params, 2, 2 * params.t_conv); 661 | let sk_reg_ntt = self.sk_reg.ntt(); 662 | let sk_reg_squared_ntt = &sk_reg_ntt * &sk_reg_ntt; 663 | pp.v_conversion = Some(Vec::from_iter(once(PolyMatrixNTT::zero( 664 | params, 665 | 2, 666 | 2 * params.t_conv, 667 | )))); 668 | for i in 0..2 * params.t_conv { 669 | let sigma; 670 | if i % 2 == 0 { 671 | let val = g_conv.get_poly(0, i)[0]; 672 | sigma = &sk_reg_squared_ntt * &single_poly(params, val).ntt(); 673 | } else { 674 | let val = g_conv.get_poly(1, i)[0]; 675 | sigma = &sk_reg_ntt * &single_poly(params, val).ntt(); 676 | } 677 | let ct = self.encrypt_matrix_reg(&sigma, &mut rng, &mut rng_pub); 678 | pp.v_conversion.as_mut().unwrap()[0].copy_into(&ct, 0, i); 679 | } 680 | } 681 | 682 | pp 683 | } 684 | 685 | pub fn generate_query(&self, idx_target: usize) -> Query<'a> { 686 | let params = self.params; 687 | let further_dims = params.db_dim_2; 688 | let idx_dim0 = idx_target / (1 << further_dims); 689 | let idx_further = idx_target % (1 << further_dims); 690 | let scale_k = params.modulus / params.pt_modulus; 691 | let bits_per = get_bits_per(params, params.t_gsw); 692 | 693 | let mut rng = ChaCha20Rng::from_entropy(); 694 | 695 | let mut query = Query::empty(); 696 | let query_seed = ChaCha20Rng::from_entropy().gen(); 697 | query.seed = Some(query_seed); 698 | let mut rng_pub = ChaCha20Rng::from_seed(query_seed); 699 | if params.expand_queries { 700 | // pack query into (possibly several) ciphertexts 701 | 702 | let num_query_cts = if params.db_dim_1 > MAX_EXP_DIM { 703 | 1 << (params.db_dim_1 - MAX_EXP_DIM) 704 | } else { 705 | 1 706 | }; 707 | 708 | println!("num_query_cts: {}", num_query_cts); 709 | 710 | let mut cts = Vec::new(); 711 | for idx_query in 0..num_query_cts { 712 | let num_exps = params.db_dim_1.min(MAX_EXP_DIM); 713 | 714 | let innner_idx_dim0 = if idx_query == (idx_dim0 / (1 << num_exps)) { 715 | (idx_dim0 % (1 << num_exps)) as u64 716 | } else { 717 | u64::MAX 718 | }; 719 | 720 | let mut sigma = PolyMatrixRaw::zero(params, 1, 1); 721 | let inv_2_g_first = invert_uint_mod(1 << num_exps, params.modulus).unwrap(); 722 | let inv_2_g_rest = 723 | invert_uint_mod(1 << (params.stop_round() + 1), params.modulus).unwrap(); 724 | 725 | if params.db_dim_2 == 0 || params.version == 3 { 726 | for i in 0..(1 << num_exps) { 727 | sigma.data[i].conditional_assign( 728 | &scale_k, 729 | (i as u64).ct_eq(&(innner_idx_dim0 as u64)), 730 | ) 731 | } 732 | 733 | for i in 0..params.poly_len { 734 | sigma.data[i] = 735 | multiply_uint_mod(sigma.data[i], inv_2_g_first, params.modulus); 736 | } 737 | } else { 738 | for i in 0..(1 << params.db_dim_1) { 739 | sigma.data[2 * i].conditional_assign( 740 | &scale_k, 741 | (i as u64).ct_eq(&(innner_idx_dim0 as u64)), 742 | ) 743 | } 744 | 745 | for i in 0..further_dims as u64 { 746 | let mask = 1 << i; 747 | let bit = ((idx_further as u64) & mask).ct_eq(&mask); 748 | for j in 0..params.t_gsw { 749 | let val = u64::conditional_select(&0, &(1u64 << (bits_per * j)), bit); 750 | let idx = (i as usize) * params.t_gsw + (j as usize); 751 | sigma.data[2 * idx + 1] = val; 752 | } 753 | } 754 | 755 | for i in 0..params.poly_len / 2 { 756 | sigma.data[2 * i] = 757 | multiply_uint_mod(sigma.data[2 * i], inv_2_g_first, params.modulus); 758 | sigma.data[2 * i + 1] = 759 | multiply_uint_mod(sigma.data[2 * i + 1], inv_2_g_rest, params.modulus); 760 | } 761 | } 762 | 763 | cts.push( 764 | self.encrypt_matrix_reg(&sigma.ntt(), &mut rng, &mut rng_pub) 765 | .raw(), 766 | ); 767 | } 768 | query.v_ct = Some(cts); 769 | } else { 770 | // Upload only GSW ciphertexts 771 | assert_eq!(further_dims, 0); 772 | 773 | let mut sigma_v = Vec::::new(); 774 | 775 | // generate GSW ciphertexts 776 | for i in 0..params.db_dim_1 { 777 | let rev_i = (params.db_dim_1 as u64) - 1 - (i as u64); 778 | let bit = ((idx_dim0 as u64) & (1 << rev_i)) >> rev_i; 779 | let mut ct_gsw = PolyMatrixNTT::zero(¶ms, 2, 2 * params.t_gsw); 780 | 781 | for j in 0..params.t_gsw { 782 | let value = (1u64 << (bits_per * j)) * bit; 783 | let sigma = PolyMatrixRaw::single_value(¶ms, value); 784 | let sigma_ntt = sigma.ntt(); 785 | let ct = self.encrypt_matrix_reg( 786 | &sigma_ntt, 787 | &mut ChaCha20Rng::from_entropy(), 788 | &mut ChaCha20Rng::from_entropy(), 789 | ); 790 | ct_gsw.copy_into(&ct, 0, 2 * j + 1); 791 | let prod = &self.get_sk_reg().ntt() * &sigma_ntt; 792 | let ct = &self.encrypt_matrix_reg( 793 | &prod, 794 | &mut ChaCha20Rng::from_entropy(), 795 | &mut ChaCha20Rng::from_entropy(), 796 | ); 797 | ct_gsw.copy_into(&ct, 0, 2 * j); 798 | } 799 | 800 | sigma_v.push(ct_gsw); 801 | 802 | // for j in 0..params.t_gsw { 803 | // let value = (1u64 << (bits_per * j)) * bit; 804 | // let sigma = PolyMatrixRaw::single_value(¶ms, value); 805 | // let sigma_ntt = to_ntt_alloc(&sigma); 806 | 807 | // // important to rng in the right order here 808 | // let prod = &to_ntt_alloc(&self.sk_reg) * &sigma_ntt; 809 | // let ct = &self.encrypt_matrix_reg(&prod, &mut rng, &mut rng_pub); 810 | // ct_gsw.copy_into(ct, 0, 2 * j); 811 | 812 | // let ct = &self.encrypt_matrix_reg(&sigma_ntt, &mut rng, &mut rng_pub); 813 | // ct_gsw.copy_into(ct, 0, 2 * j + 1); 814 | // } 815 | // sigma_v.push(ct_gsw); 816 | } 817 | 818 | query.v_ct = Some(sigma_v.iter().map(|x| from_ntt_alloc(x)).collect()); 819 | 820 | assert_eq!(query.v_ct.as_ref().unwrap().len(), params.db_dim_1); 821 | } 822 | query 823 | } 824 | 825 | pub fn generate_full_query(&self, id: &str, idx_target: usize) -> Vec { 826 | assert_eq!(id.len(), UUID_V4_LEN); 827 | let query = self.generate_query(idx_target); 828 | let mut query_buf = query.serialize(); 829 | let mut full_query_buf = id.as_bytes().to_vec(); 830 | full_query_buf.append(&mut query_buf); 831 | full_query_buf 832 | } 833 | 834 | pub fn decode_response(&self, data: &[u8], num_cts: usize) -> Vec { 835 | /* 836 | 0. NTT over q2 the secret key 837 | 838 | 1. read first row in q2_bit chunks 839 | 2. read rest in q1_bit chunks 840 | 3. NTT over q2 the first row 841 | 4. Multiply the results of (0) and (3) 842 | 5. Divide and round correctly 843 | */ 844 | let params = self.params; 845 | let p = params.pt_modulus; 846 | let p_bits = log2_ceil(params.pt_modulus); 847 | let q1 = 4 * params.pt_modulus; 848 | let q1_bits = log2_ceil(q1) as usize; 849 | let q2 = Q2_VALUES[params.q2_bits as usize]; 850 | let q2_bits = params.q2_bits as usize; 851 | 852 | let q2_params = params_with_moduli(params, &vec![q2]); 853 | 854 | // this only needs to be done during keygen 855 | let mut sk_q2 = PolyMatrixRaw::zero(&q2_params, params.n, 1); 856 | let key_to_use = if params.db_dim_2 > 0 && params.version != 3 { 857 | &self.sk_gsw 858 | } else { 859 | &self.sk_reg 860 | }; 861 | for i in 0..params.poly_len * params.n { 862 | sk_q2.data[i] = recenter(key_to_use.data[i], params.modulus, q2); 863 | } 864 | let mut sk_q2_ntt = PolyMatrixNTT::zero(&q2_params, params.n, 1); 865 | to_ntt(&mut sk_q2_ntt, &sk_q2); 866 | 867 | let mut result = PolyMatrixRaw::zero(¶ms, num_cts * params.n, params.n); 868 | 869 | let mut bit_offs = 0; 870 | for instance in 0..num_cts { 871 | // this must be done during decoding 872 | let mut first_row = PolyMatrixRaw::zero(&q2_params, 1, params.n); 873 | let mut rest_rows = PolyMatrixRaw::zero(¶ms, params.n, params.n); 874 | for i in 0..params.n * params.poly_len { 875 | first_row.data[i] = read_arbitrary_bits(data, bit_offs, q2_bits); 876 | bit_offs += q2_bits; 877 | } 878 | for i in 0..params.n * params.n * params.poly_len { 879 | rest_rows.data[i] = read_arbitrary_bits(data, bit_offs, q1_bits); 880 | bit_offs += q1_bits; 881 | } 882 | 883 | let mut first_row_q2 = PolyMatrixNTT::zero(&q2_params, 1, params.n); 884 | to_ntt(&mut first_row_q2, &first_row); 885 | 886 | let sk_prod = (&sk_q2_ntt * &first_row_q2).raw(); 887 | 888 | let q1_i64 = q1 as i64; 889 | let q2_i64 = q2 as i64; 890 | let p_i128 = p as i128; 891 | for i in 0..params.n * params.n * params.poly_len { 892 | let mut val_first = sk_prod.data[i] as i64; 893 | if val_first >= q2_i64 / 2 { 894 | val_first -= q2_i64; 895 | } 896 | let mut val_rest = rest_rows.data[i] as i64; 897 | if val_rest >= q1_i64 / 2 { 898 | val_rest -= q1_i64; 899 | } 900 | 901 | let denom = (q2 * (q1 / p)) as i64; 902 | 903 | let mut r = val_first * q1_i64; 904 | r += val_rest * q2_i64; 905 | 906 | // divide r by q2, rounding 907 | let sign: i64 = if r >= 0 { 1 } else { -1 }; 908 | let mut res = ((r + sign * (denom / 2)) as i128) / (denom as i128); 909 | res = (res + (denom as i128 / p_i128) * (p_i128) + 2 * (p_i128)) % (p_i128); 910 | let idx = instance * params.n * params.n * params.poly_len + i; 911 | result.data[idx] = res as u64; 912 | } 913 | } 914 | 915 | // println!("{:?}", result.data.as_slice().to_vec()); 916 | // result.to_vec(p_bits as usize, params.modp_words_per_chunk()) 917 | result.to_vec(p_bits as usize, params.poly_len) 918 | } 919 | } 920 | 921 | #[cfg(test)] 922 | mod test { 923 | use super::*; 924 | 925 | fn get_params() -> Params { 926 | get_short_keygen_params() 927 | } 928 | 929 | #[test] 930 | fn init_is_correct() { 931 | let params = get_params(); 932 | let client = Client::init(¶ms); 933 | 934 | assert_eq!(*client.params, params); 935 | } 936 | 937 | #[test] 938 | fn keygen_is_correct() { 939 | let params = get_params(); 940 | let mut client = Client::init(¶ms); 941 | 942 | _ = client.generate_keys(); 943 | 944 | let threshold = (10.0 * params.noise_width) as u64; 945 | 946 | for i in 0..client.sk_gsw.data.len() { 947 | let val = client.sk_gsw.data[i]; 948 | assert!((val < threshold) || ((params.modulus - val) < threshold)); 949 | } 950 | } 951 | 952 | fn get_vec(v: &Vec) -> Vec { 953 | v.iter().map(|d| d.as_slice().to_vec()).flatten().collect() 954 | } 955 | 956 | fn public_parameters_serialization_is_correct_for_params(params: Params) { 957 | let mut client = Client::init(¶ms); 958 | let pub_params = client.generate_keys(); 959 | 960 | let serialized1 = pub_params.serialize(); 961 | let deserialized1 = PublicParameters::deserialize(¶ms, &serialized1); 962 | 963 | assert_eq!( 964 | get_vec(&pub_params.v_packing), 965 | get_vec(&deserialized1.v_packing) 966 | ); 967 | 968 | println!( 969 | "packing mats (bytes) {}", 970 | get_vec(&pub_params.v_packing).len() * 8 971 | ); 972 | println!("total size (bytes) {}", serialized1.len()); 973 | if pub_params.v_conversion.is_some() { 974 | let l1 = get_vec(&pub_params.v_conversion.unwrap()); 975 | assert_eq!(l1, get_vec(&deserialized1.v_conversion.unwrap())); 976 | println!("conv mats (bytes) {}", l1.len() * 8); 977 | } 978 | if pub_params.v_expansion_left.is_some() { 979 | let l1 = get_vec(&pub_params.v_expansion_left.unwrap()); 980 | assert_eq!(l1, get_vec(&deserialized1.v_expansion_left.unwrap())); 981 | println!("exp left (bytes) {}", l1.len() * 8); 982 | } 983 | if pub_params.v_expansion_right.is_some() { 984 | let l1 = get_vec(&pub_params.v_expansion_right.unwrap()); 985 | assert_eq!(l1, get_vec(&deserialized1.v_expansion_right.unwrap())); 986 | println!("exp right (bytes) {}", l1.len() * 8); 987 | } 988 | } 989 | 990 | #[test] 991 | fn public_parameters_serialization_is_correct() { 992 | public_parameters_serialization_is_correct_for_params(get_params()) 993 | } 994 | 995 | #[test] 996 | fn real_public_parameters_serialization_is_correct() { 997 | let cfg_expand = r#" 998 | {'n': 2, 999 | 'nu_1': 10, 1000 | 'nu_2': 6, 1001 | 'p': 512, 1002 | 'q2_bits': 21, 1003 | 's_e': 85.83255142749422, 1004 | 't_gsw': 10, 1005 | 't_conv': 4, 1006 | 't_exp_left': 16, 1007 | 't_exp_right': 56, 1008 | 'instances': 11, 1009 | 'db_item_size': 100000 } 1010 | "#; 1011 | let cfg = cfg_expand.replace("'", "\""); 1012 | let params = params_from_json(&cfg); 1013 | public_parameters_serialization_is_correct_for_params(params) 1014 | } 1015 | 1016 | #[test] 1017 | fn real_public_parameters_2_serialization_is_correct() { 1018 | let cfg = r#" 1019 | { "n": 4, 1020 | "nu_1": 9, 1021 | "nu_2": 5, 1022 | "p": 256, 1023 | "q2_bits": 20, 1024 | "t_gsw": 8, 1025 | "t_conv": 4, 1026 | "t_exp_left": 8, 1027 | "t_exp_right": 56, 1028 | "instances": 2, 1029 | "db_item_size": 65536 } 1030 | "#; 1031 | let params = params_from_json(&cfg); 1032 | public_parameters_serialization_is_correct_for_params(params) 1033 | } 1034 | 1035 | #[test] 1036 | fn no_expansion_public_parameters_serialization_is_correct() { 1037 | public_parameters_serialization_is_correct_for_params(get_no_expansion_testing_params()) 1038 | } 1039 | 1040 | fn query_serialization_is_correct_for_params(params: Params) { 1041 | let mut client = Client::init(¶ms); 1042 | _ = client.generate_keys(); 1043 | let query = client.generate_query(1); 1044 | 1045 | let serialized1 = query.serialize(); 1046 | let deserialized1 = Query::deserialize(¶ms, &serialized1); 1047 | let serialized2 = deserialized1.serialize(); 1048 | 1049 | assert_eq!(serialized1.len(), serialized2.len()); 1050 | for i in 0..serialized1.len() { 1051 | assert_eq!(serialized1[i], serialized2[i], "at {}", i); 1052 | } 1053 | } 1054 | 1055 | #[test] 1056 | fn query_serialization_is_correct() { 1057 | query_serialization_is_correct_for_params(get_params()) 1058 | } 1059 | 1060 | #[test] 1061 | fn no_expansion_query_serialization_is_correct() { 1062 | query_serialization_is_correct_for_params(get_no_expansion_testing_params()) 1063 | } 1064 | } 1065 | -------------------------------------------------------------------------------- /src/discrete_gaussian.rs: -------------------------------------------------------------------------------- 1 | use rand::distributions::WeightedIndex; 2 | use rand::prelude::Distribution; 3 | use rand::Rng; 4 | use rand_chacha::ChaCha20Rng; 5 | use subtle::ConditionallySelectable; 6 | use subtle::ConstantTimeGreater; 7 | 8 | use crate::poly::*; 9 | use std::f64::consts::PI; 10 | 11 | pub const NUM_WIDTHS: usize = 4; 12 | 13 | /// Table of u64 values representing a Gaussian of width 6.4 14 | /// (standard deviation = 6.4/sqrt(2*pi)) 15 | /// 16 | /// This is the cumulative distribution function of this distribution, 17 | /// in the range [-26, 26], multiplied by 2^64. Values exactly equal to 2^64 have 18 | /// been replaced with 2^64-1, for representation as u64's. 19 | // const CDF_TABLE_GAUS_6_4: [u64; 53] = [ 20 | // 0, 21 | // 0, 22 | // 0, 23 | // 7, 24 | // 225, 25 | // 6114, 26 | // 142809, 27 | // 2864512, 28 | // 49349166, 29 | // 730367088, 30 | // 9288667698, 31 | // 101545086850, 32 | // 954617134063, 33 | // 7720973857474, 34 | // 53757667977838, 35 | // 322436486442815, 36 | // 1667499996257361, 37 | // 7443566871362048, 38 | // 28720140744863884, 39 | // 95948302954529081, 40 | // 278161926109627739, 41 | // 701795634139702303, 42 | // 1546646853635104741, 43 | // 2991920295851131431, 44 | // 5112721055115151939, 45 | // 7782220156096217088, 46 | // 10664523917613334528, 47 | // 13334023018594399677, 48 | // 15454823777858420185, 49 | // 16900097220074446875, 50 | // 17744948439569849313, 51 | // 18168582147599923877, 52 | // 18350795770755022535, 53 | // 18418023932964687732, 54 | // 18439300506838189568, 55 | // 18445076573713294255, 56 | // 18446421637223108801, 57 | // 18446690316041573778, 58 | // 18446736352735694142, 59 | // 18446743119092417553, 60 | // 18446743972164464766, 61 | // 18446744064420883918, 62 | // 18446744072979184528, 63 | // 18446744073660202450, 64 | // 18446744073706687104, 65 | // 18446744073709408807, 66 | // 18446744073709545502, 67 | // 18446744073709551391, 68 | // 18446744073709551609, 69 | // 18446744073709551615, 70 | // 18446744073709551615, 71 | // 18446744073709551615, 72 | // 18446744073709551615, 73 | // ]; 74 | 75 | pub struct DiscreteGaussian { 76 | pub weighted_index: WeightedIndex, 77 | pub cdf_table: Vec, 78 | pub max_val: i64, 79 | } 80 | 81 | impl DiscreteGaussian { 82 | pub fn init(noise_width: f64) -> Self { 83 | let max_val = (noise_width * (NUM_WIDTHS as f64)).ceil() as i64; 84 | let mut table = Vec::new(); 85 | let mut total = 0.0; 86 | 87 | // assign discrete probabilities to each possible integer output 88 | for i in -max_val..max_val + 1 { 89 | let p_val = f64::exp(-PI * f64::powi(i as f64, 2) / f64::powi(noise_width, 2)); 90 | table.push(p_val); 91 | total += p_val; 92 | } 93 | 94 | // build a CDF table for possible outputs 95 | let mut cdf_table = Vec::new(); 96 | let mut cum_prob = 0.0; 97 | 98 | for p_val in &table { 99 | cum_prob += p_val / total; 100 | let cum_prob_u64 = (cum_prob * (u64::MAX as f64)).round() as u64; 101 | cdf_table.push(cum_prob_u64); 102 | } 103 | 104 | Self { 105 | weighted_index: WeightedIndex::new(table).unwrap(), 106 | cdf_table, 107 | max_val, 108 | } 109 | } 110 | 111 | pub fn sample(&self, modulus: u64, rng: &mut ChaCha20Rng) -> u64 { 112 | let sampled_val = rng.gen::(); 113 | let len = (2 * self.max_val + 1) as usize; 114 | let mut to_output = 0; 115 | 116 | for i in (0..len).rev() { 117 | let mut out_val = (i as i64) - self.max_val; 118 | // this branch is ok: not secret-dependent 119 | if out_val < 0 { 120 | out_val += modulus as i64; 121 | } 122 | let out_val_u64 = out_val as u64; 123 | 124 | // let point = CDF_TABLE_GAUS_6_4[i]; 125 | let point = self.cdf_table[i]; 126 | 127 | // if sampled_val <= point, set to_output := out_val 128 | // (in constant time) 129 | let cmp = !(sampled_val.ct_gt(&point)); 130 | to_output.conditional_assign(&out_val_u64, cmp); 131 | } 132 | to_output 133 | } 134 | 135 | /// Sample from a discrete Gaussian distribution. THIS IS NOT CONSTANT TIME! 136 | pub fn fast_sample(&self, modulus: u64, rng: &mut ChaCha20Rng) -> u64 { 137 | let sampled_val = self.weighted_index.sample(rng); 138 | let mut val = (sampled_val as i64) - self.max_val; 139 | if val < 0 { 140 | val += modulus as i64; 141 | } 142 | val as u64 143 | } 144 | 145 | pub fn sample_matrix(&self, p: &mut PolyMatrixRaw, rng: &mut ChaCha20Rng) { 146 | let modulus = p.get_params().modulus; 147 | for r in 0..p.rows { 148 | for c in 0..p.cols { 149 | let poly = p.get_poly_mut(r, c); 150 | for z in 0..poly.len() { 151 | let s = self.sample(modulus, rng); 152 | poly[z] = s; 153 | } 154 | } 155 | } 156 | } 157 | } 158 | 159 | #[cfg(test)] 160 | mod test { 161 | use super::*; 162 | use crate::util::*; 163 | 164 | #[test] 165 | fn dg_seems_okay() { 166 | let params = get_test_params(); 167 | let dg = DiscreteGaussian::init(params.noise_width); 168 | let mut rng = get_chacha_rng(); 169 | let mut v = Vec::new(); 170 | let trials = 10000; 171 | let mut sum = 0; 172 | for _ in 0..trials { 173 | let val = dg.sample(params.modulus, &mut rng); 174 | let mut val_i64 = val as i64; 175 | if val_i64 >= (params.modulus as i64) / 2 { 176 | val_i64 -= params.modulus as i64; 177 | } 178 | v.push(val_i64); 179 | sum += val_i64; 180 | } 181 | let expected_mean = 0; 182 | let computed_mean = sum as f64 / trials as f64; 183 | let expected_std_dev = params.noise_width / f64::sqrt(2f64 * std::f64::consts::PI); 184 | let std_dev_of_mean = expected_std_dev / f64::sqrt(trials as f64); 185 | println!("mean:: expected: {}, got: {}", expected_mean, computed_mean); 186 | assert!(f64::abs(computed_mean) < std_dev_of_mean * 5f64); 187 | 188 | let computed_variance: f64 = v 189 | .iter() 190 | .map(|x| (computed_mean - (*x as f64)).powi(2)) 191 | .sum::() 192 | / (v.len() as f64); 193 | let computed_std_dev = computed_variance.sqrt(); 194 | println!( 195 | "std_dev:: expected: {}, got: {}", 196 | expected_std_dev, computed_std_dev 197 | ); 198 | assert!((computed_std_dev - expected_std_dev).abs() < (expected_std_dev * 0.1)); 199 | } 200 | 201 | #[test] 202 | fn cdf_table_seems_okay() { 203 | let dg = DiscreteGaussian::init(6.4); 204 | println!("{:?}", dg.cdf_table); 205 | } 206 | } 207 | -------------------------------------------------------------------------------- /src/gadget.rs: -------------------------------------------------------------------------------- 1 | use crate::{params::*, poly::*}; 2 | 3 | pub fn get_bits_per(params: &Params, dim: usize) -> usize { 4 | let modulus_log2 = params.modulus_log2; 5 | if dim as u64 == modulus_log2 { 6 | return 1; 7 | } 8 | ((modulus_log2 as f64) / (dim as f64)).floor() as usize + 1 9 | } 10 | 11 | pub fn build_gadget(params: &Params, rows: usize, cols: usize) -> PolyMatrixRaw { 12 | let mut g = PolyMatrixRaw::zero(params, rows, cols); 13 | let nx = g.rows; 14 | let m = g.cols; 15 | 16 | assert_eq!(m % nx, 0); 17 | 18 | let num_elems = m / nx; 19 | let params = g.params; 20 | let bits_per = get_bits_per(params, num_elems); 21 | 22 | for i in 0..nx { 23 | for j in 0..num_elems { 24 | if bits_per * j >= 64 { 25 | continue; 26 | } 27 | let poly = g.get_poly_mut(i, i + j * nx); 28 | poly[0] = 1u64 << (bits_per * j); 29 | } 30 | } 31 | g 32 | } 33 | 34 | pub fn gadget_invert_rdim<'a>(out: &mut PolyMatrixRaw<'a>, inp: &PolyMatrixRaw<'a>, rdim: usize) { 35 | assert_eq!(out.cols, inp.cols); 36 | 37 | let params = inp.params; 38 | let mx = out.rows; 39 | let num_elems = mx / rdim; 40 | let bits_per = get_bits_per(params, num_elems); 41 | let mask = (1u64 << bits_per) - 1; 42 | 43 | for j in 0..rdim { 44 | for i in 0..inp.cols { 45 | for z in 0..params.poly_len { 46 | let val = inp.get_poly(j, i)[z]; 47 | for k in 0..num_elems { 48 | let bit_offs = k * bits_per; 49 | let piece = if bit_offs >= 64 { 50 | 0 51 | } else { 52 | (val >> bit_offs) & mask 53 | }; 54 | 55 | out.get_poly_mut(j + k * rdim, i)[z] = piece; 56 | } 57 | } 58 | } 59 | } 60 | } 61 | 62 | pub fn gadget_invert<'a>(out: &mut PolyMatrixRaw<'a>, inp: &PolyMatrixRaw<'a>) { 63 | gadget_invert_rdim(out, inp, inp.rows); 64 | } 65 | 66 | pub fn gadget_invert_alloc<'a>(mx: usize, inp: &PolyMatrixRaw<'a>) -> PolyMatrixRaw<'a> { 67 | let mut out = PolyMatrixRaw::zero(inp.params, mx, inp.cols); 68 | gadget_invert(&mut out, inp); 69 | out 70 | } 71 | 72 | #[cfg(test)] 73 | mod test { 74 | use crate::util::get_test_params; 75 | 76 | use super::*; 77 | 78 | #[test] 79 | fn gadget_invert_is_correct() { 80 | let params = get_test_params(); 81 | let mut mat = PolyMatrixRaw::zero(¶ms, 2, 1); 82 | mat.get_poly_mut(0, 0)[37] = 3; 83 | mat.get_poly_mut(1, 0)[37] = 6; 84 | let log_q = params.modulus_log2 as usize; 85 | let result = gadget_invert_alloc(2 * log_q, &mat); 86 | 87 | assert_eq!(result.get_poly(0, 0)[37], 1); 88 | assert_eq!(result.get_poly(2, 0)[37], 1); 89 | assert_eq!(result.get_poly(4, 0)[37], 0); // binary for '3' 90 | 91 | assert_eq!(result.get_poly(1, 0)[37], 0); 92 | assert_eq!(result.get_poly(3, 0)[37], 1); 93 | assert_eq!(result.get_poly(5, 0)[37], 1); 94 | assert_eq!(result.get_poly(7, 0)[37], 0); // binary for '6' 95 | } 96 | } 97 | -------------------------------------------------------------------------------- /src/key_value.rs: -------------------------------------------------------------------------------- 1 | use crate::params::Params; 2 | use sha2::{Digest, Sha256}; 3 | 4 | const VARINT_MAX_BYTES: usize = 8; 5 | const MAX_VARINT_BITS: u64 = 63; 6 | 7 | pub fn varint_decode(data: &[u8]) -> (usize, usize) { 8 | let mut shift = 0u64; 9 | let mut result = 0u64; 10 | let mut j = 0; 11 | 12 | while shift < MAX_VARINT_BITS { 13 | let i = data[j] as u64; 14 | j += 1; 15 | result |= (i & 0x7f) << shift; 16 | shift += 7; 17 | if i & 0x80 == 0 { 18 | break; 19 | } 20 | } 21 | 22 | (result as usize, j) 23 | } 24 | 25 | pub fn row_from_key(params: &Params, key: &str) -> usize { 26 | let num_items = params.num_items(); 27 | let buckets_log2 = (num_items as f64).log2().ceil() as usize; 28 | 29 | let hash = Sha256::digest(key.as_bytes()); 30 | 31 | // let idx = read_arbitrary_bits(&hash, 0, buckets_log2) as usize; 32 | let mut idx = 0; 33 | for i in 0..buckets_log2 { 34 | let cond = hash[i / 8] & (1 << (7 - (i % 8))); 35 | if cond != 0 { 36 | idx += 1 << (buckets_log2 - i - 1); 37 | } 38 | } 39 | idx 40 | } 41 | 42 | pub fn extract_result_impl(key: &str, result: &[u8]) -> Result, &'static str> { 43 | let hash_bytes = result[0] as usize; 44 | let hash = Sha256::digest(key.as_bytes()); 45 | let target = &hash[(hash.len() - hash_bytes)..]; 46 | let mut i = 1; 47 | while i < result.len() { 48 | // read key 49 | let key_hash = &result[i..i + hash_bytes]; 50 | i += hash_bytes; 51 | 52 | // read len 53 | let (value_len, value_len_len) = varint_decode(&result[i..i + VARINT_MAX_BYTES]); 54 | i += value_len_len; 55 | 56 | // read value 57 | let value = &result[i..i + value_len]; 58 | i += value_len; 59 | 60 | if key_hash == target { 61 | return Ok(value.to_vec()); 62 | } 63 | } 64 | 65 | Err("key not found") 66 | } 67 | 68 | #[cfg(test)] 69 | mod test { 70 | use super::*; 71 | 72 | use crate::util::*; 73 | 74 | fn get_params() -> Params { 75 | params_from_json( 76 | r#"{ 77 | "n": 4, 78 | "nu_1": 9, 79 | "nu_2": 5, 80 | "p": 256, 81 | "q2_bits": 20, 82 | "t_gsw": 8, 83 | "t_conv": 4, 84 | "t_exp_left": 8, 85 | "t_exp_right": 56, 86 | "instances": 2, 87 | "db_item_size": 65536 88 | }"#, 89 | ) 90 | } 91 | 92 | #[test] 93 | fn row_from_key_is_correct() { 94 | let params = get_params(); 95 | assert_eq!(row_from_key(¶ms, "CA"), 4825); 96 | assert_eq!(row_from_key(¶ms, "OR"), 8359); 97 | } 98 | } 99 | -------------------------------------------------------------------------------- /src/lib.rs: -------------------------------------------------------------------------------- 1 | #![feature(stdarch_x86_avx512)] 2 | 3 | pub mod aligned_memory; 4 | pub mod arith; 5 | pub mod discrete_gaussian; 6 | pub mod noise_estimate; 7 | pub mod number_theory; 8 | pub mod util; 9 | 10 | pub mod gadget; 11 | pub mod ntt; 12 | pub mod params; 13 | pub mod poly; 14 | 15 | pub mod client; 16 | pub mod key_value; 17 | 18 | #[cfg(feature = "server")] 19 | pub mod server; 20 | -------------------------------------------------------------------------------- /src/noise_estimate.rs: -------------------------------------------------------------------------------- 1 | use std::f64::consts::{E, PI}; 2 | 3 | use crate::{ 4 | client::HAMMING_WEIGHT, 5 | params::{Params, Q2_VALUES}, 6 | }; 7 | 8 | // This a simplified subset of a Params instance 9 | pub struct Paramset { 10 | pub n: usize, 11 | pub d: usize, 12 | pub p: u64, 13 | pub q: u64, 14 | pub sigma: f64, 15 | pub t_conv: usize, 16 | pub t_exp_left: usize, 17 | pub t_exp_right: usize, 18 | pub t_gsw: usize, 19 | pub db_dim_1: usize, 20 | pub db_dim_2: usize, 21 | pub expand_queries: bool, 22 | } 23 | 24 | pub fn extract_paramset(params: &Params) -> Paramset { 25 | Paramset { 26 | n: params.n, 27 | d: params.poly_len, 28 | p: params.pt_modulus, 29 | q: params.modulus, 30 | sigma: params.noise_width, 31 | t_conv: params.t_conv, 32 | t_exp_left: params.t_exp_left, 33 | t_exp_right: params.t_exp_right, 34 | t_gsw: params.t_gsw, 35 | db_dim_1: params.db_dim_1, 36 | db_dim_2: params.db_dim_2, 37 | expand_queries: params.expand_queries, 38 | } 39 | } 40 | 41 | fn get_base(t: usize, q: u64) -> f64 { 42 | // f64::ceil((q as f64).powf(1. / t as f64)) 43 | let q_f = q as f64; 44 | let t_f = t as f64; 45 | let q_bits = f64::ceil(f64::log2(q_f)); 46 | 2f64.powf((q_bits / t_f).ceil()) 47 | } 48 | 49 | fn gadget_exp_factor(s: &Paramset, t: usize, z: f64) -> f64 { 50 | (t * s.d) as f64 * s.sigma.powi(2) * z.powi(2) / 4f64 51 | } 52 | 53 | pub fn get_noise_from_paramset(s: &Paramset) -> f64 { 54 | let nu1 = s.db_dim_1 as i32; 55 | let nu2 = s.db_dim_2 as i32; 56 | 57 | let n_used = 1; 58 | 59 | let z_gsw = get_base(s.t_gsw, s.q); 60 | let m_gsw = (n_used + 1) * s.t_gsw; 61 | let z_conv = get_base(s.t_conv, s.q); 62 | let z_exp_left = get_base(s.t_exp_left, s.q); 63 | let z_exp_right = get_base(s.t_exp_right, s.q); 64 | 65 | let num_exp_reg = s.db_dim_1 + 1; 66 | 67 | let mut sigma_reg_2 = s.sigma.powi(2); 68 | let mut sigma_gsw_2 = s.sigma.powi(2); 69 | 70 | if s.expand_queries { 71 | sigma_reg_2 = 4f64.powf(num_exp_reg as f64) 72 | * s.sigma.powi(2) 73 | // * (1.0 + ((s.d * s.t_exp_left) as f64 * z_exp_left.powi(2) / 3.)); 74 | * (1.0 + ((s.t_exp_left) as f64 * z_exp_left.powi(2) / 3.)); 75 | // NB: above, we exclude a factor of s.d; this is bad according to the paper, but 76 | // in practice, it seems to model the noise accurately 77 | 78 | let num_exp_gsw = f64::ceil(f64::log2((s.t_gsw as f64) * (nu2 as f64))) as i32 + 1; 79 | sigma_gsw_2 = 4f64.powi(num_exp_gsw) 80 | * s.sigma.powi(2) 81 | * (1.0 + ((s.t_exp_right) as f64 * z_exp_right.powi(2) / 3.)); 82 | sigma_gsw_2 = sigma_gsw_2 * 2. * (HAMMING_WEIGHT as f64) 83 | + 2. * gadget_exp_factor(s, s.t_conv, z_conv); 84 | } 85 | 86 | let sigma_0_2 = (2f64.powi(nu1)) 87 | * (n_used as f64) 88 | * (s.d as f64) 89 | * ((s.p as f64) / 2.).powi(2) 90 | * (sigma_reg_2); 91 | let sigma_rest = 92 | (nu2 as f64) * (s.d as f64) * (m_gsw as f64) * z_gsw.powi(2) / 2. * (sigma_gsw_2); 93 | let sigma_r_2 = sigma_0_2 + sigma_rest; 94 | 95 | let sigma_packing_2 = ((s.d * s.n * s.t_conv) as f64) * s.sigma.powi(2) * z_conv.powi(2) / 4.; 96 | 97 | sigma_r_2 + sigma_packing_2 98 | } 99 | 100 | pub fn get_p_err(s: &Paramset, s_e: f64, q_prime: u64) -> f64 { 101 | let p_f = s.p as f64; 102 | let q_prime_f = q_prime as f64; 103 | let q_f = s.q as f64; 104 | 105 | let q_mod_p = 1; 106 | let modswitch_adj = (1. / 8.) * ((4. * p_f) * (q_mod_p as f64) / q_f); 107 | let thresh = (1. / 4.) - modswitch_adj; 108 | assert!((thresh > 0.) && (thresh < (1. / 4.))); 109 | 110 | let s_round_2 = s.sigma.powi(2) * (s.d as f64) / 4.; 111 | let numer = -PI * thresh.powi(2); 112 | let denom = s_e * (p_f / q_f).powi(2) + (s_round_2) * (p_f / q_prime_f).powi(2); 113 | 114 | let p_single_err_log = f64::ln(2.) + (numer / denom); 115 | let p_err_log = p_single_err_log + f64::ln((s.n * s.n * s.d) as f64); 116 | let p_err = p_err_log * f64::log(E, 2.); 117 | p_err 118 | } 119 | 120 | pub trait NoiseEstimator { 121 | fn estimate_noise(&self) -> f64; 122 | fn estimate_log2_err_prob(&self) -> f64; 123 | } 124 | 125 | impl NoiseEstimator for Params { 126 | fn estimate_noise(&self) -> f64 { 127 | get_noise_from_paramset(&extract_paramset(self)) 128 | } 129 | 130 | fn estimate_log2_err_prob(&self) -> f64 { 131 | let q2 = Q2_VALUES[self.q2_bits as usize]; 132 | let paramset = extract_paramset(self); 133 | let s_e = self.estimate_noise(); 134 | get_p_err(¶mset, s_e, q2) 135 | } 136 | } 137 | 138 | #[cfg(test)] 139 | mod test { 140 | use crate::util::*; 141 | 142 | use super::*; 143 | 144 | #[test] 145 | fn get_noise_from_paramset_correct() { 146 | let cfg_expand = r#" 147 | { 148 | "n": 2, 149 | "nu_1": 9, 150 | "nu_2": 5, 151 | "p": 256, 152 | "q2_bits": 22, 153 | "t_gsw": 7, 154 | "t_conv": 3, 155 | "t_exp_left": 5, 156 | "t_exp_right": 5, 157 | "instances": 4, 158 | "db_item_size": 32768 159 | } 160 | "#; 161 | let params = params_from_json(cfg_expand); 162 | let s_e = params.estimate_noise(); 163 | let p_err = params.estimate_log2_err_prob(); 164 | let noise_log2 = f64::log2(s_e); 165 | println!("noise_log2: {}", noise_log2); 166 | println!("p_err: {}", p_err); 167 | println!("setup bytes: {}", params.setup_bytes()); 168 | // assert!(noise_log2 < 87.0); 169 | assert!(p_err <= -40.0); 170 | } 171 | } 172 | -------------------------------------------------------------------------------- /src/ntt.rs: -------------------------------------------------------------------------------- 1 | use std::arch::x86_64::*; 2 | 3 | use crate::{arith::*, number_theory::*, params::*}; 4 | 5 | pub fn powers_of_primitive_root(root: u64, modulus: u64, poly_len_log2: usize) -> Vec { 6 | let poly_len = 1usize << poly_len_log2; 7 | let mut root_powers = vec![0u64; poly_len]; 8 | let mut power = root; 9 | for i in 1..poly_len { 10 | let idx = reverse_bits(i as u64, poly_len_log2) as usize; 11 | root_powers[idx] = power; 12 | power = multiply_uint_mod(power, root, modulus); 13 | } 14 | root_powers[0] = 1; 15 | root_powers 16 | } 17 | 18 | pub fn scale_powers_u64(modulus: u64, poly_len: usize, inp: &[u64]) -> Vec { 19 | let mut scaled_powers = vec![0; poly_len]; 20 | for i in 0..poly_len { 21 | let wide_val = (inp[i] as u128) << 64u128; 22 | let quotient = wide_val / (modulus as u128); 23 | scaled_powers[i] = quotient as u64; 24 | } 25 | scaled_powers 26 | } 27 | 28 | pub fn scale_powers_u32(modulus: u32, poly_len: usize, inp: &[u64]) -> Vec { 29 | let mut scaled_powers = vec![0; poly_len]; 30 | for i in 0..poly_len { 31 | let wide_val = inp[i] << 32; 32 | let quotient = wide_val / (modulus as u64); 33 | scaled_powers[i] = (quotient as u32) as u64; 34 | } 35 | scaled_powers 36 | } 37 | 38 | pub fn build_ntt_tables_alt( 39 | poly_len: usize, 40 | moduli: &[u64], 41 | opt_roots: Option<&[u64]>, 42 | ) -> Vec>> { 43 | let poly_len_log2 = log2(poly_len as u64) as usize; 44 | let mut output: Vec>> = vec![Vec::new(); moduli.len()]; 45 | for coeff_mod in 0..moduli.len() { 46 | let modulus = moduli[coeff_mod]; 47 | let root = if let Some(roots) = opt_roots { 48 | roots[coeff_mod] 49 | } else { 50 | get_minimal_primitive_root(2 * poly_len as u64, modulus).unwrap() 51 | }; 52 | let inv_root = invert_uint_mod(root, modulus).unwrap(); 53 | 54 | let root_powers = powers_of_primitive_root(root, modulus, poly_len_log2); 55 | let scaled_root_powers = scale_powers_u64(modulus, poly_len, root_powers.as_slice()); 56 | let mut inv_root_powers = powers_of_primitive_root(inv_root, modulus, poly_len_log2); 57 | for i in 0..poly_len { 58 | inv_root_powers[i] = div2_uint_mod(inv_root_powers[i], modulus); 59 | } 60 | let scaled_inv_root_powers = 61 | scale_powers_u64(modulus, poly_len, inv_root_powers.as_slice()); 62 | 63 | output[coeff_mod] = vec![ 64 | root_powers, 65 | scaled_root_powers, 66 | inv_root_powers, 67 | scaled_inv_root_powers, 68 | ]; 69 | } 70 | output 71 | } 72 | 73 | pub fn build_ntt_tables( 74 | poly_len: usize, 75 | moduli: &[u64], 76 | opt_roots: Option<&[u64]>, 77 | ) -> Vec>> { 78 | let poly_len_log2 = log2(poly_len as u64) as usize; 79 | let mut output: Vec>> = vec![Vec::new(); moduli.len()]; 80 | for coeff_mod in 0..moduli.len() { 81 | let modulus = moduli[coeff_mod]; 82 | let modulus_as_u32 = modulus.try_into().unwrap(); 83 | let root = if let Some(roots) = opt_roots { 84 | roots[coeff_mod] 85 | } else { 86 | get_minimal_primitive_root(2 * poly_len as u64, modulus).unwrap() 87 | }; 88 | let inv_root = invert_uint_mod(root, modulus).unwrap(); 89 | 90 | let root_powers = powers_of_primitive_root(root, modulus, poly_len_log2); 91 | let scaled_root_powers = scale_powers_u32(modulus_as_u32, poly_len, root_powers.as_slice()); 92 | let mut inv_root_powers = powers_of_primitive_root(inv_root, modulus, poly_len_log2); 93 | for i in 0..poly_len { 94 | inv_root_powers[i] = div2_uint_mod(inv_root_powers[i], modulus); 95 | } 96 | let scaled_inv_root_powers = 97 | scale_powers_u32(modulus_as_u32, poly_len, inv_root_powers.as_slice()); 98 | 99 | output[coeff_mod] = vec![ 100 | root_powers, 101 | scaled_root_powers, 102 | inv_root_powers, 103 | scaled_inv_root_powers, 104 | ]; 105 | } 106 | output 107 | } 108 | 109 | #[cfg(not(target_feature = "avx2"))] 110 | pub fn ntt_forward(params: &Params, operand_overall: &mut [u64]) { 111 | if params.crt_count == 1 { 112 | ntt_forward_alt(params, operand_overall); 113 | return; 114 | } 115 | let log_n = params.poly_len_log2; 116 | let n = 1 << log_n; 117 | 118 | for coeff_mod in 0..params.crt_count { 119 | let operand = &mut operand_overall[coeff_mod * n..coeff_mod * n + n]; 120 | 121 | let forward_table = params.get_ntt_forward_table(coeff_mod); 122 | let forward_table_prime = params.get_ntt_forward_prime_table(coeff_mod); 123 | let modulus_small = params.moduli[coeff_mod] as u32; 124 | let two_times_modulus_small: u32 = 2 * modulus_small; 125 | 126 | for mm in 0..log_n { 127 | let m = 1 << mm; 128 | let t = n >> (mm + 1); 129 | 130 | let mut it = operand.chunks_exact_mut(2 * t); 131 | 132 | for i in 0..m { 133 | let w = forward_table[m + i]; 134 | let w_prime = forward_table_prime[m + i]; 135 | 136 | let op = it.next().unwrap(); 137 | 138 | for j in 0..t { 139 | let x: u32 = op[j] as u32; 140 | let y: u32 = op[t + j] as u32; 141 | 142 | let curr_x: u32 = 143 | x - (two_times_modulus_small * ((x >= two_times_modulus_small) as u32)); 144 | let q_tmp: u64 = ((y as u64) * (w_prime as u64)) >> 32u64; 145 | let q_new = w * (y as u64) - q_tmp * (modulus_small as u64); 146 | 147 | op[j] = curr_x as u64 + q_new; 148 | op[t + j] = curr_x as u64 + ((two_times_modulus_small as u64) - q_new); 149 | } 150 | } 151 | } 152 | 153 | for i in 0..n { 154 | operand[i] -= ((operand[i] >= two_times_modulus_small as u64) as u64) 155 | * two_times_modulus_small as u64; 156 | operand[i] -= ((operand[i] >= modulus_small as u64) as u64) * modulus_small as u64; 157 | } 158 | } 159 | } 160 | 161 | pub fn ntt_forward_alt(params: &Params, operand_overall: &mut [u64]) { 162 | let log_n = params.poly_len_log2; 163 | let n = 1 << log_n; 164 | 165 | for coeff_mod in 0..params.crt_count { 166 | let operand = &mut operand_overall[coeff_mod * n..coeff_mod * n + n]; 167 | 168 | let forward_table = params.get_ntt_forward_table(coeff_mod); 169 | let forward_table_prime = params.get_ntt_forward_prime_table(coeff_mod); 170 | let modulus_small = params.moduli[coeff_mod]; 171 | let two_times_modulus_small = 2 * modulus_small; 172 | 173 | for mm in 0..log_n { 174 | let m = 1 << mm; 175 | let t = n >> (mm + 1); 176 | 177 | let mut it = operand.chunks_exact_mut(2 * t); 178 | 179 | for i in 0..m { 180 | let w = forward_table[m + i]; 181 | let w_prime = forward_table_prime[m + i]; 182 | 183 | let op = it.next().unwrap(); 184 | 185 | for j in 0..t { 186 | let x: u64 = op[j] as u64; 187 | let y: u64 = op[t + j] as u64; 188 | 189 | let curr_x: u64 = 190 | x - (two_times_modulus_small * ((x >= two_times_modulus_small) as u64)); 191 | let q_tmp = ((y as u128) * (w_prime as u128)) >> 64u64; 192 | let q_new = (w as u128) * (y as u128) - q_tmp * (modulus_small as u128); 193 | let q_new = (q_new % (modulus_small as u128)) as u64; 194 | 195 | op[j] = curr_x as u64 + q_new; 196 | op[t + j] = curr_x as u64 + ((two_times_modulus_small as u64) - q_new); 197 | } 198 | } 199 | } 200 | 201 | for i in 0..n { 202 | operand[i] -= ((operand[i] >= two_times_modulus_small as u64) as u64) 203 | * two_times_modulus_small as u64; 204 | operand[i] -= ((operand[i] >= modulus_small as u64) as u64) * modulus_small as u64; 205 | } 206 | } 207 | } 208 | 209 | #[cfg(target_feature = "avx2")] 210 | pub fn ntt_forward(params: &Params, operand_overall: &mut [u64]) { 211 | if params.crt_count == 1 { 212 | ntt_forward_alt(params, operand_overall); 213 | return; 214 | } 215 | let log_n = params.poly_len_log2; 216 | let n = 1 << log_n; 217 | 218 | for coeff_mod in 0..params.crt_count { 219 | let operand = unsafe { 220 | std::slice::from_raw_parts_mut(operand_overall.as_mut_ptr().add(coeff_mod * n), n) 221 | }; 222 | 223 | let forward_table = params.get_ntt_forward_table(coeff_mod); 224 | let forward_table_prime = params.get_ntt_forward_prime_table(coeff_mod); 225 | let modulus_small = params.moduli[coeff_mod] as u32; 226 | let two_times_modulus_small: u32 = 2 * modulus_small; 227 | 228 | for mm in 0..log_n { 229 | let m = 1 << mm; 230 | let t = n >> (mm + 1); 231 | 232 | for i in 0..m { 233 | let w = unsafe { *forward_table.get_unchecked(m + i) }; 234 | let w_prime = unsafe { *forward_table_prime.get_unchecked(m + i) }; 235 | 236 | let op = unsafe { 237 | std::slice::from_raw_parts_mut(operand.as_mut_ptr().add(2 * t * i), 2 * t) 238 | }; 239 | 240 | if t < 4 || log_n <= 10 { 241 | for j in 0..t { 242 | let x: u32 = unsafe { *op.get_unchecked(j) as u32 }; 243 | let y: u32 = unsafe { *op.get_unchecked(t + j) as u32 }; 244 | 245 | let curr_x: u32 = 246 | x - (two_times_modulus_small * ((x >= two_times_modulus_small) as u32)); 247 | let q_tmp: u64 = ((y as u64) * (w_prime as u64)) >> 32u64; 248 | let q_new = w * (y as u64) - q_tmp * (modulus_small as u64); 249 | 250 | unsafe { 251 | *op.get_unchecked_mut(j) = curr_x as u64 + q_new; 252 | *op.get_unchecked_mut(t + j) = 253 | curr_x as u64 + ((two_times_modulus_small as u64) - q_new); 254 | } 255 | } 256 | } else if t == 4 { 257 | unsafe { 258 | for j in (0..t).step_by(4) { 259 | // Use AVX2 here 260 | let p_x = op.get_unchecked_mut(j) as *mut u64; 261 | let p_y = op.get_unchecked_mut(j + t) as *mut u64; 262 | let x = _mm256_load_si256(p_x as *const __m256i); 263 | let y = _mm256_load_si256(p_y as *const __m256i); 264 | 265 | let cmp_val = _mm256_set1_epi64x(two_times_modulus_small as i64); 266 | let gt_mask = _mm256_cmpgt_epi64(x, cmp_val); 267 | 268 | let to_subtract = _mm256_and_si256(gt_mask, cmp_val); 269 | let curr_x = _mm256_sub_epi64(x, to_subtract); 270 | 271 | // uint32_t q_val = ((y) * (uint64_t)(Wprime)) >> 32; 272 | let w_prime_vec = _mm256_set1_epi64x(w_prime as i64); 273 | let product = _mm256_mul_epu32(y, w_prime_vec); 274 | let q_val = _mm256_srli_epi64(product, 32); 275 | 276 | // q_val = W * y - q_val * modulus_small; 277 | let w_vec = _mm256_set1_epi64x(w as i64); 278 | let w_times_y = _mm256_mul_epu32(y, w_vec); 279 | let modulus_small_vec = _mm256_set1_epi64x(modulus_small as i64); 280 | let q_scaled = _mm256_mul_epu32(q_val, modulus_small_vec); 281 | let q_final = _mm256_sub_epi64(w_times_y, q_scaled); 282 | 283 | let new_x = _mm256_add_epi64(curr_x, q_final); 284 | let q_final_inverted = _mm256_sub_epi64(cmp_val, q_final); 285 | let new_y = _mm256_add_epi64(curr_x, q_final_inverted); 286 | 287 | _mm256_store_si256(p_x as *mut __m256i, new_x); 288 | _mm256_store_si256(p_y as *mut __m256i, new_y); 289 | } 290 | } 291 | } else { 292 | unsafe { 293 | for j in (0..t).step_by(8) { 294 | let p_x = op.get_unchecked_mut(j) as *mut u64; 295 | let p_y = op.get_unchecked_mut(j + t) as *mut u64; 296 | let x = _mm512_load_si512(p_x as *const _); 297 | let y = _mm512_load_si512(p_y as *const _); 298 | 299 | let cmp_val = _mm512_set1_epi64(two_times_modulus_small as i64); 300 | let gt_mask = _mm512_cmpgt_epu64_mask(x, cmp_val); 301 | 302 | // let to_subtract = _mm512_and_si512(gt_mask, cmp_val); 303 | let curr_x = _mm512_mask_sub_epi64(x, gt_mask, x, cmp_val); 304 | 305 | // uint32_t q_val = ((y) * (uint64_t)(Wprime)) >> 32; 306 | let w_prime_vec = _mm512_set1_epi64(w_prime as i64); 307 | let product = _mm512_mul_epu32(y, w_prime_vec); 308 | let q_val = _mm512_srli_epi64(product, 32); 309 | 310 | // q_val = W * y - q_val * modulus_small; 311 | let w_vec = _mm512_set1_epi64(w as i64); 312 | let w_times_y = _mm512_mul_epu32(y, w_vec); 313 | let modulus_small_vec = _mm512_set1_epi64(modulus_small as i64); 314 | let q_scaled = _mm512_mul_epu32(q_val, modulus_small_vec); 315 | let q_final = _mm512_sub_epi64(w_times_y, q_scaled); 316 | 317 | let new_x = _mm512_add_epi64(curr_x, q_final); 318 | let q_final_inverted = _mm512_sub_epi64(cmp_val, q_final); 319 | let new_y = _mm512_add_epi64(curr_x, q_final_inverted); 320 | 321 | _mm512_store_si512(p_x as *mut _, new_x); 322 | _mm512_store_si512(p_y as *mut _, new_y); 323 | } 324 | } 325 | } 326 | } 327 | } 328 | 329 | if log_n <= 10 { 330 | for i in 0..n { 331 | operand[i] -= ((operand[i] >= two_times_modulus_small as u64) as u64) 332 | * two_times_modulus_small as u64; 333 | operand[i] -= ((operand[i] >= modulus_small as u64) as u64) * modulus_small as u64; 334 | } 335 | continue; 336 | } 337 | 338 | for i in (0..n).step_by(8) { 339 | unsafe { 340 | let p_x = operand.get_unchecked_mut(i) as *mut u64; 341 | 342 | let cmp_val1 = _mm512_set1_epi64(two_times_modulus_small as i64); 343 | let mut x = _mm512_load_si512(p_x as *const _); 344 | let mut gt_mask = _mm512_cmpgt_epu64_mask(x, cmp_val1); 345 | // let mut to_subtract = _mm512_and_si512(gt_mask, cmp_val1); 346 | x = _mm512_mask_sub_epi64(x, gt_mask, x, cmp_val1); 347 | 348 | let cmp_val2 = _mm512_set1_epi64(modulus_small as i64); 349 | gt_mask = _mm512_cmpgt_epu64_mask(x, cmp_val2); 350 | // to_subtract = _mm512_and_si512(gt_mask, cmp_val2); 351 | x = _mm512_mask_sub_epi64(x, gt_mask, x, cmp_val2); 352 | _mm512_store_si512(p_x as *mut _, x); 353 | } 354 | } 355 | } 356 | } 357 | 358 | // #[cfg(not(target_feature = "avx2"))] 359 | // pub fn ntt_inverse(params: &Params, operand_overall: &mut [u64]) { 360 | // if params.crt_count == 1 { 361 | // ntt_inverse_alt(params, operand_overall); 362 | // return; 363 | // } 364 | // for coeff_mod in 0..params.crt_count { 365 | // let n = params.poly_len; 366 | 367 | // let operand = &mut operand_overall[coeff_mod * n..coeff_mod * n + n]; 368 | 369 | // let inverse_table = params.get_ntt_inverse_table(coeff_mod); 370 | // let inverse_table_prime = params.get_ntt_inverse_prime_table(coeff_mod); 371 | // let modulus = params.moduli[coeff_mod]; 372 | // let two_times_modulus: u64 = 2 * modulus; 373 | 374 | // for mm in (0..params.poly_len_log2).rev() { 375 | // let h = 1 << mm; 376 | // let t = n >> (mm + 1); 377 | 378 | // let mut it = operand.chunks_exact_mut(2 * t); 379 | 380 | // for i in 0..h { 381 | // let w = inverse_table[h + i]; 382 | // let w_prime = inverse_table_prime[h + i]; 383 | 384 | // let op = it.next().unwrap(); 385 | 386 | // for j in 0..t { 387 | // let x = op[j]; 388 | // let y = op[t + j]; 389 | 390 | // let t_tmp = two_times_modulus - y + x; 391 | // let curr_x = x + y - (two_times_modulus * (((x << 1) >= t_tmp) as u64)); 392 | // let h_tmp = (t_tmp * w_prime) >> 32; 393 | 394 | // let res_x = (curr_x + (modulus * ((t_tmp & 1) as u64))) >> 1; 395 | // let res_y = w * t_tmp - h_tmp * modulus; 396 | 397 | // op[j] = res_x; 398 | // op[t + j] = res_y; 399 | // } 400 | // } 401 | // } 402 | 403 | // for i in 0..n { 404 | // operand[i] -= ((operand[i] >= two_times_modulus) as u64) * two_times_modulus; 405 | // operand[i] -= ((operand[i] >= modulus) as u64) * modulus; 406 | // } 407 | // } 408 | // } 409 | 410 | pub fn ntt_inverse_alt(params: &Params, operand_overall: &mut [u64]) { 411 | for coeff_mod in 0..params.crt_count { 412 | let n = params.poly_len; 413 | 414 | let operand = &mut operand_overall[coeff_mod * n..coeff_mod * n + n]; 415 | 416 | let inverse_table = params.get_ntt_inverse_table(coeff_mod); 417 | let inverse_table_prime = params.get_ntt_inverse_prime_table(coeff_mod); 418 | let modulus = params.moduli[coeff_mod]; 419 | let two_times_modulus: u64 = 2 * modulus; 420 | 421 | for mm in (0..params.poly_len_log2).rev() { 422 | let h = 1 << mm; 423 | let t = n >> (mm + 1); 424 | 425 | let mut it = operand.chunks_exact_mut(2 * t); 426 | 427 | for i in 0..h { 428 | let w = inverse_table[h + i]; 429 | let w_prime = inverse_table_prime[h + i]; 430 | 431 | let op = it.next().unwrap(); 432 | 433 | for j in 0..t { 434 | let x = op[j]; 435 | let y = op[t + j]; 436 | 437 | let t_tmp = two_times_modulus - y + x; 438 | let curr_x = x + y - (two_times_modulus * (((x << 1) >= t_tmp) as u64)); 439 | let h_tmp = ((t_tmp as u128) * (w_prime as u128)) >> 64; 440 | 441 | let res_x = (curr_x + (modulus * ((t_tmp & 1) as u64))) >> 1; 442 | let res_y = ((w as u128) * (t_tmp as u128)) - (h_tmp * modulus as u128); 443 | 444 | op[j] = res_x; 445 | op[t + j] = (res_y % (modulus as u128)) as u64; 446 | } 447 | } 448 | } 449 | 450 | for i in 0..n { 451 | operand[i] -= ((operand[i] >= two_times_modulus) as u64) * two_times_modulus; 452 | operand[i] -= ((operand[i] >= modulus) as u64) * modulus; 453 | } 454 | } 455 | } 456 | 457 | pub fn ntt_inverse_256(params: &Params, operand_overall: &mut [u64]) { 458 | if params.crt_count == 1 { 459 | ntt_inverse_alt(params, operand_overall); 460 | return; 461 | } 462 | for coeff_mod in 0..params.crt_count { 463 | let n = params.poly_len; 464 | 465 | let operand = &mut operand_overall[coeff_mod * n..coeff_mod * n + n]; 466 | 467 | let inverse_table = params.get_ntt_inverse_table(coeff_mod); 468 | let inverse_table_prime = params.get_ntt_inverse_prime_table(coeff_mod); 469 | let modulus = params.moduli[coeff_mod]; 470 | let two_times_modulus: u64 = 2 * modulus; 471 | for mm in (0..params.poly_len_log2).rev() { 472 | let h = 1 << mm; 473 | let t = n >> (mm + 1); 474 | 475 | let mut it = operand.chunks_exact_mut(2 * t); 476 | 477 | for i in 0..h { 478 | let w = inverse_table[h + i]; 479 | let w_prime = inverse_table_prime[h + i]; 480 | 481 | let op = it.next().unwrap(); 482 | 483 | if t < 4 { 484 | for j in 0..t { 485 | let x = op[j]; 486 | let y = op[t + j]; 487 | 488 | let t_tmp = two_times_modulus - y + x; 489 | let curr_x = x + y - (two_times_modulus * (((x << 1) >= t_tmp) as u64)); 490 | let h_tmp = (t_tmp * w_prime) >> 32; 491 | 492 | let res_x = (curr_x + (modulus * ((t_tmp & 1) as u64))) >> 1; 493 | let res_y = w * t_tmp - h_tmp * modulus; 494 | 495 | op[j] = res_x; 496 | op[t + j] = res_y; 497 | } 498 | } else { 499 | unsafe { 500 | for j in (0..t).step_by(4) { 501 | // Use AVX2 here 502 | let p_x = &mut op[j] as *mut u64; 503 | let p_y = &mut op[j + t] as *mut u64; 504 | let x = _mm256_load_si256(p_x as *const __m256i); 505 | let y = _mm256_load_si256(p_y as *const __m256i); 506 | 507 | let modulus_vec = _mm256_set1_epi64x(modulus as i64); 508 | let two_times_modulus_vec = 509 | _mm256_set1_epi64x(two_times_modulus as i64); 510 | let mut t_tmp = _mm256_set1_epi64x(two_times_modulus as i64); 511 | t_tmp = _mm256_sub_epi64(t_tmp, y); 512 | t_tmp = _mm256_add_epi64(t_tmp, x); 513 | let gt_mask = _mm256_cmpgt_epi64(_mm256_slli_epi64(x, 1), t_tmp); 514 | let to_subtract = _mm256_and_si256(gt_mask, two_times_modulus_vec); 515 | let mut curr_x = _mm256_add_epi64(x, y); 516 | curr_x = _mm256_sub_epi64(curr_x, to_subtract); 517 | 518 | let w_prime_vec = _mm256_set1_epi64x(w_prime as i64); 519 | let mut h_tmp = _mm256_mul_epu32(t_tmp, w_prime_vec); 520 | h_tmp = _mm256_srli_epi64(h_tmp, 32); 521 | 522 | let and_mask = _mm256_set_epi64x(1, 1, 1, 1); 523 | let eq_mask = 524 | _mm256_cmpeq_epi64(_mm256_and_si256(t_tmp, and_mask), and_mask); 525 | let to_add = _mm256_and_si256(eq_mask, modulus_vec); 526 | 527 | let new_x = _mm256_srli_epi64(_mm256_add_epi64(curr_x, to_add), 1); 528 | 529 | let w_vec = _mm256_set1_epi64x(w as i64); 530 | let w_times_t_tmp = _mm256_mul_epu32(t_tmp, w_vec); 531 | let h_tmp_times_modulus = _mm256_mul_epu32(h_tmp, modulus_vec); 532 | let new_y = _mm256_sub_epi64(w_times_t_tmp, h_tmp_times_modulus); 533 | 534 | _mm256_store_si256(p_x as *mut __m256i, new_x); 535 | _mm256_store_si256(p_y as *mut __m256i, new_y); 536 | } 537 | } 538 | } 539 | } 540 | } 541 | 542 | // for i in 0..n { 543 | // operand[i] -= ((operand[i] >= two_times_modulus) as u64) * two_times_modulus; 544 | // operand[i] -= ((operand[i] >= modulus) as u64) * modulus; 545 | // } 546 | 547 | for i in (0..n).step_by(4) { 548 | unsafe { 549 | let p_x = &mut operand[i] as *mut u64; 550 | 551 | let cmp_val1 = _mm256_set1_epi64x(two_times_modulus as i64); 552 | let mut x = _mm256_load_si256(p_x as *const __m256i); 553 | let mut gt_mask = _mm256_cmpgt_epi64(x, cmp_val1); 554 | let mut to_subtract = _mm256_and_si256(gt_mask, cmp_val1); 555 | x = _mm256_sub_epi64(x, to_subtract); 556 | 557 | let cmp_val2 = _mm256_set1_epi64x(modulus as i64); 558 | gt_mask = _mm256_cmpgt_epi64(x, cmp_val2); 559 | to_subtract = _mm256_and_si256(gt_mask, cmp_val2); 560 | x = _mm256_sub_epi64(x, to_subtract); 561 | _mm256_store_si256(p_x as *mut __m256i, x); 562 | } 563 | } 564 | } 565 | } 566 | 567 | pub fn ntt_inverse(params: &Params, operand_overall: &mut [u64]) { 568 | if params.crt_count == 1 { 569 | ntt_inverse_alt(params, operand_overall); 570 | return; 571 | } 572 | for coeff_mod in 0..params.crt_count { 573 | let n = params.poly_len; 574 | 575 | let operand = &mut operand_overall[coeff_mod * n..coeff_mod * n + n]; 576 | 577 | let inverse_table = params.get_ntt_inverse_table(coeff_mod); 578 | let inverse_table_prime = params.get_ntt_inverse_prime_table(coeff_mod); 579 | let modulus = params.moduli[coeff_mod]; 580 | let two_times_modulus: u64 = 2 * modulus; 581 | for mm in (0..params.poly_len_log2).rev() { 582 | let h = 1 << mm; 583 | let t = n >> (mm + 1); 584 | 585 | let mut it = operand.chunks_exact_mut(2 * t); 586 | 587 | for i in 0..h { 588 | let w = inverse_table[h + i]; 589 | let w_prime = inverse_table_prime[h + i]; 590 | 591 | let op = it.next().unwrap(); 592 | 593 | if t < 4 { 594 | for j in 0..t { 595 | let x = op[j]; 596 | let y = op[t + j]; 597 | 598 | let t_tmp = two_times_modulus - y + x; 599 | let curr_x = x + y - (two_times_modulus * (((x << 1) >= t_tmp) as u64)); 600 | let h_tmp = (t_tmp * w_prime) >> 32; 601 | 602 | let res_x = (curr_x + (modulus * ((t_tmp & 1) as u64))) >> 1; 603 | let res_y = w * t_tmp - h_tmp * modulus; 604 | 605 | op[j] = res_x; 606 | op[t + j] = res_y; 607 | } 608 | } else if t < 8 { 609 | unsafe { 610 | for j in (0..t).step_by(4) { 611 | // Use AVX2 here 612 | let p_x = &mut op[j] as *mut u64; 613 | let p_y = &mut op[j + t] as *mut u64; 614 | let x = _mm256_load_si256(p_x as *const __m256i); 615 | let y = _mm256_load_si256(p_y as *const __m256i); 616 | 617 | let modulus_vec = _mm256_set1_epi64x(modulus as i64); 618 | let two_times_modulus_vec = 619 | _mm256_set1_epi64x(two_times_modulus as i64); 620 | let mut t_tmp = _mm256_set1_epi64x(two_times_modulus as i64); 621 | t_tmp = _mm256_sub_epi64(t_tmp, y); 622 | t_tmp = _mm256_add_epi64(t_tmp, x); 623 | let gt_mask = _mm256_cmpgt_epi64(_mm256_slli_epi64(x, 1), t_tmp); 624 | let to_subtract = _mm256_and_si256(gt_mask, two_times_modulus_vec); 625 | let mut curr_x = _mm256_add_epi64(x, y); 626 | curr_x = _mm256_sub_epi64(curr_x, to_subtract); 627 | 628 | let w_prime_vec = _mm256_set1_epi64x(w_prime as i64); 629 | let mut h_tmp = _mm256_mul_epu32(t_tmp, w_prime_vec); 630 | h_tmp = _mm256_srli_epi64(h_tmp, 32); 631 | 632 | let and_mask = _mm256_set_epi64x(1, 1, 1, 1); 633 | let eq_mask = 634 | _mm256_cmpeq_epi64(_mm256_and_si256(t_tmp, and_mask), and_mask); 635 | let to_add = _mm256_and_si256(eq_mask, modulus_vec); 636 | 637 | let new_x = _mm256_srli_epi64(_mm256_add_epi64(curr_x, to_add), 1); 638 | 639 | let w_vec = _mm256_set1_epi64x(w as i64); 640 | let w_times_t_tmp = _mm256_mul_epu32(t_tmp, w_vec); 641 | let h_tmp_times_modulus = _mm256_mul_epu32(h_tmp, modulus_vec); 642 | let new_y = _mm256_sub_epi64(w_times_t_tmp, h_tmp_times_modulus); 643 | 644 | _mm256_store_si256(p_x as *mut __m256i, new_x); 645 | _mm256_store_si256(p_y as *mut __m256i, new_y); 646 | } 647 | } 648 | } else { 649 | unsafe { 650 | for j in (0..t).step_by(8) { 651 | // Use AVX2 here 652 | let p_x = &mut op[j] as *mut u64; 653 | let p_y = &mut op[j + t] as *mut u64; 654 | let x = _mm512_load_si512(p_x as *const _); 655 | let y = _mm512_load_si512(p_y as *const _); 656 | 657 | let modulus_vec = _mm512_set1_epi64(modulus as i64); 658 | let two_times_modulus_vec = _mm512_set1_epi64(two_times_modulus as i64); 659 | let mut t_tmp = _mm512_set1_epi64(two_times_modulus as i64); 660 | t_tmp = _mm512_sub_epi64(t_tmp, y); 661 | t_tmp = _mm512_add_epi64(t_tmp, x); 662 | // let gt_mask = _mm512_cmpgt_epi64(_mm512_slli_epi64(x, 1), t_tmp); 663 | let gt_mask = _mm512_cmpgt_epu64_mask(_mm512_slli_epi64(x, 1), t_tmp); 664 | // let to_subtract = _mm512_and_si512(gt_mask, two_times_modulus_vec); 665 | let mut curr_x = _mm512_add_epi64(x, y); 666 | curr_x = _mm512_mask_sub_epi64( 667 | curr_x, 668 | gt_mask, 669 | curr_x, 670 | two_times_modulus_vec, 671 | ); 672 | 673 | let w_prime_vec = _mm512_set1_epi64(w_prime as i64); 674 | let mut h_tmp = _mm512_mul_epu32(t_tmp, w_prime_vec); 675 | h_tmp = _mm512_srli_epi64(h_tmp, 32); 676 | 677 | let and_mask = _mm512_set_epi64(1, 1, 1, 1, 1, 1, 1, 1); 678 | let eq_mask = _mm512_cmpeq_epi64_mask( 679 | _mm512_and_si512(t_tmp, and_mask), 680 | and_mask, 681 | ); 682 | // let to_add = _mm512_and_si512(eq_mask, modulus_vec); 683 | 684 | let new_x = _mm512_srli_epi64( 685 | _mm512_mask_add_epi64(curr_x, eq_mask, curr_x, modulus_vec), 686 | 1, 687 | ); 688 | 689 | let w_vec = _mm512_set1_epi64(w as i64); 690 | let w_times_t_tmp = _mm512_mul_epu32(t_tmp, w_vec); 691 | let h_tmp_times_modulus = _mm512_mul_epu32(h_tmp, modulus_vec); 692 | let new_y = _mm512_sub_epi64(w_times_t_tmp, h_tmp_times_modulus); 693 | 694 | _mm512_store_si512(p_x as *mut _, new_x); 695 | _mm512_store_si512(p_y as *mut _, new_y); 696 | } 697 | } 698 | } 699 | } 700 | } 701 | 702 | for i in (0..n).step_by(8) { 703 | unsafe { 704 | let p_x = &mut operand[i] as *mut u64; 705 | 706 | let cmp_val1 = _mm512_set1_epi64(two_times_modulus as i64); 707 | let mut x = _mm512_load_si512(p_x as *const _); 708 | let mut gt_mask = _mm512_cmpgt_epu64_mask(x, cmp_val1); 709 | // let mut to_subtract = _mm512_and_si512(gt_mask, cmp_val1); 710 | x = _mm512_mask_sub_epi64(x, gt_mask, x, cmp_val1); 711 | 712 | let cmp_val2 = _mm512_set1_epi64(modulus as i64); 713 | gt_mask = _mm512_cmpgt_epu64_mask(x, cmp_val2); 714 | // to_subtract = _mm512_and_si512(gt_mask, cmp_val2); 715 | x = _mm512_mask_sub_epi64(x, gt_mask, x, cmp_val2); 716 | _mm512_store_si512(p_x as *mut _, x); 717 | } 718 | } 719 | } 720 | } 721 | 722 | #[cfg(test)] 723 | mod test { 724 | use std::time::Instant; 725 | 726 | use super::*; 727 | use crate::{aligned_memory::AlignedMemory64, util::*}; 728 | use rand::Rng; 729 | 730 | fn get_params() -> Params { 731 | get_test_params() 732 | } 733 | 734 | const REF_VAL: u64 = 519370102; 735 | 736 | #[test] 737 | fn build_ntt_tables_correct() { 738 | let moduli = [268369921u64, 249561089u64]; 739 | let poly_len = 2048usize; 740 | let res = build_ntt_tables(poly_len, moduli.as_slice(), None); 741 | assert_eq!(res.len(), 2); 742 | assert_eq!(res[0].len(), 4); 743 | assert_eq!(res[0][0].len(), poly_len); 744 | assert_eq!(res[0][2][0], 134184961u64); 745 | assert_eq!(res[0][2][1], 96647580u64); 746 | let mut x1 = 0u64; 747 | for i in 0..res.len() { 748 | for j in 0..res[0].len() { 749 | for k in 0..res[0][0].len() { 750 | x1 ^= res[i][j][k]; 751 | } 752 | } 753 | } 754 | assert_eq!(x1, REF_VAL); 755 | } 756 | 757 | #[test] 758 | fn ntt_forward_correct() { 759 | let params = get_params(); 760 | let mut v1 = AlignedMemory64::new(2 * 2048); 761 | v1[0] = 100; 762 | v1[2048] = 100; 763 | ntt_forward(¶ms, v1.as_mut_slice()); 764 | assert_eq!(v1[50], 100); 765 | assert_eq!(v1[2048 + 50], 100); 766 | } 767 | 768 | #[test] 769 | fn ntt_inverse_correct() { 770 | let params = get_params(); 771 | let mut v1 = AlignedMemory64::new(2 * 2048); 772 | for i in 0..v1.len() { 773 | v1[i] = 100; 774 | } 775 | ntt_inverse(¶ms, v1.as_mut_slice()); 776 | assert_eq!(v1[0], 100); 777 | assert_eq!(v1[2048], 100); 778 | assert_eq!(v1[50], 0); 779 | assert_eq!(v1[2048 + 50], 0); 780 | } 781 | 782 | #[test] 783 | fn ntt_correct() { 784 | let params = get_params(); 785 | let trials = 1000; 786 | let mut v1 = AlignedMemory64::new(trials * params.crt_count * params.poly_len); 787 | let mut rng = rand::thread_rng(); 788 | for trial in 0..trials { 789 | for i in 0..params.crt_count { 790 | for j in 0..params.poly_len { 791 | let idx = 792 | calc_index(&[trial, i, j], &[trials, params.crt_count, params.poly_len]); 793 | let val: u64 = rng.gen(); 794 | v1[idx] = val % params.moduli[i]; 795 | } 796 | } 797 | } 798 | let mut v2 = v1.clone(); 799 | for chunk in v2 800 | .as_mut_slice() 801 | .chunks_exact_mut(params.crt_count * params.poly_len) 802 | { 803 | ntt_forward(¶ms, chunk); 804 | } 805 | 806 | let now = Instant::now(); 807 | for chunk in v2 808 | .as_mut_slice() 809 | .chunks_exact_mut(params.crt_count * params.poly_len) 810 | { 811 | ntt_inverse(¶ms, chunk); 812 | } 813 | println!("ntt 512 taken: {:?}", now.elapsed()); 814 | 815 | // for i in 0..params.crt_count * params.poly_len { 816 | // assert_eq!(v1[i], v2[i]); 817 | // } 818 | 819 | // let now = Instant::now(); 820 | // ntt_inverse_256(¶ms, v2.as_mut_slice()); 821 | // println!("ntt 256 taken: {:?}", now.elapsed()); 822 | 823 | let now = Instant::now(); 824 | for chunk in v2 825 | .as_mut_slice() 826 | .chunks_exact_mut(params.crt_count * params.poly_len) 827 | { 828 | ntt_forward(¶ms, chunk); 829 | } 830 | println!("ntt for taken: {:?}", now.elapsed()); 831 | 832 | let now = Instant::now(); 833 | for chunk in v2 834 | .as_mut_slice() 835 | .chunks_exact_mut(params.crt_count * params.poly_len) 836 | { 837 | ntt_inverse_256(¶ms, chunk); 838 | } 839 | println!("ntt 256 taken: {:?}", now.elapsed()); 840 | 841 | // let now = Instant::now(); 842 | // ntt_inverse_256(¶ms, v2.as_mut_slice()); 843 | // println!("ntt 256 taken: {:?}", now.elapsed()); 844 | // for i in 0..params.crt_count * params.poly_len { 845 | // assert_eq!(v1[i], v2[i]); 846 | // } 847 | } 848 | 849 | fn get_alt_params() -> Params { 850 | Params::init( 851 | 2048, 852 | &vec![180143985094819841u64], 853 | 6.4, 854 | 2, 855 | 256, 856 | 20, 857 | 4, 858 | 8, 859 | 56, 860 | 8, 861 | true, 862 | 9, 863 | 6, 864 | 1, 865 | 2048, 866 | 0, 867 | ) 868 | } 869 | 870 | #[test] 871 | fn alt_ntt_correct() { 872 | let params = get_alt_params(); 873 | let mut v1 = AlignedMemory64::new(params.crt_count * params.poly_len); 874 | let mut rng = rand::thread_rng(); 875 | for i in 0..params.crt_count { 876 | for j in 0..params.poly_len { 877 | let idx = calc_index(&[i, j], &[params.crt_count, params.poly_len]); 878 | let val: u64 = rng.gen(); 879 | v1[idx] = val % params.moduli[i]; 880 | } 881 | } 882 | let mut v2 = v1.clone(); 883 | ntt_forward(¶ms, v2.as_mut_slice()); 884 | ntt_inverse(¶ms, v2.as_mut_slice()); 885 | for i in 0..params.crt_count * params.poly_len { 886 | assert_eq!(v1[i], v2[i]); 887 | } 888 | } 889 | 890 | #[test] 891 | fn calc_index_correct() { 892 | assert_eq!(calc_index(&[2, 3, 4], &[10, 10, 100]), 2304); 893 | assert_eq!(calc_index(&[2, 3, 4], &[3, 5, 7]), 95); 894 | } 895 | } 896 | -------------------------------------------------------------------------------- /src/number_theory.rs: -------------------------------------------------------------------------------- 1 | use crate::arith::*; 2 | use rand::Rng; 3 | 4 | const ATTEMPT_MAX: usize = 100; 5 | 6 | pub fn is_primitive_root(root: u64, degree: u64, modulus: u64) -> bool { 7 | if root == 0 { 8 | return false; 9 | } 10 | 11 | exponentiate_uint_mod(root, degree >> 1, modulus) == modulus - 1 12 | } 13 | 14 | pub fn get_primitive_root(degree: u64, modulus: u64) -> Option { 15 | assert!(modulus > 1); 16 | assert!(degree >= 2); 17 | let size_entire_group = modulus - 1; 18 | let size_quotient_group = size_entire_group / degree; 19 | if size_entire_group - size_quotient_group * degree != 0 { 20 | return None; 21 | } 22 | 23 | let mut root = 0u64; 24 | for trial in 0..ATTEMPT_MAX { 25 | let mut rng = rand::thread_rng(); 26 | let r1: u64 = rng.gen(); 27 | let r2: u64 = rng.gen(); 28 | let r3 = ((r1 << 32) | r2) % modulus; 29 | root = exponentiate_uint_mod(r3, size_quotient_group, modulus); 30 | if is_primitive_root(root, degree, modulus) { 31 | break; 32 | } 33 | if trial == ATTEMPT_MAX - 1 { 34 | return None; 35 | } 36 | } 37 | 38 | Some(root) 39 | } 40 | 41 | pub fn get_minimal_primitive_root(degree: u64, modulus: u64) -> Option { 42 | let mut root = get_primitive_root(degree, modulus)?; 43 | let generator_sq = multiply_uint_mod(root, root, modulus); 44 | let mut current_generator = root; 45 | 46 | for _ in 0..degree { 47 | if current_generator < root { 48 | root = current_generator; 49 | } 50 | 51 | current_generator = multiply_uint_mod(current_generator, generator_sq, modulus); 52 | } 53 | 54 | Some(root) 55 | } 56 | 57 | pub fn extended_gcd(mut x: u64, mut y: u64) -> (u64, i64, i64) { 58 | assert!(x != 0); 59 | assert!(y != 0); 60 | 61 | let mut prev_a = 1; 62 | let mut a = 0; 63 | let mut prev_b = 0; 64 | let mut b = 1; 65 | 66 | while y != 0 { 67 | let q: i64 = (x / y) as i64; 68 | let mut temp = (x % y) as i64; 69 | x = y; 70 | y = temp as u64; 71 | 72 | temp = a; 73 | a = prev_a - (q * a); 74 | prev_a = temp; 75 | 76 | temp = b; 77 | b = prev_b - (q * b); 78 | prev_b = temp; 79 | } 80 | 81 | (x, prev_a, prev_b) 82 | } 83 | 84 | pub fn invert_uint_mod(value: u64, modulus: u64) -> Option { 85 | if value == 0 { 86 | return None; 87 | } 88 | let gcd_tuple = extended_gcd(value, modulus); 89 | if gcd_tuple.0 != 1 { 90 | return None; 91 | } else if gcd_tuple.1 < 0 { 92 | return Some((gcd_tuple.1 as u64).overflowing_add(modulus).0); 93 | } else { 94 | return Some(gcd_tuple.1 as u64); 95 | } 96 | } 97 | -------------------------------------------------------------------------------- /src/params.rs: -------------------------------------------------------------------------------- 1 | use std::mem::size_of; 2 | 3 | use crate::{arith::*, client::SEED_LENGTH, ntt::*, number_theory::*, poly::*}; 4 | 5 | pub const MAX_MODULI: usize = 4; 6 | 7 | pub static MIN_Q2_BITS: u64 = 14; 8 | pub static Q2_VALUES: [u64; 37] = [ 9 | 0, 10 | 0, 11 | 0, 12 | 0, 13 | 0, 14 | 0, 15 | 0, 16 | 0, 17 | 0, 18 | 0, 19 | 0, 20 | 0, 21 | 0, 22 | 0, 23 | 12289, 24 | 12289, 25 | 61441, 26 | 65537, 27 | 65537, 28 | 520193, 29 | 786433, 30 | 786433, 31 | 3604481, 32 | 7340033, 33 | 16515073, 34 | 33292289, 35 | 67043329, 36 | 132120577, 37 | 268369921, 38 | 469762049, 39 | 1073479681, 40 | 2013265921, 41 | 4293918721, 42 | 8588886017, 43 | 17175674881, 44 | 34359214081, 45 | 68718428161, 46 | ]; 47 | 48 | #[derive(Debug, PartialEq, Clone)] 49 | pub struct Params { 50 | pub poly_len: usize, 51 | pub poly_len_log2: usize, 52 | pub ntt_tables: Vec>>, 53 | pub scratch: Vec, 54 | 55 | pub crt_count: usize, 56 | pub barrett_cr_0: [u64; MAX_MODULI], 57 | pub barrett_cr_1: [u64; MAX_MODULI], 58 | pub barrett_cr_0_modulus: u64, 59 | pub barrett_cr_1_modulus: u64, 60 | pub mod0_inv_mod1: u64, 61 | pub mod1_inv_mod0: u64, 62 | pub moduli: [u64; MAX_MODULI], 63 | pub modulus: u64, 64 | pub modulus_log2: u64, 65 | pub noise_width: f64, 66 | 67 | pub n: usize, 68 | pub pt_modulus: u64, 69 | pub q2_bits: u64, 70 | pub t_conv: usize, 71 | pub t_exp_left: usize, 72 | pub t_exp_right: usize, 73 | pub t_gsw: usize, 74 | 75 | pub expand_queries: bool, 76 | pub db_dim_1: usize, 77 | pub db_dim_2: usize, 78 | pub instances: usize, 79 | pub db_item_size: usize, 80 | 81 | pub version: usize, 82 | } 83 | 84 | impl Params { 85 | pub fn get_ntt_forward_table(&self, i: usize) -> &[u64] { 86 | self.ntt_tables[i][0].as_slice() 87 | } 88 | pub fn get_ntt_forward_prime_table(&self, i: usize) -> &[u64] { 89 | self.ntt_tables[i][1].as_slice() 90 | } 91 | pub fn get_ntt_inverse_table(&self, i: usize) -> &[u64] { 92 | self.ntt_tables[i][2].as_slice() 93 | } 94 | pub fn get_ntt_inverse_prime_table(&self, i: usize) -> &[u64] { 95 | self.ntt_tables[i][3].as_slice() 96 | } 97 | 98 | pub fn get_v_neg1(&self) -> Vec { 99 | let mut v_neg1 = Vec::new(); 100 | for i in 0..self.poly_len_log2 { 101 | let idx = self.poly_len - (1 << i); 102 | let mut ng1 = PolyMatrixRaw::zero(&self, 1, 1); 103 | ng1.data[idx] = 1; 104 | v_neg1.push((-&ng1).ntt()); 105 | } 106 | v_neg1 107 | } 108 | 109 | pub fn get_sk_gsw(&self) -> (usize, usize) { 110 | (self.n, 1) 111 | } 112 | pub fn get_sk_reg(&self) -> (usize, usize) { 113 | (1, 1) 114 | } 115 | 116 | pub fn num_expanded(&self) -> usize { 117 | 1 << self.db_dim_1 118 | } 119 | 120 | pub fn num_items(&self) -> usize { 121 | (1 << self.db_dim_1) * (1 << self.db_dim_2) 122 | } 123 | 124 | pub fn item_size(&self) -> usize { 125 | let logp = log2(self.pt_modulus) as usize; 126 | self.instances * self.n * self.n * self.poly_len * logp / 8 127 | } 128 | 129 | pub fn g(&self) -> usize { 130 | let num_bits_to_gen = self.t_gsw * self.db_dim_2 + self.num_expanded(); 131 | log2_ceil_usize(num_bits_to_gen) 132 | } 133 | 134 | pub fn stop_round(&self) -> usize { 135 | log2_ceil_usize(self.t_gsw * self.db_dim_2) 136 | } 137 | 138 | pub fn factor_on_first_dim(&self) -> usize { 139 | if self.db_dim_2 == 0 { 140 | 1 141 | } else { 142 | 2 143 | } 144 | } 145 | 146 | pub fn setup_bytes(&self) -> usize { 147 | let mut sz_polys = 0; 148 | 149 | let num_packing_mats = if self.version == 0 { self.n } else { 2 }; 150 | let packing_sz = ((self.n + 1) - 1) * self.t_conv; 151 | sz_polys += num_packing_mats * packing_sz; 152 | 153 | if self.expand_queries { 154 | let expansion_left_sz = self.g() * self.t_exp_left; 155 | let mut expansion_right_sz = (self.stop_round() + 1) * self.t_exp_right; 156 | let conversion_sz = 2 * self.t_conv; 157 | 158 | if self.version > 0 && self.t_exp_left == self.t_exp_right { 159 | expansion_right_sz = 0; 160 | } 161 | 162 | sz_polys += expansion_left_sz + expansion_right_sz + conversion_sz; 163 | } 164 | 165 | let sz_bytes = sz_polys * self.poly_len * size_of::(); 166 | SEED_LENGTH + sz_bytes 167 | } 168 | 169 | pub fn query_bytes(&self) -> usize { 170 | let sz_polys; 171 | 172 | if self.expand_queries { 173 | sz_polys = 1; 174 | } else { 175 | let first_dimension_sz = self.num_expanded(); 176 | let further_dimension_sz = self.db_dim_2 * (2 * self.t_gsw); 177 | sz_polys = first_dimension_sz + further_dimension_sz; 178 | } 179 | 180 | let sz_bytes = sz_polys * self.poly_len * size_of::(); 181 | SEED_LENGTH + sz_bytes 182 | } 183 | 184 | pub fn query_v_buf_bytes(&self) -> usize { 185 | self.num_expanded() * self.poly_len * size_of::() 186 | } 187 | 188 | pub fn bytes_per_chunk(&self) -> usize { 189 | let trials = self.n * self.n; 190 | let chunks = self.instances * trials; 191 | let bytes_per_chunk = f64::ceil(self.db_item_size as f64 / chunks as f64) as usize; 192 | bytes_per_chunk 193 | } 194 | 195 | pub fn modp_words_per_chunk(&self) -> usize { 196 | let bytes_per_chunk = self.bytes_per_chunk(); 197 | let logp = log2(self.pt_modulus); 198 | let modp_words_per_chunk = f64::ceil((bytes_per_chunk * 8) as f64 / logp as f64) as usize; 199 | modp_words_per_chunk 200 | } 201 | 202 | pub fn crt_compose_1(&self, x: u64) -> u64 { 203 | assert_eq!(self.crt_count, 1); 204 | x 205 | } 206 | 207 | pub fn crt_compose_2(&self, x: u64, y: u64) -> u64 { 208 | assert_eq!(self.crt_count, 2); 209 | // assert!(self.moduli[0] > self.moduli[1]); 210 | // n m 211 | 212 | let mut val = (x as u128) * (self.mod1_inv_mod0 as u128); 213 | val += (y as u128) * (self.mod0_inv_mod1 as u128); 214 | 215 | // let mut val = y as u128; 216 | // val += self.mod1_inv_mod0 as u128 * (x + self.moduli[0] - y) as u128; 217 | 218 | barrett_reduction_u128(self, val) 219 | } 220 | 221 | pub fn crt_compose(&self, a: &[u64], idx: usize) -> u64 { 222 | if self.crt_count == 1 { 223 | self.crt_compose_1(a[idx]) 224 | } else { 225 | self.crt_compose_2(a[idx], a[idx + self.poly_len]) 226 | } 227 | } 228 | 229 | pub fn init( 230 | poly_len: usize, 231 | moduli: &[u64], 232 | noise_width: f64, 233 | n: usize, 234 | pt_modulus: u64, 235 | q2_bits: u64, 236 | t_conv: usize, 237 | t_exp_left: usize, 238 | t_exp_right: usize, 239 | t_gsw: usize, 240 | expand_queries: bool, 241 | db_dim_1: usize, 242 | db_dim_2: usize, 243 | instances: usize, 244 | db_item_size: usize, 245 | version: usize, 246 | ) -> Self { 247 | assert!(q2_bits >= MIN_Q2_BITS); 248 | 249 | let poly_len_log2 = log2(poly_len as u64) as usize; 250 | let crt_count = moduli.len(); 251 | assert!(crt_count <= MAX_MODULI); 252 | let mut moduli_array = [0; MAX_MODULI]; 253 | for i in 0..crt_count { 254 | moduli_array[i] = moduli[i]; 255 | } 256 | let ntt_tables = if crt_count > 1 { 257 | build_ntt_tables(poly_len, moduli, None) 258 | } else { 259 | build_ntt_tables_alt(poly_len, moduli, None) 260 | }; 261 | let scratch = vec![0u64; crt_count * poly_len]; 262 | let mut modulus = 1; 263 | for m in moduli { 264 | modulus *= m; 265 | } 266 | let modulus_log2 = log2_ceil(modulus); 267 | let (barrett_cr_0, barrett_cr_1) = get_barrett(moduli); 268 | let (barrett_cr_0_modulus, barrett_cr_1_modulus) = get_barrett_crs(modulus); 269 | let mut mod0_inv_mod1 = 0; 270 | let mut mod1_inv_mod0 = 0; 271 | if crt_count == 2 { 272 | mod0_inv_mod1 = moduli[0] * invert_uint_mod(moduli[0], moduli[1]).unwrap(); 273 | mod1_inv_mod0 = moduli[1] * invert_uint_mod(moduli[1], moduli[0]).unwrap(); 274 | } 275 | Self { 276 | poly_len, 277 | poly_len_log2, 278 | ntt_tables, 279 | scratch, 280 | crt_count, 281 | barrett_cr_0, 282 | barrett_cr_1, 283 | barrett_cr_0_modulus, 284 | barrett_cr_1_modulus, 285 | mod0_inv_mod1, 286 | mod1_inv_mod0, 287 | moduli: moduli_array, 288 | modulus, 289 | modulus_log2, 290 | noise_width, 291 | n, 292 | pt_modulus, 293 | q2_bits, 294 | t_conv, 295 | t_exp_left, 296 | t_exp_right, 297 | t_gsw, 298 | expand_queries, 299 | db_dim_1, 300 | db_dim_2, 301 | instances, 302 | db_item_size, 303 | version, 304 | } 305 | } 306 | } 307 | -------------------------------------------------------------------------------- /src/poly.rs: -------------------------------------------------------------------------------- 1 | #[cfg(target_feature = "avx2")] 2 | use std::arch::x86_64::*; 3 | 4 | use rand::distributions::Standard; 5 | use rand::Rng; 6 | use rand_chacha::ChaCha20Rng; 7 | use std::cell::RefCell; 8 | use std::ops::{Add, Mul, Neg}; 9 | 10 | use crate::{aligned_memory::*, arith::*, discrete_gaussian::*, ntt::*, params::*, util::*}; 11 | 12 | const SCRATCH_SPACE: usize = 8192; 13 | thread_local!(static SCRATCH: RefCell = RefCell::new(AlignedMemory64::new(SCRATCH_SPACE))); 14 | 15 | pub trait PolyMatrix<'a> { 16 | fn is_ntt(&self) -> bool; 17 | fn get_rows(&self) -> usize; 18 | fn get_cols(&self) -> usize; 19 | fn get_params(&self) -> &Params; 20 | fn num_words(&self) -> usize; 21 | fn zero(params: &'a Params, rows: usize, cols: usize) -> Self; 22 | fn random(params: &'a Params, rows: usize, cols: usize) -> Self; 23 | fn random_rng(params: &'a Params, rows: usize, cols: usize, rng: &mut T) -> Self; 24 | fn as_slice(&self) -> &[u64]; 25 | fn as_mut_slice(&mut self) -> &mut [u64]; 26 | fn zero_out(&mut self) { 27 | for item in self.as_mut_slice() { 28 | *item = 0; 29 | } 30 | } 31 | fn get_poly(&self, row: usize, col: usize) -> &[u64] { 32 | let num_words = self.num_words(); 33 | let start = (row * self.get_cols() + col) * num_words; 34 | // &self.as_slice()[start..start + num_words] 35 | unsafe { self.as_slice().get_unchecked(start..start + num_words) } 36 | // unsafe { 37 | // let num_words = self.num_words(); 38 | // let ptr = self 39 | // .as_slice() 40 | // .as_ptr() 41 | // .add((row * self.get_cols() + col) * num_words); 42 | // std::slice::from_raw_parts(ptr, num_words) 43 | // } 44 | } 45 | fn get_poly_mut(&mut self, row: usize, col: usize) -> &mut [u64] { 46 | let num_words = self.num_words(); 47 | let start = (row * self.get_cols() + col) * num_words; 48 | // &mut self.as_mut_slice()[start..start + num_words] 49 | unsafe { 50 | self.as_mut_slice() 51 | .get_unchecked_mut(start..start + num_words) 52 | } 53 | // unsafe { 54 | // self.as_mut_slice().get_unchecked_mut(index) 55 | // let num_words = self.num_words(); 56 | // let ptr = self 57 | // .as_mut_slice() 58 | // .as_mut_ptr() 59 | // .add((row * self.get_cols() + col) * num_words); 60 | // std::slice::from_raw_parts_mut(ptr, num_words) 61 | // } 62 | } 63 | fn copy_into(&mut self, p: &Self, target_row: usize, target_col: usize) { 64 | assert!(target_row < self.get_rows()); 65 | assert!(target_col < self.get_cols()); 66 | assert!(target_row + p.get_rows() <= self.get_rows()); 67 | assert!(target_col + p.get_cols() <= self.get_cols()); 68 | for r in 0..p.get_rows() { 69 | for c in 0..p.get_cols() { 70 | let pol_src = p.get_poly(r, c); 71 | let pol_dst = self.get_poly_mut(target_row + r, target_col + c); 72 | pol_dst.copy_from_slice(pol_src); 73 | } 74 | } 75 | } 76 | 77 | fn submatrix(&self, target_row: usize, target_col: usize, rows: usize, cols: usize) -> Self; 78 | fn pad_top(&self, pad_rows: usize) -> Self; 79 | } 80 | 81 | pub struct PolyMatrixRaw<'a> { 82 | pub params: &'a Params, 83 | pub rows: usize, 84 | pub cols: usize, 85 | pub data: AlignedMemory64, 86 | } 87 | 88 | pub struct PolyMatrixNTT<'a> { 89 | pub params: &'a Params, 90 | pub rows: usize, 91 | pub cols: usize, 92 | pub data: AlignedMemory64, 93 | } 94 | 95 | impl<'a> PolyMatrix<'a> for PolyMatrixRaw<'a> { 96 | fn is_ntt(&self) -> bool { 97 | false 98 | } 99 | fn get_rows(&self) -> usize { 100 | self.rows 101 | } 102 | fn get_cols(&self) -> usize { 103 | self.cols 104 | } 105 | fn get_params(&self) -> &Params { 106 | &self.params 107 | } 108 | fn as_slice(&self) -> &[u64] { 109 | self.data.as_slice() 110 | } 111 | fn as_mut_slice(&mut self) -> &mut [u64] { 112 | self.data.as_mut_slice() 113 | } 114 | fn num_words(&self) -> usize { 115 | self.params.poly_len 116 | } 117 | fn zero(params: &'a Params, rows: usize, cols: usize) -> PolyMatrixRaw<'a> { 118 | let num_coeffs = rows * cols * params.poly_len; 119 | let data = AlignedMemory64::new(num_coeffs); 120 | PolyMatrixRaw { 121 | params, 122 | rows, 123 | cols, 124 | data, 125 | } 126 | } 127 | fn random_rng(params: &'a Params, rows: usize, cols: usize, rng: &mut T) -> Self { 128 | let mut iter = rng.sample_iter(&Standard); 129 | let mut out = PolyMatrixRaw::zero(params, rows, cols); 130 | for r in 0..rows { 131 | for c in 0..cols { 132 | for i in 0..params.poly_len { 133 | let val: u64 = iter.next().unwrap(); 134 | out.get_poly_mut(r, c)[i] = val % params.modulus; 135 | } 136 | } 137 | } 138 | out 139 | } 140 | fn random(params: &'a Params, rows: usize, cols: usize) -> Self { 141 | let mut rng = rand::thread_rng(); 142 | Self::random_rng(params, rows, cols, &mut rng) 143 | } 144 | fn pad_top(&self, pad_rows: usize) -> Self { 145 | let mut padded = Self::zero(self.params, self.rows + pad_rows, self.cols); 146 | padded.copy_into(&self, pad_rows, 0); 147 | padded 148 | } 149 | fn submatrix(&self, target_row: usize, target_col: usize, rows: usize, cols: usize) -> Self { 150 | let mut m = Self::zero(self.params, rows, cols); 151 | assert!(target_row < self.rows); 152 | assert!(target_col < self.cols); 153 | assert!(target_row + rows <= self.rows); 154 | assert!(target_col + cols <= self.cols); 155 | for r in 0..rows { 156 | for c in 0..cols { 157 | let pol_src = self.get_poly(target_row + r, target_col + c); 158 | let pol_dst = m.get_poly_mut(r, c); 159 | pol_dst.copy_from_slice(pol_src); 160 | } 161 | } 162 | m 163 | } 164 | } 165 | 166 | impl<'a> Clone for PolyMatrixRaw<'a> { 167 | fn clone(&self) -> Self { 168 | let mut data_clone = AlignedMemory64::new(self.data.len()); 169 | data_clone 170 | .as_mut_slice() 171 | .copy_from_slice(self.data.as_slice()); 172 | PolyMatrixRaw { 173 | params: self.params, 174 | rows: self.rows, 175 | cols: self.cols, 176 | data: data_clone, 177 | } 178 | } 179 | } 180 | 181 | impl<'a> PolyMatrixRaw<'a> { 182 | pub fn identity(params: &'a Params, rows: usize, cols: usize) -> PolyMatrixRaw<'a> { 183 | let num_coeffs = rows * cols * params.poly_len; 184 | let mut data = AlignedMemory::new(num_coeffs); 185 | for r in 0..rows { 186 | let c = r; 187 | let idx = r * cols * params.poly_len + c * params.poly_len; 188 | data[idx] = 1; 189 | } 190 | PolyMatrixRaw { 191 | params, 192 | rows, 193 | cols, 194 | data, 195 | } 196 | } 197 | 198 | pub fn noise( 199 | params: &'a Params, 200 | rows: usize, 201 | cols: usize, 202 | dg: &DiscreteGaussian, 203 | rng: &mut ChaCha20Rng, 204 | ) -> Self { 205 | let mut out = PolyMatrixRaw::zero(params, rows, cols); 206 | dg.sample_matrix(&mut out, rng); 207 | out 208 | } 209 | 210 | pub fn fast_noise( 211 | params: &'a Params, 212 | rows: usize, 213 | cols: usize, 214 | dg: &DiscreteGaussian, 215 | rng: &mut ChaCha20Rng, 216 | ) -> Self { 217 | let mut out = PolyMatrixRaw::zero(params, rows, cols); 218 | let modulus = params.modulus; 219 | for r in 0..out.rows { 220 | for c in 0..out.cols { 221 | let poly = out.get_poly_mut(r, c); 222 | for z in 0..poly.len() { 223 | let s = dg.fast_sample(modulus, rng); 224 | poly[z] = s; 225 | } 226 | } 227 | } 228 | out 229 | } 230 | 231 | pub fn ntt(&self) -> PolyMatrixNTT<'a> { 232 | to_ntt_alloc(&self) 233 | } 234 | 235 | pub fn reduce_mod(&mut self, modulus: u64) { 236 | for r in 0..self.rows { 237 | for c in 0..self.cols { 238 | for z in 0..self.params.poly_len { 239 | self.get_poly_mut(r, c)[z] %= modulus; 240 | } 241 | } 242 | } 243 | } 244 | 245 | pub fn apply_func u64>(&mut self, func: F) { 246 | for r in 0..self.rows { 247 | for c in 0..self.cols { 248 | let pol_mut = self.get_poly_mut(r, c); 249 | for el in pol_mut { 250 | *el = func(*el); 251 | } 252 | } 253 | } 254 | } 255 | 256 | pub fn to_vec(&self, modulus_bits: usize, num_coeffs: usize) -> Vec { 257 | let sz_bits = self.rows * self.cols * num_coeffs * modulus_bits; 258 | let sz_bytes = f64::ceil((sz_bits as f64) / 8f64) as usize + 32; 259 | let sz_bytes_roundup_16 = ((sz_bytes + 15) / 16) * 16; 260 | let mut data = vec![0u8; sz_bytes_roundup_16]; 261 | let mut bit_offs = 0; 262 | for r in 0..self.rows { 263 | for c in 0..self.cols { 264 | for z in 0..num_coeffs { 265 | let val = self.get_poly(r, c)[z]; 266 | write_arbitrary_bits(data.as_mut_slice(), val, bit_offs, modulus_bits); 267 | // assert_eq!( 268 | // read_arbitrary_bits(data.as_slice(), bit_offs, modulus_bits), 269 | // val 270 | // ); 271 | 272 | bit_offs += modulus_bits; 273 | } 274 | // round bit_offs down to nearest byte boundary 275 | bit_offs = (bit_offs / 8) * 8 276 | } 277 | } 278 | data 279 | } 280 | 281 | pub fn single_value(params: &'a Params, value: u64) -> PolyMatrixRaw<'a> { 282 | let mut out = Self::zero(params, 1, 1); 283 | out.data[0] = value; 284 | out 285 | } 286 | } 287 | 288 | impl<'a> PolyMatrix<'a> for PolyMatrixNTT<'a> { 289 | fn is_ntt(&self) -> bool { 290 | true 291 | } 292 | fn get_rows(&self) -> usize { 293 | self.rows 294 | } 295 | fn get_cols(&self) -> usize { 296 | self.cols 297 | } 298 | fn get_params(&self) -> &Params { 299 | &self.params 300 | } 301 | fn as_slice(&self) -> &[u64] { 302 | self.data.as_slice() 303 | } 304 | fn as_mut_slice(&mut self) -> &mut [u64] { 305 | self.data.as_mut_slice() 306 | } 307 | fn num_words(&self) -> usize { 308 | self.params.poly_len * self.params.crt_count 309 | } 310 | fn zero(params: &'a Params, rows: usize, cols: usize) -> PolyMatrixNTT<'a> { 311 | let num_coeffs = rows * cols * params.poly_len * params.crt_count; 312 | let data = AlignedMemory::new(num_coeffs); 313 | PolyMatrixNTT { 314 | params, 315 | rows, 316 | cols, 317 | data, 318 | } 319 | } 320 | fn random_rng(params: &'a Params, rows: usize, cols: usize, rng: &mut T) -> Self { 321 | let mut iter = rng.sample_iter(&Standard); 322 | let mut out = PolyMatrixNTT::zero(params, rows, cols); 323 | for r in 0..rows { 324 | for c in 0..cols { 325 | for i in 0..params.crt_count { 326 | for j in 0..params.poly_len { 327 | let idx = calc_index(&[i, j], &[params.crt_count, params.poly_len]); 328 | let val: u64 = iter.next().unwrap(); 329 | out.get_poly_mut(r, c)[idx] = val % params.moduli[i]; 330 | } 331 | } 332 | } 333 | } 334 | out 335 | } 336 | fn random(params: &'a Params, rows: usize, cols: usize) -> Self { 337 | let mut rng = rand::thread_rng(); 338 | Self::random_rng(params, rows, cols, &mut rng) 339 | } 340 | fn pad_top(&self, pad_rows: usize) -> Self { 341 | let mut padded = Self::zero(self.params, self.rows + pad_rows, self.cols); 342 | padded.copy_into(&self, pad_rows, 0); 343 | padded 344 | } 345 | 346 | fn submatrix(&self, target_row: usize, target_col: usize, rows: usize, cols: usize) -> Self { 347 | let mut m = Self::zero(self.params, rows, cols); 348 | assert!(target_row < self.rows); 349 | assert!(target_col < self.cols); 350 | assert!(target_row + rows <= self.rows); 351 | assert!(target_col + cols <= self.cols); 352 | for r in 0..rows { 353 | for c in 0..cols { 354 | let pol_src = self.get_poly(target_row + r, target_col + c); 355 | let pol_dst = m.get_poly_mut(r, c); 356 | pol_dst.copy_from_slice(pol_src); 357 | } 358 | } 359 | m 360 | } 361 | } 362 | 363 | impl<'a> Clone for PolyMatrixNTT<'a> { 364 | fn clone(&self) -> Self { 365 | let mut data_clone = AlignedMemory64::new(self.data.len()); 366 | data_clone 367 | .as_mut_slice() 368 | .copy_from_slice(self.data.as_slice()); 369 | PolyMatrixNTT { 370 | params: self.params, 371 | rows: self.rows, 372 | cols: self.cols, 373 | data: data_clone, 374 | } 375 | } 376 | } 377 | 378 | impl<'a> PolyMatrixNTT<'a> { 379 | pub fn raw(&self) -> PolyMatrixRaw<'a> { 380 | from_ntt_alloc(&self) 381 | } 382 | } 383 | 384 | pub fn shift_rows_by_one<'a>(inp: &PolyMatrixNTT<'a>) -> PolyMatrixNTT<'a> { 385 | if inp.rows == 1 { 386 | return inp.clone(); 387 | } 388 | 389 | let all_but_last_row = inp.submatrix(0, 0, inp.rows - 1, inp.cols); 390 | let last_row = inp.submatrix(inp.rows - 1, 0, 1, inp.cols); 391 | let out = stack_ntt(&last_row, &all_but_last_row); 392 | out 393 | } 394 | 395 | pub fn multiply_poly(params: &Params, res: &mut [u64], a: &[u64], b: &[u64]) { 396 | for c in 0..params.crt_count { 397 | for i in 0..params.poly_len { 398 | let idx = c * params.poly_len + i; 399 | res[idx] = multiply_modular(params, a[idx], b[idx], c); 400 | } 401 | } 402 | } 403 | 404 | pub fn multiply_add_poly(params: &Params, res: &mut [u64], a: &[u64], b: &[u64]) { 405 | for c in 0..params.crt_count { 406 | for i in 0..params.poly_len { 407 | let idx = c * params.poly_len + i; 408 | res[idx] = multiply_add_modular(params, a[idx], b[idx], res[idx], c); 409 | } 410 | } 411 | } 412 | 413 | pub fn add_poly(params: &Params, res: &mut [u64], a: &[u64], b: &[u64]) { 414 | for c in 0..params.crt_count { 415 | for i in 0..params.poly_len { 416 | let idx = c * params.poly_len + i; 417 | res[idx] = add_modular(params, a[idx], b[idx], c); 418 | } 419 | } 420 | } 421 | 422 | pub fn add_poly_raw(params: &Params, res: &mut [u64], a: &[u64], b: &[u64]) { 423 | for i in 0..params.poly_len { 424 | res[i] = a[i] + b[i]; 425 | } 426 | } 427 | 428 | pub fn add_poly_into(params: &Params, res: &mut [u64], a: &[u64]) { 429 | for c in 0..params.crt_count { 430 | for i in 0..params.poly_len { 431 | let idx = c * params.poly_len + i; 432 | res[idx] = add_modular(params, res[idx], a[idx], c); 433 | } 434 | } 435 | } 436 | 437 | pub fn sub_poly_into(params: &Params, res: &mut [u64], a: &[u64]) { 438 | for c in 0..params.crt_count { 439 | for i in 0..params.poly_len { 440 | let idx = c * params.poly_len + i; 441 | res[idx] = sub_modular(params, res[idx], a[idx], c); 442 | } 443 | } 444 | } 445 | 446 | pub fn invert_poly(params: &Params, res: &mut [u64], a: &[u64]) { 447 | for i in 0..params.poly_len { 448 | res[i] = params.modulus - a[i]; 449 | } 450 | } 451 | 452 | pub fn invert_poly_ntt(params: &Params, res: &mut [u64], a: &[u64]) { 453 | for c in 0..params.crt_count { 454 | for i in 0..params.poly_len { 455 | let idx = c * params.poly_len + i; 456 | res[idx] = invert_modular(params, a[idx], c); 457 | } 458 | } 459 | } 460 | 461 | pub fn automorph_poly(params: &Params, res: &mut [u64], a: &[u64], t: usize) { 462 | let poly_len = params.poly_len; 463 | for i in 0..poly_len { 464 | let num = (i * t) / poly_len; 465 | let rem = (i * t) % poly_len; 466 | 467 | if num % 2 == 0 { 468 | res[rem] = a[i]; 469 | } else { 470 | res[rem] = params.modulus - a[i]; 471 | } 472 | } 473 | } 474 | 475 | pub fn automorph_poly_uncrtd(params: &Params, res: &mut [u64], a: &[u64], t: usize) { 476 | let poly_len = params.poly_len; 477 | for m in 0..params.crt_count { 478 | let a_chunk = &a[m * poly_len..(m + 1) * poly_len]; 479 | let res_chunk = &mut res[m * poly_len..(m + 1) * poly_len]; 480 | for i in 0..poly_len { 481 | let num = (i * t) / poly_len; 482 | let rem = (i * t) % poly_len; 483 | 484 | if num % 2 == 0 { 485 | res_chunk[rem] = a_chunk[i]; 486 | } else { 487 | res_chunk[rem] = params.moduli[m] - a_chunk[i]; 488 | } 489 | } 490 | } 491 | } 492 | 493 | #[cfg(target_feature = "avx2")] 494 | pub fn multiply_add_poly_avx(params: &Params, res: &mut [u64], a: &[u64], b: &[u64]) { 495 | for c in 0..params.crt_count { 496 | for i in (0..params.poly_len).step_by(4) { 497 | unsafe { 498 | let p_x = &a[c * params.poly_len + i] as *const u64; 499 | let p_y = &b[c * params.poly_len + i] as *const u64; 500 | let p_z = &mut res[c * params.poly_len + i] as *mut u64; 501 | let x = _mm256_load_si256(p_x as *const __m256i); 502 | let y = _mm256_load_si256(p_y as *const __m256i); 503 | let z = _mm256_load_si256(p_z as *const __m256i); 504 | 505 | let product = _mm256_mul_epu32(x, y); 506 | let out = _mm256_add_epi64(z, product); 507 | 508 | _mm256_store_si256(p_z as *mut __m256i, out); 509 | } 510 | } 511 | } 512 | } 513 | 514 | #[cfg(target_feature = "avx2")] 515 | pub fn multiply_poly_avx(params: &Params, res: &mut [u64], a: &[u64], b: &[u64]) { 516 | for c in 0..params.crt_count { 517 | for i in (0..params.poly_len).step_by(4) { 518 | unsafe { 519 | let p_x = &a[c * params.poly_len + i] as *const u64; 520 | let p_y = &b[c * params.poly_len + i] as *const u64; 521 | let p_z = &mut res[c * params.poly_len + i] as *mut u64; 522 | let x = _mm256_load_si256(p_x as *const __m256i); 523 | let y = _mm256_load_si256(p_y as *const __m256i); 524 | 525 | let product = _mm256_mul_epu32(x, y); 526 | 527 | _mm256_store_si256(p_z as *mut __m256i, product); 528 | } 529 | } 530 | } 531 | } 532 | 533 | pub fn modular_reduce(params: &Params, res: &mut [u64]) { 534 | for c in 0..params.crt_count { 535 | for i in 0..params.poly_len { 536 | let idx = c * params.poly_len + i; 537 | res[idx] = barrett_coeff_u64(params, res[idx], c); 538 | } 539 | } 540 | } 541 | 542 | #[cfg(not(target_feature = "avx2"))] 543 | pub fn multiply(res: &mut PolyMatrixNTT, a: &PolyMatrixNTT, b: &PolyMatrixNTT) { 544 | assert!(res.rows == a.rows); 545 | assert!(res.cols == b.cols); 546 | assert!(a.cols == b.rows); 547 | 548 | let params = res.params; 549 | for i in 0..a.rows { 550 | for j in 0..b.cols { 551 | for z in 0..params.poly_len * params.crt_count { 552 | res.get_poly_mut(i, j)[z] = 0; 553 | } 554 | for k in 0..a.cols { 555 | let params = res.params; 556 | let res_poly = res.get_poly_mut(i, j); 557 | let pol1 = a.get_poly(i, k); 558 | let pol2 = b.get_poly(k, j); 559 | multiply_add_poly(params, res_poly, pol1, pol2); 560 | } 561 | } 562 | } 563 | } 564 | 565 | #[cfg(target_feature = "avx2")] 566 | pub fn multiply(res: &mut PolyMatrixNTT, a: &PolyMatrixNTT, b: &PolyMatrixNTT) { 567 | assert_eq!(res.rows, a.rows); 568 | assert_eq!(res.cols, b.cols); 569 | assert_eq!(a.cols, b.rows); 570 | 571 | let params = res.params; 572 | for i in 0..a.rows { 573 | for j in 0..b.cols { 574 | for z in 0..params.poly_len * params.crt_count { 575 | res.get_poly_mut(i, j)[z] = 0; 576 | } 577 | let res_poly = res.get_poly_mut(i, j); 578 | for k in 0..a.cols { 579 | let pol1 = a.get_poly(i, k); 580 | let pol2 = b.get_poly(k, j); 581 | multiply_add_poly_avx(params, res_poly, pol1, pol2); 582 | } 583 | modular_reduce(params, res_poly); 584 | } 585 | } 586 | } 587 | 588 | #[cfg(target_feature = "avx2")] 589 | pub fn multiply_no_reduce( 590 | res: &mut PolyMatrixNTT, 591 | a: &PolyMatrixNTT, 592 | b: &PolyMatrixNTT, 593 | start_inner_dim: usize, 594 | ) { 595 | assert_eq!(res.rows, a.rows); 596 | assert_eq!(res.cols, b.cols); 597 | assert_eq!(a.cols, b.rows); 598 | 599 | let params = res.params; 600 | for i in 0..a.rows { 601 | for j in 0..b.cols { 602 | let res_poly = res.get_poly_mut(i, j); 603 | for k in start_inner_dim..a.cols { 604 | let pol1 = a.get_poly(i, k); 605 | let pol2 = b.get_poly(k, j); 606 | multiply_add_poly_avx(params, res_poly, pol1, pol2); 607 | } 608 | } 609 | } 610 | } 611 | 612 | pub fn add(res: &mut PolyMatrixNTT, a: &PolyMatrixNTT, b: &PolyMatrixNTT) { 613 | assert!(res.rows == a.rows); 614 | assert!(res.cols == a.cols); 615 | assert!(a.rows == b.rows); 616 | assert!(a.cols == b.cols); 617 | 618 | let params = res.params; 619 | for i in 0..a.rows { 620 | for j in 0..a.cols { 621 | let res_poly = res.get_poly_mut(i, j); 622 | let pol1 = a.get_poly(i, j); 623 | let pol2 = b.get_poly(i, j); 624 | add_poly(params, res_poly, pol1, pol2); 625 | } 626 | } 627 | } 628 | 629 | pub fn add_raw(res: &mut PolyMatrixRaw, a: &PolyMatrixRaw, b: &PolyMatrixRaw) { 630 | assert_eq!(res.rows, a.rows); 631 | assert_eq!(res.cols, a.cols); 632 | assert_eq!(a.rows, b.rows); 633 | assert_eq!(a.cols, b.cols); 634 | 635 | let params = res.params; 636 | for i in 0..a.rows { 637 | for j in 0..a.cols { 638 | let res_poly = res.get_poly_mut(i, j); 639 | let pol1 = a.get_poly(i, j); 640 | let pol2 = b.get_poly(i, j); 641 | add_poly_raw(params, res_poly, pol1, pol2); 642 | } 643 | } 644 | } 645 | 646 | pub fn add_into(res: &mut PolyMatrixNTT, a: &PolyMatrixNTT) { 647 | assert!(res.rows == a.rows); 648 | assert!(res.cols == a.cols); 649 | 650 | let params = res.params; 651 | for i in 0..res.rows { 652 | for j in 0..res.cols { 653 | let res_poly = res.get_poly_mut(i, j); 654 | let pol2 = a.get_poly(i, j); 655 | add_poly_into(params, res_poly, pol2); 656 | } 657 | } 658 | } 659 | 660 | pub fn sub_into(res: &mut PolyMatrixNTT, a: &PolyMatrixNTT) { 661 | assert!(res.rows == a.rows); 662 | assert!(res.cols == a.cols); 663 | 664 | let params = res.params; 665 | for i in 0..res.rows { 666 | for j in 0..res.cols { 667 | let res_poly = res.get_poly_mut(i, j); 668 | let pol2 = a.get_poly(i, j); 669 | sub_poly_into(params, res_poly, pol2); 670 | } 671 | } 672 | } 673 | 674 | pub fn add_into_at(res: &mut PolyMatrixNTT, a: &PolyMatrixNTT, t_row: usize, t_col: usize) { 675 | let params = res.params; 676 | for i in 0..a.rows { 677 | for j in 0..a.cols { 678 | let res_poly = res.get_poly_mut(t_row + i, t_col + j); 679 | let pol2 = a.get_poly(i, j); 680 | add_poly_into(params, res_poly, pol2); 681 | } 682 | } 683 | } 684 | 685 | pub fn invert(res: &mut PolyMatrixRaw, a: &PolyMatrixRaw) { 686 | assert!(res.rows == a.rows); 687 | assert!(res.cols == a.cols); 688 | 689 | let params = res.params; 690 | for i in 0..a.rows { 691 | for j in 0..a.cols { 692 | let res_poly = res.get_poly_mut(i, j); 693 | let pol1 = a.get_poly(i, j); 694 | invert_poly(params, res_poly, pol1); 695 | } 696 | } 697 | } 698 | 699 | pub fn invert_ntt(res: &mut PolyMatrixNTT, a: &PolyMatrixNTT) { 700 | assert!(res.rows == a.rows); 701 | assert!(res.cols == a.cols); 702 | 703 | let params = res.params; 704 | for i in 0..a.rows { 705 | for j in 0..a.cols { 706 | let res_poly = res.get_poly_mut(i, j); 707 | let pol1 = a.get_poly(i, j); 708 | invert_poly_ntt(params, res_poly, pol1); 709 | } 710 | } 711 | } 712 | 713 | pub fn automorph<'a>(res: &mut PolyMatrixRaw<'a>, a: &PolyMatrixRaw<'a>, t: usize) { 714 | assert!(res.rows == a.rows); 715 | assert!(res.cols == a.cols); 716 | 717 | let params = res.params; 718 | for i in 0..a.rows { 719 | for j in 0..a.cols { 720 | let res_poly = res.get_poly_mut(i, j); 721 | let pol1 = a.get_poly(i, j); 722 | automorph_poly(params, res_poly, pol1, t); 723 | } 724 | } 725 | } 726 | 727 | pub fn automorph_alloc<'a>(a: &PolyMatrixRaw<'a>, t: usize) -> PolyMatrixRaw<'a> { 728 | let mut res = PolyMatrixRaw::zero(a.params, a.rows, a.cols); 729 | automorph(&mut res, a, t); 730 | res 731 | } 732 | 733 | pub fn stack<'a>(a: &PolyMatrixRaw<'a>, b: &PolyMatrixRaw<'a>) -> PolyMatrixRaw<'a> { 734 | assert_eq!(a.cols, b.cols); 735 | let mut c = PolyMatrixRaw::zero(a.params, a.rows + b.rows, a.cols); 736 | c.copy_into(a, 0, 0); 737 | c.copy_into(b, a.rows, 0); 738 | c 739 | } 740 | 741 | pub fn stack_ntt<'a>(a: &PolyMatrixNTT<'a>, b: &PolyMatrixNTT<'a>) -> PolyMatrixNTT<'a> { 742 | assert_eq!(a.cols, b.cols); 743 | let mut c = PolyMatrixNTT::zero(a.params, a.rows + b.rows, a.cols); 744 | c.copy_into(a, 0, 0); 745 | c.copy_into(b, a.rows, 0); 746 | c 747 | } 748 | 749 | pub fn scalar_multiply(res: &mut PolyMatrixNTT, a: &PolyMatrixNTT, b: &PolyMatrixNTT) { 750 | assert_eq!(a.rows, 1); 751 | assert_eq!(a.cols, 1); 752 | 753 | let params = res.params; 754 | let pol2 = a.get_poly(0, 0); 755 | for i in 0..b.rows { 756 | for j in 0..b.cols { 757 | let res_poly = res.get_poly_mut(i, j); 758 | let pol1 = b.get_poly(i, j); 759 | multiply_poly(params, res_poly, pol1, pol2); 760 | } 761 | } 762 | } 763 | 764 | pub fn scalar_multiply_alloc<'a>( 765 | a: &PolyMatrixNTT<'a>, 766 | b: &PolyMatrixNTT<'a>, 767 | ) -> PolyMatrixNTT<'a> { 768 | let mut res = PolyMatrixNTT::zero(b.params, b.rows, b.cols); 769 | scalar_multiply(&mut res, a, b); 770 | res 771 | } 772 | 773 | pub fn single_poly<'a>(params: &'a Params, val: u64) -> PolyMatrixRaw<'a> { 774 | let mut res = PolyMatrixRaw::zero(params, 1, 1); 775 | res.get_poly_mut(0, 0)[0] = val; 776 | res 777 | } 778 | 779 | fn reduce_copy(params: &Params, out: &mut [u64], inp: &[u64]) { 780 | for n in 0..params.crt_count { 781 | for z in 0..params.poly_len { 782 | out[n * params.poly_len + z] = barrett_coeff_u64(params, inp[z], n); 783 | } 784 | } 785 | } 786 | 787 | pub fn to_ntt(a: &mut PolyMatrixNTT, b: &PolyMatrixRaw) { 788 | let params = a.params; 789 | for r in 0..a.rows { 790 | for c in 0..a.cols { 791 | let pol_src = b.get_poly(r, c); 792 | let pol_dst = a.get_poly_mut(r, c); 793 | reduce_copy(params, pol_dst, pol_src); 794 | ntt_forward(params, pol_dst); 795 | } 796 | } 797 | } 798 | 799 | pub fn to_ntt_no_reduce(a: &mut PolyMatrixNTT, b: &PolyMatrixRaw) { 800 | let params = a.params; 801 | for r in 0..a.rows { 802 | for c in 0..a.cols { 803 | let pol_src = b.get_poly(r, c); 804 | let pol_dst = a.get_poly_mut(r, c); 805 | for n in 0..params.crt_count { 806 | let idx = n * params.poly_len; 807 | pol_dst[idx..idx + params.poly_len].copy_from_slice(pol_src); 808 | } 809 | ntt_forward(params, pol_dst); 810 | } 811 | } 812 | } 813 | 814 | pub fn to_ntt_alloc<'a>(b: &PolyMatrixRaw<'a>) -> PolyMatrixNTT<'a> { 815 | let mut a = PolyMatrixNTT::zero(b.params, b.rows, b.cols); 816 | to_ntt(&mut a, b); 817 | a 818 | } 819 | 820 | pub fn from_ntt(a: &mut PolyMatrixRaw, b: &PolyMatrixNTT) { 821 | let params = a.params; 822 | SCRATCH.with(|scratch_cell| { 823 | let scratch_vec = &mut *scratch_cell.borrow_mut(); 824 | let scratch = scratch_vec.as_mut_slice(); 825 | for r in 0..a.rows { 826 | for c in 0..a.cols { 827 | let pol_src = b.get_poly(r, c); 828 | let pol_dst = a.get_poly_mut(r, c); 829 | scratch[0..pol_src.len()].copy_from_slice(pol_src); 830 | ntt_inverse(params, scratch); 831 | for z in 0..params.poly_len { 832 | pol_dst[z] = params.crt_compose(scratch, z); 833 | } 834 | } 835 | } 836 | }); 837 | } 838 | 839 | pub fn from_ntt_scratch(a: &mut PolyMatrixRaw, scratch: &mut [u64], b: &PolyMatrixNTT) { 840 | assert_eq!(b.rows, 2); 841 | assert_eq!(b.cols, 1); 842 | 843 | let params = b.params; 844 | for r in 0..b.rows { 845 | for c in 0..b.cols { 846 | let pol_src = b.get_poly(r, c); 847 | scratch[0..pol_src.len()].copy_from_slice(pol_src); 848 | ntt_inverse(params, scratch); 849 | if r == 0 { 850 | let pol_dst = a.get_poly_mut(r, c); 851 | for z in 0..params.poly_len { 852 | pol_dst[z] = params.crt_compose(scratch, z); 853 | } 854 | } 855 | } 856 | } 857 | } 858 | 859 | pub fn from_ntt_alloc<'a>(b: &PolyMatrixNTT<'a>) -> PolyMatrixRaw<'a> { 860 | let mut a = PolyMatrixRaw::zero(b.params, b.rows, b.cols); 861 | from_ntt(&mut a, b); 862 | a 863 | } 864 | 865 | impl<'a, 'b> Neg for &'b PolyMatrixRaw<'a> { 866 | type Output = PolyMatrixRaw<'a>; 867 | 868 | fn neg(self) -> Self::Output { 869 | let mut out = PolyMatrixRaw::zero(self.params, self.rows, self.cols); 870 | invert(&mut out, self); 871 | out 872 | } 873 | } 874 | 875 | impl<'a, 'b> Neg for &'b PolyMatrixNTT<'a> { 876 | type Output = PolyMatrixNTT<'a>; 877 | 878 | fn neg(self) -> Self::Output { 879 | let mut out = PolyMatrixNTT::zero(self.params, self.rows, self.cols); 880 | invert_ntt(&mut out, self); 881 | out 882 | } 883 | } 884 | 885 | impl<'a, 'b> Mul for &'b PolyMatrixNTT<'a> { 886 | type Output = PolyMatrixNTT<'a>; 887 | 888 | fn mul(self, rhs: Self) -> Self::Output { 889 | let mut out = PolyMatrixNTT::zero(self.params, self.rows, rhs.cols); 890 | multiply(&mut out, self, rhs); 891 | out 892 | } 893 | } 894 | 895 | impl<'a, 'b> Add for &'b PolyMatrixNTT<'a> { 896 | type Output = PolyMatrixNTT<'a>; 897 | 898 | fn add(self, rhs: Self) -> Self::Output { 899 | let mut out = PolyMatrixNTT::zero(self.params, self.rows, self.cols); 900 | add(&mut out, self, rhs); 901 | out 902 | } 903 | } 904 | 905 | impl<'a, 'b> Add for &'b PolyMatrixRaw<'a> { 906 | type Output = PolyMatrixRaw<'a>; 907 | 908 | fn add(self, rhs: Self) -> Self::Output { 909 | let mut out = PolyMatrixRaw::zero(self.params, self.rows, self.cols); 910 | add_raw(&mut out, self, rhs); 911 | out 912 | } 913 | } 914 | 915 | #[cfg(test)] 916 | mod test { 917 | use super::*; 918 | 919 | fn get_params() -> Params { 920 | get_test_params() 921 | } 922 | 923 | fn assert_all_zero(a: &[u64]) { 924 | for i in a { 925 | assert_eq!(*i, 0); 926 | } 927 | } 928 | 929 | #[test] 930 | fn sets_all_zeros() { 931 | let params = get_params(); 932 | let m1 = PolyMatrixNTT::zero(¶ms, 2, 1); 933 | assert_all_zero(m1.as_slice()); 934 | } 935 | 936 | #[test] 937 | fn multiply_correctness() { 938 | let params = get_params(); 939 | let m1 = PolyMatrixNTT::zero(¶ms, 2, 1); 940 | let m2 = PolyMatrixNTT::zero(¶ms, 3, 2); 941 | let m3 = &m2 * &m1; 942 | assert_all_zero(m3.as_slice()); 943 | } 944 | 945 | #[test] 946 | fn full_multiply_correctness() { 947 | let params = get_params(); 948 | let mut m1 = PolyMatrixRaw::zero(¶ms, 1, 1); 949 | let mut m2 = PolyMatrixRaw::zero(¶ms, 1, 1); 950 | m1.get_poly_mut(0, 0)[1] = 100; 951 | m2.get_poly_mut(0, 0)[1] = 7; 952 | let m1_ntt = to_ntt_alloc(&m1); 953 | let m2_ntt = to_ntt_alloc(&m2); 954 | let m3_ntt = &m1_ntt * &m2_ntt; 955 | let m3 = from_ntt_alloc(&m3_ntt); 956 | assert_eq!(m3.get_poly(0, 0)[2], 700); 957 | } 958 | 959 | fn get_alt_params() -> Params { 960 | Params::init( 961 | 2048, 962 | &vec![180143985094819841u64], 963 | 6.4, 964 | 2, 965 | 256, 966 | 20, 967 | 4, 968 | 8, 969 | 56, 970 | 8, 971 | true, 972 | 9, 973 | 6, 974 | 1, 975 | 2048, 976 | 0, 977 | ) 978 | } 979 | 980 | #[test] 981 | fn alt_full_multiply_correctness() { 982 | let params = get_alt_params(); 983 | let mut m1 = PolyMatrixRaw::zero(¶ms, 1, 1); 984 | let mut m2 = PolyMatrixRaw::zero(¶ms, 1, 1); 985 | m1.get_poly_mut(0, 0)[1] = 100; 986 | m2.get_poly_mut(0, 0)[1] = 7; 987 | let m1_ntt = to_ntt_alloc(&m1); 988 | let m2_ntt = to_ntt_alloc(&m2); 989 | let m3_ntt = &m1_ntt * &m2_ntt; 990 | let m3 = from_ntt_alloc(&m3_ntt); 991 | assert_eq!(m3.get_poly(0, 0)[2], 700); 992 | } 993 | 994 | #[test] 995 | fn to_vec_correctness() { 996 | let params = get_params(); 997 | let mut m1 = PolyMatrixRaw::zero(¶ms, 1, 1); 998 | for i in 0..params.poly_len { 999 | m1.data[i] = 1; 1000 | } 1001 | let modulus_bits = 9; 1002 | let v = m1.to_vec(modulus_bits, params.poly_len); 1003 | for i in 0..v.len() { 1004 | println!("{:?}", v[i]); 1005 | } 1006 | let mut bit_offs = 0; 1007 | for i in 0..params.poly_len { 1008 | let val = read_arbitrary_bits(v.as_slice(), bit_offs, modulus_bits); 1009 | assert_eq!(m1.data[i], val); 1010 | bit_offs += modulus_bits; 1011 | } 1012 | } 1013 | } 1014 | -------------------------------------------------------------------------------- /src/util.rs: -------------------------------------------------------------------------------- 1 | use crate::{arith::*, client::Seed, params::*, poly::*}; 2 | use rand::{prelude::SmallRng, thread_rng, Rng, SeedableRng}; 3 | use rand_chacha::ChaCha20Rng; 4 | use serde_json::Value; 5 | use std::fs; 6 | 7 | pub const CFG_20_256: &'static str = r#" 8 | {'n': 2, 9 | 'nu_1': 9, 10 | 'nu_2': 6, 11 | 'p': 256, 12 | 'q2_bits': 20, 13 | 's_e': 87.62938774292914, 14 | 't_gsw': 8, 15 | 't_conv': 4, 16 | 't_exp_left': 8, 17 | 't_exp_right': 56, 18 | 'instances': 1, 19 | 'db_item_size': 8192 } 20 | "#; 21 | pub const CFG_16_100000: &'static str = r#" 22 | {'n': 2, 23 | 'nu_1': 10, 24 | 'nu_2': 6, 25 | 'p': 512, 26 | 'q2_bits': 21, 27 | 's_e': 85.83255142749422, 28 | 't_gsw': 10, 29 | 't_conv': 4, 30 | 't_exp_left': 16, 31 | 't_exp_right': 56, 32 | 'instances': 11, 33 | 'db_item_size': 100000 } 34 | "#; 35 | 36 | pub fn calc_index(indices: &[usize], lengths: &[usize]) -> usize { 37 | let mut idx = 0usize; 38 | let mut prod = 1usize; 39 | for i in (0..indices.len()).rev() { 40 | idx += indices[i] * prod; 41 | prod *= lengths[i]; 42 | } 43 | idx 44 | } 45 | 46 | pub fn decompose_index(indices: &mut [usize], index: usize, lengths: &[usize]) { 47 | let mut cur = index; 48 | let mut prod = 1usize; 49 | for i in 1..lengths.len() { 50 | prod *= lengths[i]; 51 | } 52 | 53 | for i in 0..lengths.len() { 54 | let val = cur / prod; 55 | cur -= val * prod; 56 | indices[i] = val; 57 | if i < lengths.len() - 1 { 58 | prod /= lengths[i + 1]; 59 | } 60 | } 61 | } 62 | 63 | pub fn get_test_params() -> Params { 64 | Params::init( 65 | 2048, 66 | &vec![268369921u64, 249561089u64], 67 | 6.4, 68 | 2, 69 | 256, 70 | 20, 71 | 4, 72 | 8, 73 | 56, 74 | 8, 75 | true, 76 | 9, 77 | 6, 78 | 1, 79 | 2048, 80 | 0, 81 | ) 82 | } 83 | 84 | pub fn get_short_keygen_params() -> Params { 85 | Params::init( 86 | 2048, 87 | &vec![268369921u64, 249561089u64], 88 | 6.4, 89 | 2, 90 | 256, 91 | 20, 92 | 4, 93 | 4, 94 | 4, 95 | 4, 96 | true, 97 | 9, 98 | 6, 99 | 1, 100 | 2048, 101 | 0, 102 | ) 103 | } 104 | 105 | pub fn get_expansion_testing_params() -> Params { 106 | let cfg = r#" 107 | {'n': 2, 108 | 'nu_1': 9, 109 | 'nu_2': 6, 110 | 'p': 256, 111 | 'q2_bits': 20, 112 | 't_gsw': 8, 113 | 't_conv': 4, 114 | 't_exp_left': 8, 115 | 't_exp_right': 56, 116 | 'instances': 1, 117 | 'db_item_size': 8192 } 118 | "#; 119 | params_from_json(&cfg.replace("'", "\"")) 120 | } 121 | 122 | pub fn get_fast_expansion_testing_params() -> Params { 123 | let cfg = r#" 124 | {'n': 2, 125 | 'nu_1': 6, 126 | 'nu_2': 2, 127 | 'p': 256, 128 | 'q2_bits': 20, 129 | 't_gsw': 8, 130 | 't_conv': 4, 131 | 't_exp_left': 8, 132 | 't_exp_right': 8, 133 | 'instances': 1, 134 | 'db_item_size': 8192 } 135 | "#; 136 | params_from_json(&cfg.replace("'", "\"")) 137 | } 138 | 139 | pub fn get_no_expansion_testing_params() -> Params { 140 | let cfg = r#" 141 | {'direct_upload': 1, 142 | 'n': 5, 143 | 'nu_1': 6, 144 | 'nu_2': 3, 145 | 'p': 65536, 146 | 'q2_bits': 27, 147 | 't_gsw': 3, 148 | 't_conv': 56, 149 | 't_exp_left': 56, 150 | 't_exp_right': 56} 151 | "#; 152 | params_from_json(&cfg.replace("'", "\"")) 153 | } 154 | 155 | pub fn get_seed() -> u64 { 156 | thread_rng().gen::() 157 | } 158 | 159 | pub fn get_seeded_rng() -> SmallRng { 160 | SmallRng::seed_from_u64(get_seed()) 161 | } 162 | 163 | pub fn get_chacha_seed() -> Seed { 164 | thread_rng().gen::<[u8; 32]>() 165 | } 166 | 167 | pub fn get_chacha_rng() -> ChaCha20Rng { 168 | ChaCha20Rng::from_seed(get_chacha_seed()) 169 | } 170 | 171 | pub fn get_static_seed() -> u64 { 172 | 0x123456789 173 | } 174 | 175 | pub fn get_chacha_static_seed() -> Seed { 176 | [ 177 | 0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8, 0x9, 0xa, 0xb, 0xc, 0xd, 0xe, 0xf, 0x0, 0x1, 178 | 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8, 0x9, 0xa, 0xb, 0xc, 0xd, 0xe, 0xf, 179 | ] 180 | } 181 | 182 | pub fn get_static_seeded_rng() -> SmallRng { 183 | SmallRng::seed_from_u64(get_static_seed()) 184 | } 185 | 186 | pub const fn get_empty_params() -> Params { 187 | Params { 188 | poly_len: 0, 189 | poly_len_log2: 0, 190 | ntt_tables: Vec::new(), 191 | scratch: Vec::new(), 192 | crt_count: 0, 193 | barrett_cr_0_modulus: 0, 194 | barrett_cr_1_modulus: 0, 195 | barrett_cr_0: [0u64; MAX_MODULI], 196 | barrett_cr_1: [0u64; MAX_MODULI], 197 | mod0_inv_mod1: 0, 198 | mod1_inv_mod0: 0, 199 | moduli: [0u64; MAX_MODULI], 200 | modulus: 0, 201 | modulus_log2: 0, 202 | noise_width: 0f64, 203 | n: 0, 204 | pt_modulus: 0, 205 | q2_bits: 0, 206 | t_conv: 0, 207 | t_exp_left: 0, 208 | t_exp_right: 0, 209 | t_gsw: 0, 210 | expand_queries: false, 211 | db_dim_1: 0, 212 | db_dim_2: 0, 213 | instances: 0, 214 | db_item_size: 0, 215 | version: 0, 216 | } 217 | } 218 | 219 | pub fn params_from_json(cfg: &str) -> Params { 220 | let v: Value = serde_json::from_str(cfg).unwrap(); 221 | params_from_json_obj(&v) 222 | } 223 | 224 | pub fn params_from_json_obj(v: &Value) -> Params { 225 | let n = v["n"].as_u64().unwrap() as usize; 226 | let db_dim_1 = v["nu_1"].as_u64().unwrap() as usize; 227 | let db_dim_2 = v["nu_2"].as_u64().unwrap() as usize; 228 | let instances = v["instances"].as_u64().unwrap_or(1) as usize; 229 | let p = v["p"].as_u64().unwrap(); 230 | let q2_bits = u64::max(v["q2_bits"].as_u64().unwrap(), MIN_Q2_BITS); 231 | let t_gsw = v["t_gsw"].as_u64().unwrap() as usize; 232 | let t_conv = v["t_conv"].as_u64().unwrap() as usize; 233 | let t_exp_left = v["t_exp_left"].as_u64().unwrap() as usize; 234 | let t_exp_right = v["t_exp_right"].as_u64().unwrap() as usize; 235 | let do_expansion = v.get("direct_upload").is_none(); 236 | 237 | let mut db_item_size = v["db_item_size"].as_u64().unwrap_or(0) as usize; 238 | if db_item_size == 0 { 239 | db_item_size = instances * n * n; 240 | db_item_size = db_item_size * 2048 * log2_ceil(p) as usize / 8; 241 | } 242 | 243 | let version = v["version"].as_u64().unwrap_or(0) as usize; 244 | 245 | Params::init( 246 | 2048, 247 | &vec![268369921u64, 249561089u64], 248 | 6.4, 249 | n, 250 | p, 251 | q2_bits, 252 | t_conv, 253 | t_exp_left, 254 | t_exp_right, 255 | t_gsw, 256 | do_expansion, 257 | db_dim_1, 258 | db_dim_2, 259 | instances, 260 | db_item_size, 261 | version, 262 | ) 263 | } 264 | 265 | static ALL_PARAMS_STORE_FNAME: &str = "../params_store.json"; 266 | 267 | pub fn get_params_from_store(target_num_log2: usize, item_size: usize) -> Params { 268 | let params_store_str = fs::read_to_string(ALL_PARAMS_STORE_FNAME).unwrap(); 269 | let v: Value = serde_json::from_str(¶ms_store_str).unwrap(); 270 | let nearest_target_num = target_num_log2; 271 | let nearest_item_size = 1 << usize::max(log2_ceil_usize(item_size), 8); 272 | println!( 273 | "Starting with parameters for 2^{} x {} bytes...", 274 | nearest_target_num, nearest_item_size 275 | ); 276 | let target = v 277 | .as_array() 278 | .unwrap() 279 | .iter() 280 | .map(|x| x.as_object().unwrap()) 281 | .filter(|x| x.get("target_num").unwrap().as_u64().unwrap() == (nearest_target_num as u64)) 282 | .filter(|x| x.get("item_size").unwrap().as_u64().unwrap() == (nearest_item_size as u64)) 283 | .map(|x| x.get("params").unwrap()) 284 | .next() 285 | .unwrap(); 286 | params_from_json_obj(target) 287 | } 288 | 289 | pub fn read_arbitrary_bits(data: &[u8], bit_offs: usize, num_bits: usize) -> u64 { 290 | let word_off = bit_offs / 64; 291 | let bit_off_within_word = bit_offs % 64; 292 | if (bit_off_within_word + num_bits) <= 64 { 293 | let idx = word_off * 8; 294 | let val = u64::from_ne_bytes(data[idx..idx + 8].try_into().unwrap()); 295 | (val >> bit_off_within_word) & ((1u64 << num_bits) - 1) 296 | } else { 297 | let idx = word_off * 8; 298 | let val = u128::from_ne_bytes(data[idx..idx + 16].try_into().unwrap()); 299 | ((val >> bit_off_within_word) & ((1u128 << num_bits) - 1)) as u64 300 | } 301 | } 302 | 303 | pub fn write_arbitrary_bits(data: &mut [u8], mut val: u64, bit_offs: usize, num_bits: usize) { 304 | let word_off = bit_offs / 64; 305 | let bit_off_within_word = bit_offs % 64; 306 | val = val & ((1u64 << num_bits) - 1); 307 | if (bit_off_within_word + num_bits) <= 64 { 308 | let idx = word_off * 8; 309 | let mut cur_val = u64::from_ne_bytes(data[idx..idx + 8].try_into().unwrap()); 310 | cur_val &= !(((1u64 << num_bits) - 1) << bit_off_within_word); 311 | cur_val |= val << bit_off_within_word; 312 | data[idx..idx + 8].copy_from_slice(&u64::to_ne_bytes(cur_val)); 313 | } else { 314 | let idx = word_off * 8; 315 | let mut cur_val = u128::from_ne_bytes(data[idx..idx + 16].try_into().unwrap()); 316 | let mask = !(((1u128 << num_bits) - 1) << bit_off_within_word); 317 | cur_val &= mask; 318 | cur_val |= (val as u128) << bit_off_within_word; 319 | data[idx..idx + 16].copy_from_slice(&u128::to_ne_bytes(cur_val)); 320 | } 321 | } 322 | 323 | pub fn reorient_reg_ciphertexts(params: &Params, out: &mut [u64], v_reg: &Vec) { 324 | let poly_len = params.poly_len; 325 | let crt_count = params.crt_count; 326 | 327 | assert_eq!(crt_count, 2); 328 | assert!(log2(params.moduli[0]) <= 32); 329 | 330 | let num_reg_expanded = 1 << params.db_dim_1; 331 | let ct_rows = v_reg[0].rows; 332 | let ct_cols = v_reg[0].cols; 333 | 334 | assert_eq!(ct_rows, 2); 335 | assert_eq!(ct_cols, 1); 336 | 337 | for j in 0..num_reg_expanded { 338 | for r in 0..ct_rows { 339 | for m in 0..ct_cols { 340 | for z in 0..params.poly_len { 341 | let idx_a_in = 342 | r * (ct_cols * crt_count * poly_len) + m * (crt_count * poly_len); 343 | let idx_a_out = z * (num_reg_expanded * ct_cols * ct_rows) 344 | + j * (ct_cols * ct_rows) 345 | + m * (ct_rows) 346 | + r; 347 | let val1 = v_reg[j].data[idx_a_in + z] % params.moduli[0]; 348 | let val2 = v_reg[j].data[idx_a_in + params.poly_len + z] % params.moduli[1]; 349 | 350 | out[idx_a_out] = val1 | (val2 << 32); 351 | } 352 | } 353 | } 354 | } 355 | } 356 | 357 | #[cfg(test)] 358 | mod test { 359 | use super::*; 360 | 361 | #[test] 362 | fn params_from_json_correct() { 363 | let cfg = r#" 364 | {'n': 2, 365 | 'nu_1': 9, 366 | 'nu_2': 6, 367 | 'p': 256, 368 | 'q2_bits': 20, 369 | 's_e': 87.62938774292914, 370 | 't_gsw': 8, 371 | 't_conv': 4, 372 | 't_exp_left': 8, 373 | 't_exp_right': 56, 374 | 'instances': 1, 375 | 'db_item_size': 2048 } 376 | "#; 377 | let cfg = cfg.replace("'", "\""); 378 | let b = params_from_json(&cfg); 379 | let c = Params::init( 380 | 2048, 381 | &vec![268369921u64, 249561089u64], 382 | 6.4, 383 | 2, 384 | 256, 385 | 20, 386 | 4, 387 | 8, 388 | 56, 389 | 8, 390 | true, 391 | 9, 392 | 6, 393 | 1, 394 | 2048, 395 | 0, 396 | ); 397 | assert_eq!(b, c); 398 | } 399 | 400 | #[test] 401 | fn test_decompose_calc_correct() { 402 | let lengths = [5, 4, 3]; 403 | let indices = [2, 1, 2]; 404 | let idx = calc_index(&indices, &lengths); 405 | let mut gues_indices = [0, 0, 0]; 406 | decompose_index(&mut gues_indices, idx, &lengths); 407 | assert_eq!(indices, gues_indices); 408 | } 409 | 410 | #[test] 411 | fn test_read_write_arbitrary_bits() { 412 | let len = 4096; 413 | let num_bits = 9; 414 | let mut data = vec![0u8; len]; 415 | let scaled_len = len * 8 / num_bits - 64; 416 | let mut bit_offs = 0; 417 | let get_from = |i: usize| -> u64 { ((i * 7 + 13) % (1 << num_bits)) as u64 }; 418 | for i in 0..scaled_len { 419 | write_arbitrary_bits(data.as_mut_slice(), get_from(i), bit_offs, num_bits); 420 | bit_offs += num_bits; 421 | } 422 | bit_offs = 0; 423 | for i in 0..scaled_len { 424 | let val = read_arbitrary_bits(data.as_slice(), bit_offs, num_bits); 425 | assert_eq!(val, get_from(i)); 426 | bit_offs += num_bits; 427 | } 428 | } 429 | } 430 | --------------------------------------------------------------------------------