├── .gitignore ├── LICENCE ├── Makefile ├── README.md ├── bench └── bench.cu ├── src ├── array │ ├── fixnum_array.cu │ └── fixnum_array.h ├── fixnum │ ├── internal │ │ └── primitives.cu │ ├── nail.cu │ ├── slot_layout.cu │ ├── warp_fixnum.cu │ └── word_fixnum.cu ├── functions │ ├── chinese.cu │ ├── divexact.cu │ ├── internal │ │ └── modexp_impl.cu │ ├── modexp.cu │ ├── modinv.cu │ ├── multi_modexp.cu │ ├── paillier_decrypt.cu │ ├── paillier_encrypt.cu │ ├── quorem.cu │ └── quorem_preinv.cu ├── modnum │ ├── internal │ │ └── monty.cu │ ├── modnum_monty_cios.cu │ └── modnum_monty_redc.cu └── util │ └── cuda_wrap.h └── tests ├── gentests.py └── test-suite.cu /.gitignore: -------------------------------------------------------------------------------- 1 | # Test input 2 | tests/add_* 3 | tests/sub_* 4 | tests/mul_* 5 | tests/modexp_* 6 | tests/paillier_* 7 | 8 | # Object files 9 | *.o 10 | *.ko 11 | *.obj 12 | *.elf 13 | 14 | # Precompiled Headers 15 | *.gch 16 | *.pch 17 | 18 | # Libraries 19 | *.lib 20 | *.a 21 | *.la 22 | *.lo 23 | 24 | # Shared objects (inc. Windows DLLs) 25 | *.dll 26 | *.so 27 | *.so.* 28 | *.dylib 29 | 30 | # Executables 31 | *.exe 32 | *.out 33 | *.app 34 | *.i*86 35 | *.x86_64 36 | *.hex 37 | 38 | # Debug files 39 | *.dSYM/ 40 | *.class 41 | -------------------------------------------------------------------------------- /LICENCE: -------------------------------------------------------------------------------- 1 | CSIRO Open Source Software Licence v1.0 2 | (Based on MIT/BSD Open Source Licence) 3 | 4 | IMPORTANT – PLEASE READ CAREFULLY 5 | 6 | This document contains the terms under which CSIRO agrees to licence its 7 | Software to you. This is a template and further information relevant to the 8 | licence is set out in the Supplementary Licence specific to the Software you are 9 | licensing from CSIRO. Both documents together form this agreement. 10 | 11 | The Software is copyright (c) Commonwealth Scientific and Industrial Research 12 | Organisation (CSIRO) ABN 41 687 119 230. 13 | 14 | Redistribution and use of this Software in source and binary forms, with or 15 | without modification, are permitted provided that the following conditions are 16 | met: 17 | 18 | Redistributions of source code must retain the above copyright notice, this list 19 | of conditions and the following disclaimer. Redistributions in binary form must 20 | reproduce the above copyright notice, this list of conditions and the following 21 | disclaimer in the documentation and/or other materials provided with the 22 | distribution. Neither the name of CSIRO nor the names of its contributors may be 23 | used to endorse or promote products derived from this software without specific 24 | prior written permission of CSIRO. EXCEPT AS EXPRESSLY STATED IN THIS AGREEMENT 25 | AND TO THE FULL EXTENT PERMITTED BY APPLICABLE LAW, THE SOFTWARE IS PROVIDED 26 | "AS-IS". CSIRO MAKES NO REPRESENTATIONS, WARRANTIES OR CONDITIONS OF ANY KIND, 27 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO ANY REPRESENTATIONS, WARRANTIES 28 | OR CONDITIONS REGARDING THE CONTENTS OR ACCURACY OF THE SOFTWARE, OR OF TITLE, 29 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, NON-INFRINGEMENT, THE ABSENCE 30 | OF LATENT OR OTHER DEFECTS, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT 31 | DISCOVERABLE. 32 | 33 | TO THE FULL EXTENT PERMITTED BY APPLICABLE LAW, IN NO EVENT SHALL CSIRO BE 34 | LIABLE ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, IN AN ACTION FOR 35 | BREACH OF CONTRACT, NEGLIGENCE OR OTHERWISE) FOR ANY CLAIM, LOSS, DAMAGES OR 36 | OTHER LIABILITY HOWSOEVER INCURRED. WITHOUT LIMITING THE SCOPE OF THE PREVIOUS 37 | SENTENCE THE EXCLUSION OF LIABILITY SHALL INCLUDE: LOSS OF PRODUCTION OR 38 | OPERATION TIME, LOSS, DAMAGE OR CORRUPTION OF DATA OR RECORDS; OR LOSS OF 39 | ANTICIPATED SAVINGS, OPPORTUNITY, REVENUE, PROFIT OR GOODWILL, OR OTHER ECONOMIC 40 | LOSS; OR ANY SPECIAL, INCIDENTAL, INDIRECT, CONSEQUENTIAL, PUNITIVE OR EXEMPLARY 41 | DAMAGES, ARISING OUT OF OR IN CONNECTION WITH THIS AGREEMENT, ACCESS OF THE 42 | SOFTWARE OR ANY OTHER DEALINGS WITH THE SOFTWARE, EVEN IF CSIRO HAS BEEN ADVISED 43 | OF THE POSSIBILITY OF SUCH CLAIM, LOSS, DAMAGES OR OTHER LIABILITY. 44 | 45 | APPLICABLE LEGISLATION SUCH AS THE AUSTRALIAN CONSUMER LAW MAY APPLY 46 | REPRESENTATIONS, WARRANTIES, OR CONDITIONS, OR IMPOSES OBLIGATIONS OR LIABILITY 47 | ON CSIRO THAT CANNOT BE EXCLUDED, RESTRICTED OR MODIFIED TO THE FULL EXTENT SET 48 | OUT IN THE EXPRESS TERMS OF THIS CLAUSE ABOVE "CONSUMER GUARANTEES". TO THE 49 | EXTENT THAT SUCH CONSUMER GUARANTEES CONTINUE TO APPLY, THEN TO THE FULL EXTENT 50 | PERMITTED BY THE APPLICABLE LEGISLATION, THE LIABILITY OF CSIRO UNDER THE 51 | RELEVANT CONSUMER GUARANTEE IS LIMITED (WHERE PERMITTED AT CSIRO’S OPTION) TO 52 | ONE OF FOLLOWING REMEDIES OR SUBSTANTIALLY EQUIVALENT REMEDIES: 53 | 54 | (a) THE REPLACEMENT OF THE SOFTWARE, THE SUPPLY OF EQUIVALENT SOFTWARE, OR 55 | SUPPLYING RELEVANT SERVICES AGAIN; 56 | 57 | (b) THE REPAIR OF THE SOFTWARE; 58 | 59 | (c) THE PAYMENT OF THE COST OF REPLACING THE SOFTWARE, OF ACQUIRING EQUIVALENT 60 | SOFTWARE, HAVING THE RELEVANT SERVICES SUPPLIED AGAIN, OR HAVING THE SOFTWARE 61 | REPAIRED. 62 | 63 | IN THIS CLAUSE, CSIRO INCLUDES ANY THIRD PARTY AUTHOR OR OWNER OF ANY PART OF 64 | THE SOFTWARE OR MATERIAL DISTRIBUTED WITH IT. CSIRO MAY ENFORCE ANY RIGHTS ON 65 | BEHALF OF THE RELEVANT THIRD PARTY. 66 | 67 | If you intend to access the Software in connection with your employment or as an 68 | agent for a principal, you should only accept this agreement if you have been 69 | authorised to do so by your employer or principal (as applicable). By accepting 70 | this agreement, you are warranting to CSIRO that you are authorised to do so on 71 | behalf of your employer or principal (as applicable). 72 | 73 | The Software may contain third party material obtained by CSIRO under licence. 74 | Your rights to such material as part of the Software under this agreement is 75 | subject to any separate licence terms identified by CSIRO as part of the 76 | Software release - including as part of the Supplementary Licence, or as a 77 | separate file. Those third party licence terms may require you to download the 78 | relevant software from a third party site, or may mean that the third party 79 | licensor (and not CSIRO) grants you a licence directly for those components of 80 | the Software. It is your responsibility to ensure that you have the necessary 81 | rights to such third party material. 82 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | CXX ?= g++ 2 | GENCODES ?= 50 3 | 4 | INCLUDE_DIRS = -I./src 5 | NVCC_FLAGS = -ccbin $(CXX) -std=c++11 -Xcompiler -Wall,-Wextra 6 | NVCC_OPT_FLAGS = -DNDEBUG 7 | NVCC_TEST_FLAGS = -lineinfo 8 | NVCC_DBG_FLAGS = -g -G 9 | NVCC_LIBS = -lstdc++ 10 | NVCC_TEST_LIBS = -lgtest 11 | 12 | all: 13 | @echo "Please run 'make check' or 'make bench'." 14 | 15 | tests/test-suite: tests/test-suite.cu 16 | nvcc $(NVCC_TEST_FLAGS) $(NVCC_FLAGS) $(GENCODES:%=--gpu-architecture=compute_%) $(GENCODES:%=--gpu-code=sm_%) $(INCLUDE_DIRS) $(NVCC_LIBS) $(NVCC_TEST_LIBS) -o $@ $< 17 | 18 | check: tests/test-suite 19 | @./tests/test-suite 20 | 21 | bench/bench: bench/bench.cu 22 | nvcc $(NVCC_OPT_FLAGS) $(NVCC_FLAGS) $(GENCODES:%=--gpu-architecture=compute_%) $(GENCODES:%=--gpu-code=sm_%) $(INCLUDE_DIRS) $(NVCC_LIBS) -o $@ $< 23 | 24 | bench: bench/bench 25 | 26 | .PHONY: clean 27 | clean: 28 | $(RM) tests/test-suite bench/bench 29 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # cuda-fixnum 2 | 3 | **NB: This repository is no longer maintained. All development on cuda-fixnum now takes place at [unzvfu/cuda-fixnum](https://github.com/unzvfu/cuda-fixnum).** 4 | 5 | `cuda-fixnum` is a fixed-precision SIMD library that targets CUDA. It provides the apparatus necessary to easily create efficient functions that operate on vectors of _n_-bit integers, where _n_ can be much larger than the size of a usual machine or device register. Currently supported values of _n_ are 32, 64, 128, 256, 512, 1024, and 2048 (larger values will be possible in a forthcoming release). 6 | 7 | The primary use case for fast arithmetic of numbers in the range covered by `cuda-fixnum` is in cryptography and computational number theory; in particular it can form an integral part in accelerating homomorphic encryption primitives as used in privacy-preserving machine learning. As such, special attention is given to support modular arithmetic; this is used in an example implementation of the Paillier additively homomorphic encryption scheme and of elliptic curve scalar multiplication. Future releases will provide additional support for operations useful to implementing Ring-LWE-based somewhat homomorphic encryption schemes. 8 | 9 | Finally, the library is designed to be _fast_. Through exploitation of warp-synchronous programming, vote functions, and deferred carry handling, the primitives of the library are currently competitive with the state-of-the-art in the literature for modular multiplication and modular exponentiation on GPUs. The design of the library allows transparent substitution of the underlying arithmetic, allowing the user to select whichever performs best on the available hardware. Moreover, several algorithms, both novel and from the literature, will be incorporated shortly that will improve performance by a further 25-50%. 10 | 11 | The library is currently at the _alpha_ stage of development. It has many rough edges, but most features are present and it is performant enough to be competitive. Comments, questions and contributions are welcome! 12 | 13 | ## Example 14 | 15 | To get a feel for what it's like to use the library, let's consider a simple example. Here is an [implementation](cuda-fixnum/src/functions/paillier_encrypt.cu) of encryption in the [Paillier cryptosystem](https://en.wikipedia.org/wiki/Paillier_cryptosystem): 16 | ```cuda 17 | #include "functions/quorem_preinv.cu" 18 | #include "functions/modexp.cu" 19 | 20 | using namespace cuFIXNUM; 21 | 22 | template< typename fixnum > 23 | class paillier_encrypt { 24 | fixnum n; // public key 25 | fixnum n_sqr; // ciphertext modulus n^2 26 | modexp pow; // function x |--> x^n (mod n^2) 27 | quorem_preinv mod_n2; // function x |--> x (mod n^2) 28 | 29 | // This is a utility whose purpose is to allow calculating 30 | // and using n^2 in the constructor initialisation list. 31 | __device__ fixnum square(fixnum n) { 32 | fixnum n2; 33 | fixnum::sqr_lo(n2, n); 34 | return n2; 35 | } 36 | 37 | public: 38 | /* 39 | * Construct an encryption functor with public key n_. 40 | * n_ must be less fixnum::BITS/2 bits in order for n^2 41 | * to fit in a fixnum. 42 | */ 43 | __device__ paillier_encrypt(fixnum n_) 44 | : n(n_) 45 | , n_sqr(square(n_)) 46 | , pow(n_sqr, n_) 47 | , mod_n2(n_sqr) 48 | { } 49 | 50 | /* 51 | * Encrypt the message m using the public key n and randomness r. 52 | * Precisely, return 53 | * 54 | * ctxt <- (1 + m*n) * r^n (mod n^2) 55 | * 56 | * m and r must be at most fixnum::BITS/2 bits (m is interpreted 57 | * modulo n anyhow). 58 | */ 59 | __device__ void operator()(fixnum &ctxt, fixnum m, fixnum r) const { 60 | fixnum::mul_lo(m, m, n); // m <- m * n (lo half mult) 61 | fixnum::incr_cy(m); // m <- 1 62 | pow(r, r); // r <- r^n (mod n^2) 63 | fixnum c_hi, c_lo; // hi and lo halves of wide multiplication 64 | fixnum::mul_wide(c_hi, c_lo, m, r); // (c_hi, c_lo) <- m * r (wide mult) 65 | mod_n2(ctxt, c_hi, c_lo); // ctxt <- (c_hi, c_lo) (mod n^2) 66 | } 67 | }; 68 | ``` 69 | A few features will be common to most user-defined functions such as the one above: They will be template function objects that rely on a `fixnum`, which will be instantiated with one of the fixnum arithmetic implemententations provided, usually the [`warp_fixnum`](cuda-fixnum/src/fixnum/warp_fixnum.cu). Functions in the `fixnum` class are static and (usually) return their results in the first one or two parameters. Complicated functions that might perform precomputation, such as [modular exponentiation (`modexp`)](cuda-fixnum/src/functions/modexp.cu) and [quotient & remainder with precomputed inverse (`quorem_preinv`)](cuda-fixnum/src/functions/quorem_preinv.cu) are instance variables in the object that are initialised in the constructor. 70 | 71 | Although it is not (yet) the focus of this project to help optimise host-device communication, the [`fixnum_array`](cuda-fixnum/src/array/fixnum_array.h) facility is provided to make it easy to apply user-defined functions to data originating in the host. Using `fixnum_array` will often look something like this: 72 | ```C++ 73 | using namespace cuFIXNUM; 74 | 75 | // In this case we need to wrap paillier_encrypt above to read the 76 | // public key from memory and pass it to the constructor. 77 | __device__ uint8_t public_key[] = ...; // initialised earlier 78 | 79 | template< typename fixnum > 80 | class my_paillier_encrypt { 81 | paillier_encrypt encrypt; 82 | 83 | __device__ fixnum load_pkey() { 84 | fixnum pkey; 85 | fixnum::from_bytes(pkey, public_key, public_key_len); 86 | return pkey; 87 | } 88 | 89 | public: 90 | __device__ my_paillier_encrypt() 91 | : encrypt(load_pkey()) 92 | { } 93 | 94 | __device__ void operator()(fixnum &ctxt, fixnum m, fixnum r) const { 95 | fixnum c; // Always read into a register, then set the result. 96 | encrypt(c, m, r); 97 | ctxt = c; 98 | } 99 | }; 100 | 101 | void host_function() { 102 | ... 103 | // fixnum represents 256-byte numbers, using a 64-bit "basic fixnum". 104 | typedef warp_fixnum<256, u64_fixnum> fixnum; 105 | typedef fixnum_array fixnum_array; 106 | 107 | fixnum_array *ctxts, *ptxts, *rnds, *pkeys; 108 | 109 | int nelts = ...; // can be as much as 1e6 to 1e8 110 | // Usually this could be as much as fixnum::BYTES == 256, however 111 | // in this application it must be at most fixnum::BYTES/2 = 128. 112 | int message_bytes = ...; 113 | 114 | // Input plaintexts 115 | ptxts = fixnum_array::create(input_array, message_bytes, nelts); 116 | // Randomness 117 | rands = fixnum_array::create(random_data, fixnum::BYTES/2, nelts); 118 | // Ciphertexts will be put here 119 | ctxts = fixnum_array::create(ptxts->length()); 120 | 121 | // Map ctxts <- [paillier_encrypt(p, r) for p, r in zip(ptxts, rands)] 122 | fixnum_array::template map(ctxts, rands, ptxts); 123 | 124 | // Access results. 125 | ctxts->retrieve_all(byte_buffer, buflen); 126 | ... 127 | } 128 | ``` 129 | 130 | ## Building 131 | 132 | The build system for cuda-fixnum is currently, shall we say, _primitive_. Basically you can run `make bench` to build the benchmarking program, or `make check` to build and run the test suite. The test suite requires the [Google Test framework](https://github.com/google/googletest) to be installed. The Makefile will read in the variables `CXX` and `GENCODES` from the environment as a convenient way to specify the C++ compiler to use and the Cuda compute capability codes that you want to compile with. The defaults are `CXX = g++` and `GENCODES = 50`. 133 | 134 | ## Benchmarks 135 | 136 | Here is the output from a recent run of the benchmark with a GTX Titan X (Maxwell, 1GHz clock, 3072 cores): 137 | 138 | ``` 139 | $ bench/bench 5000000 140 | Function: mul_lo, #elts: 5000e3 141 | fixnum digit total data time Kops/s 142 | bits bits (MiB) (seconds) 143 | 32 32 19.1 0.000 24630541.9 144 | 64 32 38.1 0.000 11547344.1 145 | 128 32 76.3 0.001 5091649.7 146 | 256 32 152.6 0.003 1775568.2 147 | 512 32 305.2 0.008 619578.7 148 | 1024 32 610.4 0.030 166855.8 149 | 150 | 64 64 38.1 0.000 14619883.0 151 | 128 64 76.3 0.001 7824726.1 152 | 256 64 152.6 0.002 2908667.8 153 | 512 64 305.2 0.006 829875.5 154 | 1024 64 610.4 0.023 221749.2 155 | 2048 64 1220.7 0.087 57611.0 156 | 157 | 158 | Function: mul_wide, #elts: 5000e3 159 | fixnum digit total data time Kops/s 160 | bits bits (MiB) (seconds) 161 | 32 32 19.1 0.000 25906735.8 162 | 64 32 38.1 0.000 10775862.1 163 | 128 32 76.3 0.001 3861003.9 164 | 256 32 152.6 0.005 985998.8 165 | 512 32 305.2 0.018 271164.4 166 | 1024 32 610.4 0.060 83847.6 167 | 168 | 64 64 38.1 0.000 14662756.6 169 | 128 64 76.3 0.001 6765899.9 170 | 256 64 152.6 0.003 1904036.6 171 | 512 64 305.2 0.009 530278.9 172 | 1024 64 610.4 0.036 140024.6 173 | 2048 64 1220.7 0.129 38680.5 174 | 175 | 176 | Function: modexp, #elts: 50e3 177 | fixnum digit total data time Kops/s 178 | bits bits (MiB) (seconds) 179 | 32 32 0.2 0.000 292397.7 180 | 64 32 0.4 0.001 68306.0 181 | 128 32 0.8 0.003 16388.1 182 | 256 32 1.5 0.017 3015.5 183 | 512 32 3.1 0.108 463.6 184 | 1024 32 6.1 0.748 66.9 185 | 186 | 64 64 0.4 0.000 113378.7 187 | 128 64 0.8 0.002 20798.7 188 | 256 64 1.5 0.015 3403.2 189 | 512 64 3.1 0.105 476.8 190 | 1024 64 6.1 0.658 76.0 191 | 2048 64 12.2 4.959 10.1 192 | ``` 193 | 194 | It is interesting to note that performance is consistently better with 64-bit 195 | integer arithmetic, even though 64-bit registers are simulated on nVidia 196 | devices. Looks like my 32-bit integer arithmetic code could use some work! 197 | 198 | ## Sources 199 | 200 | The main sources for the algorithms in this library are 201 | 202 | - Brent, R. and Zimmermann, P., [_Modern Computer Arithmetic_](https://members.loria.fr/PZimmermann/mca/pub226.html), Cambridge University Press, 2010. 203 | - Menezes, A. J., van Oorschot, P. C. and Vanstone, S. A., [_Handbook of Applied Cryptography_](http://cacr.uwaterloo.ca/hac/), CRC Press, 5th printing, 2001. Chapter 14. 204 | - Granlund, T. and "the GMP development team", [_GNU MP: The GNU Multiple Precision Arithmetic Library_](https://gmplib.org), version 6.1.2. 205 | 206 | ## Author and licence 207 | 208 | The principal author of `cuda-fixnum` is Dr Hamish Ivey-Law (@unzvfu). Email: _hamish (at) ivey-law.name_ 209 | 210 | `cuda-fixnum` is copyright (c) 2016-2019 Commonwealth Scientific and Industrial Research Organisation (CSIRO). 211 | 212 | `cuda-fixnum` is released under a modified MIT/BSD Open Source Licence (see [LICENCE](LICENCE) for details). 213 | -------------------------------------------------------------------------------- /bench/bench.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | #include "fixnum/warp_fixnum.cu" 6 | #include "array/fixnum_array.h" 7 | #include "functions/modexp.cu" 8 | #include "functions/multi_modexp.cu" 9 | #include "modnum/modnum_monty_redc.cu" 10 | #include "modnum/modnum_monty_cios.cu" 11 | 12 | using namespace std; 13 | using namespace cuFIXNUM; 14 | 15 | template< typename fixnum > 16 | struct mul_lo { 17 | __device__ void operator()(fixnum &r, fixnum a) { 18 | fixnum s; 19 | fixnum::mul_lo(s, a, a); 20 | r = s; 21 | } 22 | }; 23 | 24 | template< typename fixnum > 25 | struct mul_wide { 26 | __device__ void operator()(fixnum &r, fixnum a) { 27 | fixnum rr, ss; 28 | fixnum::mul_wide(ss, rr, a, a); 29 | r = ss; 30 | } 31 | }; 32 | 33 | template< typename fixnum > 34 | struct sqr_wide { 35 | __device__ void operator()(fixnum &r, fixnum a) { 36 | fixnum rr, ss; 37 | fixnum::sqr_wide(ss, rr, a); 38 | r = ss; 39 | } 40 | }; 41 | 42 | template< typename modnum > 43 | struct my_modexp { 44 | typedef typename modnum::fixnum fixnum; 45 | 46 | __device__ void operator()(fixnum &z, fixnum x) { 47 | modexp me(x, x); 48 | fixnum zz; 49 | me(zz, x); 50 | z = zz; 51 | }; 52 | }; 53 | 54 | template< typename modnum > 55 | struct my_multi_modexp { 56 | typedef typename modnum::fixnum fixnum; 57 | 58 | __device__ void operator()(fixnum &z, fixnum x) { 59 | multi_modexp mme(x); 60 | fixnum zz; 61 | mme(zz, x, x); 62 | z = zz; 63 | }; 64 | }; 65 | 66 | template< int fn_bytes, typename word_fixnum, template class Func > 67 | void bench(int nelts) { 68 | typedef warp_fixnum fixnum; 69 | typedef fixnum_array fixnum_array; 70 | 71 | if (nelts == 0) { 72 | puts(" -*- nelts == 0; skipping... -*-"); 73 | return; 74 | } 75 | 76 | uint8_t *input = new uint8_t[fn_bytes * nelts]; 77 | for (int i = 0; i < fn_bytes * nelts; ++i) 78 | input[i] = (i * 17 + 11) % 256; 79 | 80 | fixnum_array *res, *in; 81 | in = fixnum_array::create(input, fn_bytes * nelts, fn_bytes); 82 | res = fixnum_array::create(nelts); 83 | 84 | // warm up 85 | fixnum_array::template map(res, in); 86 | 87 | clock_t c = clock(); 88 | fixnum_array::template map(res, in); 89 | c = clock() - c; 90 | 91 | double secinv = (double)CLOCKS_PER_SEC / c; 92 | double total_MiB = fixnum::BYTES * (double)nelts / (1 << 20); 93 | printf(" %4d %3d %6.1f %7.3f %12.1f\n", 94 | fixnum::BITS, fixnum::digit::BITS, total_MiB, 95 | 1/secinv, nelts * 1e-3 * secinv); 96 | 97 | delete in; 98 | delete res; 99 | delete[] input; 100 | } 101 | 102 | template< template class Func > 103 | void bench_func(const char *fn_name, int nelts) { 104 | printf("Function: %s, #elts: %de3\n", fn_name, (int)(nelts * 1e-3)); 105 | printf("fixnum digit total data time Kops/s\n"); 106 | printf(" bits bits (MiB) (seconds)\n"); 107 | bench<4, u32_fixnum, Func>(nelts); 108 | bench<8, u32_fixnum, Func>(nelts); 109 | bench<16, u32_fixnum, Func>(nelts); 110 | bench<32, u32_fixnum, Func>(nelts); 111 | bench<64, u32_fixnum, Func>(nelts); 112 | bench<128, u32_fixnum, Func>(nelts); 113 | puts(""); 114 | 115 | bench<8, u64_fixnum, Func>(nelts); 116 | bench<16, u64_fixnum, Func>(nelts); 117 | bench<32, u64_fixnum, Func>(nelts); 118 | bench<64, u64_fixnum, Func>(nelts); 119 | bench<128, u64_fixnum, Func>(nelts); 120 | bench<256, u64_fixnum, Func>(nelts); 121 | puts(""); 122 | } 123 | 124 | template< typename fixnum > 125 | using modexp_redc = my_modexp< modnum_monty_redc >; 126 | 127 | template< typename fixnum > 128 | using modexp_cios = my_modexp< modnum_monty_cios >; 129 | 130 | template< typename fixnum > 131 | using multi_modexp_redc = my_multi_modexp< modnum_monty_redc >; 132 | 133 | template< typename fixnum > 134 | using multi_modexp_cios = my_multi_modexp< modnum_monty_cios >; 135 | 136 | int main(int argc, char *argv[]) { 137 | long m = 1; 138 | if (argc > 1) 139 | m = atol(argv[1]); 140 | m = std::max(m, 1000L); 141 | 142 | bench_func("mul_lo", m); 143 | puts(""); 144 | bench_func("mul_wide", m); 145 | puts(""); 146 | bench_func("sqr_wide", m); 147 | puts(""); 148 | bench_func("modexp redc", m / 100); 149 | puts(""); 150 | bench_func("modexp cios", m / 100); 151 | puts(""); 152 | 153 | bench_func("multi modexp redc", m / 100); 154 | puts(""); 155 | bench_func("multi modexp cios", m / 100); 156 | puts(""); 157 | 158 | return 0; 159 | } 160 | -------------------------------------------------------------------------------- /src/array/fixnum_array.cu: -------------------------------------------------------------------------------- 1 | // for printing arrays 2 | #include 3 | #include 4 | #include 5 | #include 6 | // for min 7 | #include 8 | 9 | #include "util/cuda_wrap.h" 10 | #include "fixnum_array.h" 11 | 12 | namespace cuFIXNUM { 13 | 14 | // TODO: The only device function in this file is the dispatch kernel 15 | // mechanism, which could arguably be placed elsewhere, thereby 16 | // allowing this file to be compiled completely for the host. 17 | 18 | // Notes: Read programming guide Section K.3 19 | // - Can prefetch unified memory 20 | // - Can advise on location of unified memory 21 | 22 | // TODO: Can I use smart pointers? unique_ptr? 23 | 24 | // TODO: Clean this up 25 | namespace { 26 | typedef std::uint8_t byte; 27 | 28 | template< typename T > 29 | static byte *as_byte_ptr(T *ptr) { 30 | return reinterpret_cast(ptr); 31 | } 32 | 33 | template< typename T > 34 | static const byte *as_byte_ptr(const T *ptr) { 35 | return reinterpret_cast(ptr); 36 | } 37 | 38 | // TODO: refactor from word_fixnum. 39 | template< typename T > 40 | T ceilquo(T n, T d) { 41 | return (n + d - 1) / d; 42 | } 43 | } 44 | 45 | template< typename fixnum > 46 | fixnum_array * 47 | fixnum_array::create(size_t nelts) { 48 | fixnum_array *a = new fixnum_array; 49 | a->nelts = nelts; 50 | if (nelts > 0) { 51 | size_t nbytes = nelts * fixnum::BYTES; 52 | cuda_malloc_managed(&a->ptr, nbytes); 53 | } 54 | return a; 55 | } 56 | 57 | template< typename fixnum > 58 | template< typename T > 59 | fixnum_array * 60 | fixnum_array::create(size_t nelts, T init) { 61 | fixnum_array *a = create(nelts); 62 | byte *p = as_byte_ptr(a->ptr); 63 | 64 | const byte *in = as_byte_ptr(&init); 65 | byte elt[fixnum::BYTES]; 66 | memset(elt, 0, fixnum::BYTES); 67 | std::copy(in, in + sizeof(T), elt); 68 | 69 | for (uint32_t i = 0; i < nelts; ++i, p += fixnum::BYTES) 70 | fixnum::from_bytes(p, elt, fixnum::BYTES); 71 | return a; 72 | } 73 | 74 | template< typename fixnum > 75 | fixnum_array * 76 | fixnum_array::create(const byte *data, size_t total_bytes, size_t bytes_per_elt) { 77 | // FIXME: Should handle this error more appropriately 78 | if (total_bytes == 0 || bytes_per_elt == 0) 79 | return nullptr; 80 | 81 | size_t nelts = ceilquo(total_bytes, bytes_per_elt); 82 | fixnum_array *a = create(nelts); 83 | 84 | byte *p = as_byte_ptr(a->ptr); 85 | const byte *d = data; 86 | for (size_t i = 0; i < nelts; ++i) { 87 | fixnum::from_bytes(p, d, bytes_per_elt); 88 | p += fixnum::BYTES; 89 | d += bytes_per_elt; 90 | } 91 | return a; 92 | } 93 | 94 | // TODO: This doesn't belong here. 95 | template< typename digit > 96 | void 97 | rotate_array(digit *out, const digit *in, int nelts, int words_per_elt, int i) { 98 | if (i < 0) { 99 | int j = -i; 100 | i += nelts * ceilquo(j, nelts); 101 | assert(i >= 0 && i < nelts); 102 | i = nelts - i; 103 | } else if (i >= nelts) 104 | i %= nelts; 105 | int pivot = i * words_per_elt; 106 | int nwords = nelts * words_per_elt; 107 | std::copy(in, in + nwords - pivot, out + pivot); 108 | std::copy(in + nwords - pivot, in + nwords, out); 109 | } 110 | 111 | 112 | // TODO: Find a way to return a wrapper that just modifies the requested indices 113 | // on the fly, rather than copying the whole array. Hard part will be making it 114 | // work with map/dispatch. 115 | template< typename fixnum > 116 | fixnum_array * 117 | fixnum_array::rotate(int i) { 118 | fixnum_array *a = create(length()); 119 | byte *p = as_byte_ptr(a->ptr); 120 | const byte *q = as_byte_ptr(ptr); 121 | rotate_array(p, q, nelts, fixnum::BYTES, i); 122 | return a; 123 | } 124 | 125 | template< typename fixnum > 126 | fixnum_array * 127 | fixnum_array::repeat(int ntimes) { 128 | fixnum_array *a = create(length() * ntimes); 129 | byte *p = as_byte_ptr(a->ptr); 130 | const byte *q = as_byte_ptr(ptr); 131 | int nbytes = nelts * fixnum::BYTES; 132 | for (int i = 0; i < ntimes; ++i, p += nbytes) 133 | std::copy(q, q + nbytes, p); 134 | return a; 135 | } 136 | 137 | template< typename fixnum > 138 | fixnum_array * 139 | fixnum_array::rotations(int ntimes) { 140 | fixnum_array *a = create(nelts * ntimes); 141 | byte *p = as_byte_ptr(a->ptr); 142 | const byte *q = as_byte_ptr(ptr); 143 | int nbytes = nelts * fixnum::BYTES; 144 | for (int i = 0; i < ntimes; ++i, p += nbytes) 145 | rotate_array(p, q, nelts, fixnum::BYTES, i); 146 | return a; 147 | } 148 | 149 | 150 | template< typename fixnum > 151 | int 152 | fixnum_array::set(int idx, const byte *data, size_t nbytes) { 153 | // FIXME: Better error handling 154 | if (idx < 0 || idx >= nelts) 155 | return -1; 156 | 157 | int off = idx * fixnum::BYTES; 158 | const byte *q = as_byte_ptr(ptr); 159 | return fixnum::from_bytes(q + off, data, nbytes); 160 | } 161 | 162 | template< typename fixnum > 163 | fixnum_array::~fixnum_array() { 164 | if (nelts > 0) 165 | cuda_free(ptr); 166 | } 167 | 168 | template< typename fixnum > 169 | int 170 | fixnum_array::length() const { 171 | return nelts; 172 | } 173 | 174 | template< typename fixnum > 175 | size_t 176 | fixnum_array::retrieve_into(byte *dest, size_t dest_space, int idx) const { 177 | if (idx < 0 || idx > nelts) { 178 | // FIXME: This is not the right way to handle an "index out of 179 | // bounds" error. 180 | return 0; 181 | } 182 | const byte *q = as_byte_ptr(ptr); 183 | return fixnum::to_bytes(dest, dest_space, q + idx * fixnum::BYTES); 184 | } 185 | 186 | // FIXME: Can return fewer than nelts elements. 187 | template< typename fixnum > 188 | void 189 | fixnum_array::retrieve_all(byte *dest, size_t dest_space, int *dest_nelts) const { 190 | const byte *p = as_byte_ptr(ptr); 191 | byte *d = dest; 192 | int max_dest_nelts = dest_space / fixnum::BYTES; 193 | *dest_nelts = std::min(nelts, max_dest_nelts); 194 | for (int i = 0; i < *dest_nelts; ++i) { 195 | fixnum::to_bytes(d, fixnum::BYTES, p); 196 | p += fixnum::BYTES; 197 | d += fixnum::BYTES; 198 | } 199 | } 200 | 201 | namespace { 202 | std::string 203 | fixnum_as_str(const uint8_t *fn, int nbytes) { 204 | std::ostringstream ss; 205 | 206 | for (int i = nbytes - 1; i >= 0; --i) { 207 | // These IO manipulators are forgotten after each use; 208 | // i.e. they don't apply to the next output operation (whether 209 | // it be in the next loop iteration or in the conditional 210 | // below. 211 | ss << std::setfill('0') << std::setw(2) << std::hex; 212 | ss << static_cast(fn[i]); 213 | if (i && !(i & 3)) 214 | ss << ' '; 215 | } 216 | return ss.str(); 217 | } 218 | } 219 | 220 | template< typename fixnum > 221 | std::ostream & 222 | operator<<(std::ostream &os, const fixnum_array *fn_arr) { 223 | constexpr int fn_bytes = fixnum::BYTES; 224 | constexpr size_t bufsz = 4096; 225 | uint8_t arr[bufsz]; 226 | int nelts; 227 | 228 | fn_arr->retrieve_all(arr, bufsz, &nelts); 229 | os << "( "; 230 | if (nelts < fn_arr->length()) { 231 | os << "insufficient space to retrieve array"; 232 | } else if (nelts > 0) { 233 | os << fixnum_as_str(arr, fn_bytes); 234 | for (int i = 1; i < nelts; ++i) 235 | os << ", " << fixnum_as_str(arr + i*fn_bytes, fn_bytes); 236 | } 237 | os << " )" << std::flush; 238 | return os; 239 | } 240 | 241 | 242 | template< template class Func, typename fixnum, typename... Args > 243 | __global__ void 244 | dispatch(int nelts, Args... args) { 245 | // Get the slot index for the current thread. 246 | int blk_tid_offset = blockDim.x * blockIdx.x; 247 | int tid_in_blk = threadIdx.x; 248 | int idx = (blk_tid_offset + tid_in_blk) / fixnum::SLOT_WIDTH; 249 | 250 | if (idx < nelts) { 251 | // TODO: Find a way to load each argument into a register before passing 252 | // it to fn, and then unpack the return values where they belong. This 253 | // will guarantee that all operations happen on registers, rather than 254 | // inadvertently operating on memory. 255 | 256 | Func fn; 257 | // TODO: This offset calculation is entwined with fixnum layout and so 258 | // belongs somewhere else. 259 | int off = idx * fixnum::layout::WIDTH + fixnum::layout::laneIdx(); 260 | // TODO: This is hiding a sin against memory aliasing / management / 261 | // type-safety. 262 | fn(args[off]...); 263 | } 264 | } 265 | 266 | template< typename fixnum > 267 | template< template class Func, typename... Args > 268 | void 269 | fixnum_array::map(Args... args) { 270 | // TODO: Set this to the number of threads on a single SM on the host GPU. 271 | constexpr int BLOCK_SIZE = 192; 272 | 273 | // FIXME: WARPSIZE should come from slot_layout 274 | constexpr int WARPSIZE = 32; 275 | // BLOCK_SIZE must be a multiple of warpSize 276 | static_assert(!(BLOCK_SIZE % WARPSIZE), 277 | "block size must be a multiple of warpSize"); 278 | 279 | int nelts = std::min( { args->length()... } ); 280 | 281 | constexpr int fixnums_per_block = BLOCK_SIZE / fixnum::SLOT_WIDTH; 282 | 283 | // FIXME: nblocks could be too big for a single kernel call to handle 284 | int nblocks = ceilquo(nelts, fixnums_per_block); 285 | 286 | // nblocks > 0 iff nelts > 0 287 | if (nblocks > 0) { 288 | cudaStream_t stream; 289 | cuda_check(cudaStreamCreate(&stream), "create stream"); 290 | // cuda_stream_attach_mem(stream, src->ptr); 291 | // cuda_stream_attach_mem(stream, ptr); 292 | cuda_check(cudaStreamSynchronize(stream), "stream sync"); 293 | 294 | dispatch<<< nblocks, BLOCK_SIZE, 0, stream >>>(nelts, args->ptr...); 295 | 296 | cuda_check(cudaPeekAtLastError(), "kernel invocation/run"); 297 | cuda_check(cudaStreamSynchronize(stream), "stream sync"); 298 | cuda_check(cudaStreamDestroy(stream), "stream destroy"); 299 | 300 | // FIXME: Only synchronize when retrieving data from array 301 | cuda_device_synchronize(); 302 | } 303 | } 304 | 305 | } // End namespace cuFIXNUM 306 | -------------------------------------------------------------------------------- /src/array/fixnum_array.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | 6 | namespace cuFIXNUM { 7 | 8 | // TODO: Copy over functionality and documentation from IntmodVector. 9 | template< typename fixnum > 10 | class fixnum_array { 11 | public: 12 | typedef std::uint8_t byte; 13 | 14 | static fixnum_array *create(size_t nelts); 15 | template< typename T > 16 | static fixnum_array *create(size_t nelts, T init); 17 | // NB: If bytes_per_elt doesn't divide len, the last len % bytes_per_elt 18 | // bytes are *dropped*. 19 | static fixnum_array *create(const byte *data, size_t total_bytes, size_t bytes_per_elt); 20 | 21 | fixnum_array *rotate(int i); 22 | fixnum_array *rotations(int ntimes); 23 | fixnum_array *repeat(int ntimes); 24 | const byte *get_ptr() const { return reinterpret_cast(ptr); } 25 | 26 | ~fixnum_array(); 27 | 28 | int length() const; 29 | 30 | int set(int idx, const byte *data, size_t len); 31 | size_t retrieve_into(byte *dest, size_t dest_space, int idx) const; 32 | void retrieve_all(byte *dest, size_t dest_space, int *nelts) const; 33 | 34 | template< template class Func, typename... Args > 35 | static void map(Args... args); 36 | 37 | private: 38 | fixnum *ptr; 39 | int nelts; 40 | 41 | fixnum_array() { } 42 | 43 | fixnum_array(const fixnum_array &); 44 | fixnum_array &operator=(const fixnum_array &); 45 | }; 46 | 47 | template< typename fixnum > 48 | std::ostream & 49 | operator<<(std::ostream &os, const fixnum_array *fn_arr); 50 | 51 | } // End namespace cuFIXNUM 52 | 53 | #include "fixnum_array.cu" 54 | -------------------------------------------------------------------------------- /src/fixnum/internal/primitives.cu: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | namespace cuFIXNUM { 8 | 9 | namespace internal { 10 | typedef std::uint32_t u32; 11 | typedef std::uint64_t u64; 12 | 13 | __device__ __forceinline__ 14 | void 15 | addc(u32 &s, u32 a, u32 b) { 16 | asm ("addc.u32 %0, %1, %2;" 17 | : "=r"(s) 18 | : "r"(a), "r" (b)); 19 | } 20 | 21 | __device__ __forceinline__ 22 | void 23 | add_cc(u32 &s, u32 a, u32 b) { 24 | asm ("add.cc.u32 %0, %1, %2;" 25 | : "=r"(s) 26 | : "r"(a), "r" (b)); 27 | } 28 | 29 | __device__ __forceinline__ 30 | void 31 | addc_cc(u32 &s, u32 a, u32 b) { 32 | asm ("addc.cc.u32 %0, %1, %2;" 33 | : "=r"(s) 34 | : "r"(a), "r" (b)); 35 | } 36 | 37 | __device__ __forceinline__ 38 | void 39 | addc(u64 &s, u64 a, u64 b) { 40 | asm ("addc.u64 %0, %1, %2;" 41 | : "=l"(s) 42 | : "l"(a), "l" (b)); 43 | } 44 | 45 | __device__ __forceinline__ 46 | void 47 | add_cc(u64 &s, u64 a, u64 b) { 48 | asm ("add.cc.u64 %0, %1, %2;" 49 | : "=l"(s) 50 | : "l"(a), "l" (b)); 51 | } 52 | 53 | __device__ __forceinline__ 54 | void 55 | addc_cc(u64 &s, u64 a, u64 b) { 56 | asm ("addc.cc.u64 %0, %1, %2;" 57 | : "=l"(s) 58 | : "l"(a), "l" (b)); 59 | } 60 | 61 | /* 62 | * hi * 2^n + lo = a * b 63 | */ 64 | __device__ __forceinline__ 65 | void 66 | mul_hi(u32 &hi, u32 a, u32 b) { 67 | asm ("mul.hi.u32 %0, %1, %2;" 68 | : "=r"(hi) 69 | : "r"(a), "r"(b)); 70 | } 71 | 72 | __device__ __forceinline__ 73 | void 74 | mul_hi(u64 &hi, u64 a, u64 b) { 75 | asm ("mul.hi.u64 %0, %1, %2;" 76 | : "=l"(hi) 77 | : "l"(a), "l"(b)); 78 | } 79 | 80 | 81 | /* 82 | * hi * 2^n + lo = a * b 83 | */ 84 | __device__ __forceinline__ 85 | void 86 | mul_wide(u32 &hi, u32 &lo, u32 a, u32 b) { 87 | // TODO: Measure performance difference between this and the 88 | // equivalent: 89 | // mul.hi.u32 %0, %2, %3 90 | // mul.lo.u32 %1, %2, %3 91 | asm ("{\n\t" 92 | " .reg .u64 tmp;\n\t" 93 | " mul.wide.u32 tmp, %2, %3;\n\t" 94 | " mov.b64 { %1, %0 }, tmp;\n\t" 95 | "}" 96 | : "=r"(hi), "=r"(lo) 97 | : "r"(a), "r"(b)); 98 | } 99 | 100 | __device__ __forceinline__ 101 | void 102 | mul_wide(u64 &hi, u64 &lo, u64 a, u64 b) { 103 | asm ("mul.hi.u64 %0, %2, %3;\n\t" 104 | "mul.lo.u64 %1, %2, %3;" 105 | : "=l"(hi), "=l"(lo) 106 | : "l"(a), "l"(b)); 107 | } 108 | 109 | /* 110 | * (hi, lo) = a * b + c 111 | */ 112 | __device__ __forceinline__ 113 | void 114 | mad_wide(u32 &hi, u32 &lo, u32 a, u32 b, u32 c) { 115 | asm ("{\n\t" 116 | " .reg .u64 tmp;\n\t" 117 | " mad.wide.u32 tmp, %2, %3, %4;\n\t" 118 | " mov.b64 { %1, %0 }, tmp;\n\t" 119 | "}" 120 | : "=r"(hi), "=r"(lo) 121 | : "r"(a), "r"(b), "r"(c)); 122 | } 123 | 124 | __device__ __forceinline__ 125 | void 126 | mad_wide(u64 &hi, u64 &lo, u64 a, u64 b, u64 c) { 127 | asm ("mad.lo.cc.u64 %1, %2, %3, %4;\n\t" 128 | "madc.hi.u64 %0, %2, %3, 0;" 129 | : "=l"(hi), "=l"(lo) 130 | : "l"(a), "l" (b), "l"(c)); 131 | } 132 | 133 | // lo = a * b + c (mod 2^n) 134 | __device__ __forceinline__ 135 | void 136 | mad_lo(u32 &lo, u32 a, u32 b, u32 c) { 137 | asm ("mad.lo.u32 %0, %1, %2, %3;" 138 | : "=r"(lo) 139 | : "r"(a), "r" (b), "r"(c)); 140 | } 141 | 142 | __device__ __forceinline__ 143 | void 144 | mad_lo(u64 &lo, u64 a, u64 b, u64 c) { 145 | asm ("mad.lo.u64 %0, %1, %2, %3;" 146 | : "=l"(lo) 147 | : "l"(a), "l" (b), "l"(c)); 148 | } 149 | 150 | 151 | // as above but with carry in cy 152 | __device__ __forceinline__ 153 | void 154 | mad_lo_cc(u32 &lo, u32 a, u32 b, u32 c) { 155 | asm ("mad.lo.cc.u32 %0, %1, %2, %3;" 156 | : "=r"(lo) 157 | : "r"(a), "r" (b), "r"(c)); 158 | } 159 | 160 | __device__ __forceinline__ 161 | void 162 | mad_lo_cc(u64 &lo, u64 a, u64 b, u64 c) { 163 | asm ("mad.lo.cc.u64 %0, %1, %2, %3;" 164 | : "=l"(lo) 165 | : "l"(a), "l" (b), "l"(c)); 166 | } 167 | 168 | __device__ __forceinline__ 169 | void 170 | madc_lo_cc(u32 &lo, u32 a, u32 b, u32 c) { 171 | asm ("madc.lo.cc.u32 %0, %1, %2, %3;" 172 | : "=r"(lo) 173 | : "r"(a), "r" (b), "r"(c)); 174 | } 175 | 176 | __device__ __forceinline__ 177 | void 178 | madc_lo_cc(u64 &lo, u64 a, u64 b, u64 c) { 179 | asm ("madc.lo.cc.u64 %0, %1, %2, %3;" 180 | : "=l"(lo) 181 | : "l"(a), "l" (b), "l"(c)); 182 | } 183 | 184 | __device__ __forceinline__ 185 | void 186 | mad_hi(u32 &hi, u32 a, u32 b, u32 c) { 187 | asm ("mad.hi.u32 %0, %1, %2, %3;" 188 | : "=r"(hi) 189 | : "r"(a), "r" (b), "r"(c)); 190 | } 191 | 192 | __device__ __forceinline__ 193 | void 194 | mad_hi(u64 &hi, u64 a, u64 b, u64 c) { 195 | asm ("mad.hi.u64 %0, %1, %2, %3;" 196 | : "=l"(hi) 197 | : "l"(a), "l" (b), "l"(c)); 198 | } 199 | 200 | __device__ __forceinline__ 201 | void 202 | mad_hi_cc(u32 &hi, u32 a, u32 b, u32 c) { 203 | asm ("mad.hi.cc.u32 %0, %1, %2, %3;" 204 | : "=r"(hi) 205 | : "r"(a), "r" (b), "r"(c)); 206 | } 207 | 208 | __device__ __forceinline__ 209 | void 210 | mad_hi_cc(u64 &hi, u64 a, u64 b, u64 c) { 211 | asm ("mad.hi.cc.u64 %0, %1, %2, %3;" 212 | : "=l"(hi) 213 | : "l"(a), "l" (b), "l"(c)); 214 | } 215 | 216 | __device__ __forceinline__ 217 | void 218 | madc_hi_cc(u32 &hi, u32 a, u32 b, u32 c) { 219 | asm ("madc.hi.cc.u32 %0, %1, %2, %3;" 220 | : "=r"(hi) 221 | : "r"(a), "r" (b), "r"(c)); 222 | } 223 | 224 | __device__ __forceinline__ 225 | void 226 | madc_hi_cc(u64 &hi, u64 a, u64 b, u64 c) { 227 | asm ("madc.hi.cc.u64 %0, %1, %2, %3;\n\t" 228 | : "=l"(hi) 229 | : "l"(a), "l" (b), "l"(c)); 230 | } 231 | 232 | // Source: https://docs.nvidia.com/cuda/parallel-thread-execution/#logic-and-shift-instructions-shf 233 | __device__ __forceinline__ 234 | void 235 | lshift(u32 &out_hi, u32 &out_lo, u32 in_hi, u32 in_lo, unsigned b) { 236 | asm ("shf.l.clamp.b32 %1, %2, %3, %4;\n\t" 237 | "shl.b32 %0, %2, %4;" 238 | : "=r"(out_lo), "=r"(out_hi) : "r"(in_lo), "r"(in_hi), "r"(b)); 239 | } 240 | 241 | /* 242 | * Left shift by b bits; b <= 32. 243 | * Source: https://docs.nvidia.com/cuda/parallel-thread-execution/#logic-and-shift-instructions-shf 244 | */ 245 | __device__ __forceinline__ 246 | void 247 | lshift_b32(u64 &out_hi, u64 &out_lo, u64 in_hi, u64 in_lo, unsigned b) { 248 | assert(b <= 32); 249 | asm ("{\n\t" 250 | " .reg .u32 t1;\n\t" 251 | " .reg .u32 t2;\n\t" 252 | " .reg .u32 t3;\n\t" 253 | " .reg .u32 t4;\n\t" 254 | " .reg .u32 t5;\n\t" 255 | " .reg .u32 t6;\n\t" 256 | " .reg .u32 t7;\n\t" 257 | " .reg .u32 t8;\n\t" 258 | // (t4, t3, t2, t1) = (in_hi, in_lo) 259 | " mov.b64 { t3, t4 }, %3;\n\t" 260 | " mov.b64 { t1, t2 }, %2;\n\t" 261 | " shf.l.clamp.b32 t8, t3, t4, %4;\n\t" 262 | " shf.l.clamp.b32 t7, t2, t3, %4;\n\t" 263 | " shf.l.clamp.b32 t6, t1, t2, %4;\n\t" 264 | " shl.b32 t5, t1, %4;\n\t" 265 | " mov.b64 %1, { t7, t8 };\n\t" 266 | " mov.b64 %0, { t5, t6 };\n\t" 267 | "}" 268 | : "=l"(out_lo), "=l"(out_hi) : "l"(in_lo), "l"(in_hi), "r"(b)); 269 | } 270 | 271 | __device__ __forceinline__ 272 | void 273 | lshift(u64 &out_hi, u64 &out_lo, u64 in_hi, u64 in_lo, unsigned b) { 274 | assert(b <= 64); 275 | unsigned c = min(b, 32); 276 | lshift_b32(out_hi, out_lo, in_hi, in_lo, c); 277 | lshift_b32(out_hi, out_lo, out_hi, out_lo, b - c); 278 | } 279 | 280 | // Source: https://docs.nvidia.com/cuda/parallel-thread-execution/#logic-and-shift-instructions-shf 281 | __device__ __forceinline__ 282 | void 283 | rshift(u32 &out_hi, u32 &out_lo, u32 in_hi, u32 in_lo, unsigned b) { 284 | asm ("shf.r.clamp.b32 %0, %2, %3, %4;\n\t" 285 | "shr.b32 %1, %2, %4;" 286 | : "=r"(out_lo), "=r"(out_hi) : "r"(in_lo), "r"(in_hi), "r"(b)); 287 | } 288 | 289 | /* 290 | * Right shift by b bits; b <= 32. 291 | * Source: https://docs.nvidia.com/cuda/parallel-thread-execution/#logic-and-shift-instructions-shf 292 | */ 293 | __device__ __forceinline__ 294 | void 295 | rshift_b32(u64 &out_hi, u64 &out_lo, u64 in_hi, u64 in_lo, unsigned b) { 296 | assert(b <= 32); 297 | asm ("{\n\t" 298 | " .reg .u32 t1;\n\t" 299 | " .reg .u32 t2;\n\t" 300 | " .reg .u32 t3;\n\t" 301 | " .reg .u32 t4;\n\t" 302 | " .reg .u32 t5;\n\t" 303 | " .reg .u32 t6;\n\t" 304 | " .reg .u32 t7;\n\t" 305 | " .reg .u32 t8;\n\t" 306 | // (t4, t3, t2, t1) = (in_hi, in_lo) 307 | " mov.b64 { t1, t2 }, %2;\n\t" 308 | " mov.b64 { t3, t4 }, %3;\n\t" 309 | " shf.r.clamp.b32 t5, t1, t2, %4;\n\t" 310 | " shf.r.clamp.b32 t6, t2, t3, %4;\n\t" 311 | " shf.r.clamp.b32 t7, t3, t4, %4;\n\t" 312 | " shr.b32 t8, t4, %4;\n\t" 313 | " mov.b64 %0, { t5, t6 };\n\t" 314 | " mov.b64 %1, { t7, t8 };\n\t" 315 | "}" 316 | : "=l"(out_lo), "=l"(out_hi) : "l"(in_lo), "l"(in_hi), "r"(b)); 317 | } 318 | 319 | __device__ __forceinline__ 320 | void 321 | rshift(u64 &out_hi, u64 &out_lo, u64 in_hi, u64 in_lo, unsigned b) { 322 | assert(b <= 64); 323 | unsigned c = min(b, 32); 324 | rshift_b32(out_hi, out_lo, in_hi, in_lo, c); 325 | rshift_b32(out_hi, out_lo, out_hi, out_lo, b - c); 326 | } 327 | 328 | /* 329 | * Count Leading Zeroes in x. 330 | */ 331 | __device__ __forceinline__ 332 | int 333 | clz(u32 x) { 334 | int n; 335 | asm ("clz.b32 %0, %1;" : "=r"(n) : "r"(x)); 336 | return n; 337 | } 338 | 339 | __device__ __forceinline__ 340 | int 341 | clz(u64 x) { 342 | int n; 343 | asm ("clz.b64 %0, %1;" : "=r"(n) : "l"(x)); 344 | return n; 345 | } 346 | 347 | /* 348 | * Count Trailing Zeroes in x. 349 | */ 350 | __device__ __forceinline__ 351 | int 352 | ctz(u32 x) { 353 | int n; 354 | asm ("{\n\t" 355 | " .reg .u32 tmp;\n\t" 356 | " brev.b32 tmp, %1;\n\t" 357 | " clz.b32 %0, tmp;\n\t" 358 | "}" 359 | : "=r"(n) : "r"(x)); 360 | return n; 361 | } 362 | 363 | __device__ __forceinline__ 364 | int 365 | ctz(u64 x) { 366 | int n; 367 | asm ("{\n\t" 368 | " .reg .u64 tmp;\n\t" 369 | " brev.b64 tmp, %1;\n\t" 370 | " clz.b64 %0, tmp;\n\t" 371 | "}" 372 | : "=r"(n) : "l"(x)); 373 | return n; 374 | } 375 | 376 | __device__ __forceinline__ 377 | void 378 | min(u32 &m, u32 a, u32 b) { 379 | asm ("min.u32 %0, %1, %2;" : "=r"(m) : "r"(a), "r"(b)); 380 | } 381 | 382 | __device__ __forceinline__ 383 | void 384 | min(u64 &m, u64 a, u64 b) { 385 | asm ("min.u64 %0, %1, %2;" : "=l"(m) : "l"(a), "l"(b)); 386 | } 387 | 388 | __device__ __forceinline__ 389 | void 390 | max(u32 &m, u32 a, u32 b) { 391 | asm ("max.u32 %0, %1, %2;" : "=r"(m) : "r"(a), "r"(b)); 392 | } 393 | 394 | __device__ __forceinline__ 395 | void 396 | max(u64 &m, u64 a, u64 b) { 397 | asm ("max.u64 %0, %1, %2;" : "=l"(m) : "l"(a), "l"(b)); 398 | } 399 | 400 | __device__ __forceinline__ 401 | void 402 | modinv_2exp(u32 &x, u32 b) { 403 | assert(b & 1); 404 | 405 | x = (2U - b * b) * b; 406 | x *= 2U - b * x; 407 | x *= 2U - b * x; 408 | x *= 2U - b * x; 409 | } 410 | 411 | __device__ __forceinline__ 412 | void 413 | modinv_2exp(u64 &x, u64 b) { 414 | assert(b & 1); 415 | 416 | x = (2UL - b * b) * b; 417 | x *= 2UL - b * x; 418 | x *= 2UL - b * x; 419 | x *= 2UL - b * x; 420 | x *= 2UL - b * x; 421 | } 422 | 423 | /* 424 | * For 512 <= d < 1024, 425 | * 426 | * RECIPROCAL_TABLE_32[d - 512] = floor((2^24 - 2^14 + 2^9)/d) 427 | * 428 | * Total space at the moment is 512*2 = 1024 bytes. 429 | * 430 | * TODO: Investigate whether alternative storage layouts are better; examples: 431 | * 432 | * - redundantly store each element in a uint32_t 433 | * - pack two uint16_t values into each uint32_t 434 | * - is __constant__ the right storage specifier? Maybe load into shared memory? 435 | * Shared memory seems like an excellent choice (48k available per SM), though 436 | * I'll need to be mindful of bank conflicts (perhaps circumvent by having 437 | * many copies of the data in SM?). 438 | * - perhaps reading an element from memory is slower than simply calculating 439 | * floor((2^24 - 2^14 + 2^9)/d) in assembly? 440 | */ 441 | __device__ __constant__ 442 | uint16_t 443 | RECIPROCAL_TABLE_32[0x200] = 444 | { 445 | 0x7fe1, 0x7fa1, 0x7f61, 0x7f22, 0x7ee3, 0x7ea4, 0x7e65, 0x7e27, 446 | 0x7de9, 0x7dab, 0x7d6d, 0x7d30, 0x7cf3, 0x7cb6, 0x7c79, 0x7c3d, 447 | 0x7c00, 0x7bc4, 0x7b89, 0x7b4d, 0x7b12, 0x7ad7, 0x7a9c, 0x7a61, 448 | 0x7a27, 0x79ec, 0x79b2, 0x7979, 0x793f, 0x7906, 0x78cc, 0x7894, 449 | 0x785b, 0x7822, 0x77ea, 0x77b2, 0x777a, 0x7742, 0x770b, 0x76d3, 450 | 0x769c, 0x7665, 0x762f, 0x75f8, 0x75c2, 0x758c, 0x7556, 0x7520, 451 | 0x74ea, 0x74b5, 0x7480, 0x744b, 0x7416, 0x73e2, 0x73ad, 0x7379, 452 | 0x7345, 0x7311, 0x72dd, 0x72aa, 0x7277, 0x7243, 0x7210, 0x71de, 453 | 0x71ab, 0x7179, 0x7146, 0x7114, 0x70e2, 0x70b1, 0x707f, 0x704e, 454 | 0x701c, 0x6feb, 0x6fba, 0x6f8a, 0x6f59, 0x6f29, 0x6ef9, 0x6ec8, 455 | 0x6e99, 0x6e69, 0x6e39, 0x6e0a, 0x6ddb, 0x6dab, 0x6d7d, 0x6d4e, 456 | 0x6d1f, 0x6cf1, 0x6cc2, 0x6c94, 0x6c66, 0x6c38, 0x6c0a, 0x6bdd, 457 | 0x6bb0, 0x6b82, 0x6b55, 0x6b28, 0x6afb, 0x6acf, 0x6aa2, 0x6a76, 458 | 0x6a49, 0x6a1d, 0x69f1, 0x69c6, 0x699a, 0x696e, 0x6943, 0x6918, 459 | 0x68ed, 0x68c2, 0x6897, 0x686c, 0x6842, 0x6817, 0x67ed, 0x67c3, 460 | 0x6799, 0x676f, 0x6745, 0x671b, 0x66f2, 0x66c8, 0x669f, 0x6676, 461 | 0x664d, 0x6624, 0x65fc, 0x65d3, 0x65aa, 0x6582, 0x655a, 0x6532, 462 | 0x650a, 0x64e2, 0x64ba, 0x6493, 0x646b, 0x6444, 0x641c, 0x63f5, 463 | 0x63ce, 0x63a7, 0x6381, 0x635a, 0x6333, 0x630d, 0x62e7, 0x62c1, 464 | 0x629a, 0x6275, 0x624f, 0x6229, 0x6203, 0x61de, 0x61b8, 0x6193, 465 | 0x616e, 0x6149, 0x6124, 0x60ff, 0x60da, 0x60b6, 0x6091, 0x606d, 466 | 0x6049, 0x6024, 0x6000, 0x5fdc, 0x5fb8, 0x5f95, 0x5f71, 0x5f4d, 467 | 0x5f2a, 0x5f07, 0x5ee3, 0x5ec0, 0x5e9d, 0x5e7a, 0x5e57, 0x5e35, 468 | 0x5e12, 0x5def, 0x5dcd, 0x5dab, 0x5d88, 0x5d66, 0x5d44, 0x5d22, 469 | 0x5d00, 0x5cde, 0x5cbd, 0x5c9b, 0x5c7a, 0x5c58, 0x5c37, 0x5c16, 470 | 0x5bf5, 0x5bd4, 0x5bb3, 0x5b92, 0x5b71, 0x5b51, 0x5b30, 0x5b10, 471 | 0x5aef, 0x5acf, 0x5aaf, 0x5a8f, 0x5a6f, 0x5a4f, 0x5a2f, 0x5a0f, 472 | 0x59ef, 0x59d0, 0x59b0, 0x5991, 0x5972, 0x5952, 0x5933, 0x5914, 473 | 0x58f5, 0x58d6, 0x58b7, 0x5899, 0x587a, 0x585b, 0x583d, 0x581f, 474 | 0x5800, 0x57e2, 0x57c4, 0x57a6, 0x5788, 0x576a, 0x574c, 0x572e, 475 | 0x5711, 0x56f3, 0x56d5, 0x56b8, 0x569b, 0x567d, 0x5660, 0x5643, 476 | 0x5626, 0x5609, 0x55ec, 0x55cf, 0x55b2, 0x5596, 0x5579, 0x555d, 477 | 0x5540, 0x5524, 0x5507, 0x54eb, 0x54cf, 0x54b3, 0x5497, 0x547b, 478 | 0x545f, 0x5443, 0x5428, 0x540c, 0x53f0, 0x53d5, 0x53b9, 0x539e, 479 | 0x5383, 0x5368, 0x534c, 0x5331, 0x5316, 0x52fb, 0x52e0, 0x52c6, 480 | 0x52ab, 0x5290, 0x5276, 0x525b, 0x5240, 0x5226, 0x520c, 0x51f1, 481 | 0x51d7, 0x51bd, 0x51a3, 0x5189, 0x516f, 0x5155, 0x513b, 0x5121, 482 | 0x5108, 0x50ee, 0x50d5, 0x50bb, 0x50a2, 0x5088, 0x506f, 0x5056, 483 | 0x503c, 0x5023, 0x500a, 0x4ff1, 0x4fd8, 0x4fbf, 0x4fa6, 0x4f8e, 484 | 0x4f75, 0x4f5c, 0x4f44, 0x4f2b, 0x4f13, 0x4efa, 0x4ee2, 0x4eca, 485 | 0x4eb1, 0x4e99, 0x4e81, 0x4e69, 0x4e51, 0x4e39, 0x4e21, 0x4e09, 486 | 0x4df1, 0x4dda, 0x4dc2, 0x4daa, 0x4d93, 0x4d7b, 0x4d64, 0x4d4d, 487 | 0x4d35, 0x4d1e, 0x4d07, 0x4cf0, 0x4cd8, 0x4cc1, 0x4caa, 0x4c93, 488 | 0x4c7d, 0x4c66, 0x4c4f, 0x4c38, 0x4c21, 0x4c0b, 0x4bf4, 0x4bde, 489 | 0x4bc7, 0x4bb1, 0x4b9a, 0x4b84, 0x4b6e, 0x4b58, 0x4b41, 0x4b2b, 490 | 0x4b15, 0x4aff, 0x4ae9, 0x4ad3, 0x4abd, 0x4aa8, 0x4a92, 0x4a7c, 491 | 0x4a66, 0x4a51, 0x4a3b, 0x4a26, 0x4a10, 0x49fb, 0x49e5, 0x49d0, 492 | 0x49bb, 0x49a6, 0x4990, 0x497b, 0x4966, 0x4951, 0x493c, 0x4927, 493 | 0x4912, 0x48fe, 0x48e9, 0x48d4, 0x48bf, 0x48ab, 0x4896, 0x4881, 494 | 0x486d, 0x4858, 0x4844, 0x482f, 0x481b, 0x4807, 0x47f3, 0x47de, 495 | 0x47ca, 0x47b6, 0x47a2, 0x478e, 0x477a, 0x4766, 0x4752, 0x473e, 496 | 0x472a, 0x4717, 0x4703, 0x46ef, 0x46db, 0x46c8, 0x46b4, 0x46a1, 497 | 0x468d, 0x467a, 0x4666, 0x4653, 0x4640, 0x462c, 0x4619, 0x4606, 498 | 0x45f3, 0x45e0, 0x45cd, 0x45ba, 0x45a7, 0x4594, 0x4581, 0x456e, 499 | 0x455b, 0x4548, 0x4536, 0x4523, 0x4510, 0x44fe, 0x44eb, 0x44d8, 500 | 0x44c6, 0x44b3, 0x44a1, 0x448f, 0x447c, 0x446a, 0x4458, 0x4445, 501 | 0x4433, 0x4421, 0x440f, 0x43fd, 0x43eb, 0x43d9, 0x43c7, 0x43b5, 502 | 0x43a3, 0x4391, 0x437f, 0x436d, 0x435c, 0x434a, 0x4338, 0x4327, 503 | 0x4315, 0x4303, 0x42f2, 0x42e0, 0x42cf, 0x42bd, 0x42ac, 0x429b, 504 | 0x4289, 0x4278, 0x4267, 0x4256, 0x4244, 0x4233, 0x4222, 0x4211, 505 | 0x4200, 0x41ef, 0x41de, 0x41cd, 0x41bc, 0x41ab, 0x419a, 0x418a, 506 | 0x4179, 0x4168, 0x4157, 0x4147, 0x4136, 0x4125, 0x4115, 0x4104, 507 | 0x40f4, 0x40e3, 0x40d3, 0x40c2, 0x40b2, 0x40a2, 0x4091, 0x4081, 508 | 0x4071, 0x4061, 0x4050, 0x4040, 0x4030, 0x4020, 0x4010, 0x4000 509 | }; 510 | 511 | __device__ __forceinline__ 512 | uint32_t 513 | lookup_reciprocal(uint32_t d10) { 514 | assert((d10 >> 9) == 1); 515 | return RECIPROCAL_TABLE_32[d10 - 0x200]; 516 | } 517 | 518 | 519 | /* 520 | * Source: Niels Möller and Torbjörn Granlund, “Improved division by 521 | * invariant integers”, IEEE Transactions on Computers, 11 June 522 | * 2010. https://gmplib.org/~tege/division-paper.pdf 523 | */ 524 | __device__ __forceinline__ 525 | uint32_t 526 | quorem_reciprocal(uint32_t d) 527 | { 528 | // Top bit must be set, i.e. d must be already normalised. 529 | assert((d >> 31) == 1); 530 | 531 | uint32_t d0_mask, d10, d21, d31, v0, v1, v2, v3, e, t0, t1; 532 | 533 | d0_mask = -(uint32_t)(d & 1); // 0 if d&1=0, 0xFF..FF if d&1=1. 534 | d10 = d >> 22; 535 | d21 = (d >> 11) + 1; 536 | d31 = d - (d >> 1); // ceil(d/2) = d - floor(d/2) 537 | 538 | v0 = lookup_reciprocal(d10); // 15 bits 539 | mul_hi(t0, v0 * v0, d21); 540 | v1 = (v0 << 4) - t0 - 1; // 18 bits 541 | e = -(v1 * d31) + ((v1 >> 1) & d0_mask); 542 | mul_hi(t0, v1, e); 543 | v2 = (v1 << 15) + (t0 >> 1); // 33 bits (hi bit is implicit) 544 | mul_wide(t1, t0, v2, d); 545 | t1 += d + ((t0 + d) < d); 546 | v3 = v2 - t1; // 33 bits (hi bit is implicit) 547 | return v3; 548 | } 549 | 550 | /* 551 | * For 256 <= d < 512, 552 | * 553 | * RECIPROCAL_TABLE_64[d - 256] = floor((2^19 - 3*2^9)/d) 554 | * 555 | * Total space ATM is 256*2 = 512 bytes. Entries range from 10 to 11 556 | * bits, so with some clever handling of hi bits, we could get three 557 | * entries per 32 bit word, reducing the size to about 256*11/8 = 352 558 | * bytes. 559 | * 560 | * TODO: Investigate whether alternative storage layouts are better; 561 | * see RECIPROCAL_TABLE_32 above for ideas. 562 | */ 563 | __device__ __constant__ 564 | uint16_t 565 | RECIPROCAL_TABLE_64[0x100] = 566 | { 567 | 0x7fd, 0x7f5, 0x7ed, 0x7e5, 0x7dd, 0x7d5, 0x7ce, 0x7c6, 568 | 0x7bf, 0x7b7, 0x7b0, 0x7a8, 0x7a1, 0x79a, 0x792, 0x78b, 569 | 0x784, 0x77d, 0x776, 0x76f, 0x768, 0x761, 0x75b, 0x754, 570 | 0x74d, 0x747, 0x740, 0x739, 0x733, 0x72c, 0x726, 0x720, 571 | 0x719, 0x713, 0x70d, 0x707, 0x700, 0x6fa, 0x6f4, 0x6ee, 572 | 0x6e8, 0x6e2, 0x6dc, 0x6d6, 0x6d1, 0x6cb, 0x6c5, 0x6bf, 573 | 0x6ba, 0x6b4, 0x6ae, 0x6a9, 0x6a3, 0x69e, 0x698, 0x693, 574 | 0x68d, 0x688, 0x683, 0x67d, 0x678, 0x673, 0x66e, 0x669, 575 | 0x664, 0x65e, 0x659, 0x654, 0x64f, 0x64a, 0x645, 0x640, 576 | 0x63c, 0x637, 0x632, 0x62d, 0x628, 0x624, 0x61f, 0x61a, 577 | 0x616, 0x611, 0x60c, 0x608, 0x603, 0x5ff, 0x5fa, 0x5f6, 578 | 0x5f1, 0x5ed, 0x5e9, 0x5e4, 0x5e0, 0x5dc, 0x5d7, 0x5d3, 579 | 0x5cf, 0x5cb, 0x5c6, 0x5c2, 0x5be, 0x5ba, 0x5b6, 0x5b2, 580 | 0x5ae, 0x5aa, 0x5a6, 0x5a2, 0x59e, 0x59a, 0x596, 0x592, 581 | 0x58e, 0x58a, 0x586, 0x583, 0x57f, 0x57b, 0x577, 0x574, 582 | 0x570, 0x56c, 0x568, 0x565, 0x561, 0x55e, 0x55a, 0x556, 583 | 0x553, 0x54f, 0x54c, 0x548, 0x545, 0x541, 0x53e, 0x53a, 584 | 0x537, 0x534, 0x530, 0x52d, 0x52a, 0x526, 0x523, 0x520, 585 | 0x51c, 0x519, 0x516, 0x513, 0x50f, 0x50c, 0x509, 0x506, 586 | 0x503, 0x500, 0x4fc, 0x4f9, 0x4f6, 0x4f3, 0x4f0, 0x4ed, 587 | 0x4ea, 0x4e7, 0x4e4, 0x4e1, 0x4de, 0x4db, 0x4d8, 0x4d5, 588 | 0x4d2, 0x4cf, 0x4cc, 0x4ca, 0x4c7, 0x4c4, 0x4c1, 0x4be, 589 | 0x4bb, 0x4b9, 0x4b6, 0x4b3, 0x4b0, 0x4ad, 0x4ab, 0x4a8, 590 | 0x4a5, 0x4a3, 0x4a0, 0x49d, 0x49b, 0x498, 0x495, 0x493, 591 | 0x490, 0x48d, 0x48b, 0x488, 0x486, 0x483, 0x481, 0x47e, 592 | 0x47c, 0x479, 0x477, 0x474, 0x472, 0x46f, 0x46d, 0x46a, 593 | 0x468, 0x465, 0x463, 0x461, 0x45e, 0x45c, 0x459, 0x457, 594 | 0x455, 0x452, 0x450, 0x44e, 0x44b, 0x449, 0x447, 0x444, 595 | 0x442, 0x440, 0x43e, 0x43b, 0x439, 0x437, 0x435, 0x432, 596 | 0x430, 0x42e, 0x42c, 0x42a, 0x428, 0x425, 0x423, 0x421, 597 | 0x41f, 0x41d, 0x41b, 0x419, 0x417, 0x414, 0x412, 0x410, 598 | 0x40e, 0x40c, 0x40a, 0x408, 0x406, 0x404, 0x402, 0x400 599 | }; 600 | 601 | __device__ __forceinline__ 602 | uint64_t 603 | lookup_reciprocal(uint64_t d9) { 604 | assert((d9 >> 8) == 1); 605 | return RECIPROCAL_TABLE_64[d9 - 0x100]; 606 | } 607 | 608 | /* 609 | * Source: Niels Möller and Torbjörn Granlund, “Improved division by 610 | * invariant integers”, IEEE Transactions on Computers, 11 June 611 | * 2010. https://gmplib.org/~tege/division-paper.pdf 612 | */ 613 | __device__ __forceinline__ 614 | uint64_t 615 | quorem_reciprocal(uint64_t d) 616 | { 617 | // Top bit must be set, i.e. d must be already normalised. 618 | assert((d >> 63) == 1); 619 | 620 | uint64_t d0_mask, d9, d40, d63, v0, v1, v2, v3, v4, e, t0, t1; 621 | 622 | d0_mask = -(uint64_t)(d & 1); // 0 if d&1=0, 0xFF..FF if d&1=1. 623 | d9 = d >> 55; 624 | d40 = (d >> 24) + 1; 625 | d63 = d - (d >> 1); // ceil(d/2) = d - floor(d/2) 626 | 627 | v0 = lookup_reciprocal(d9); // 11 bits 628 | t0 = v0 * v0 * d40; 629 | v1 = (v0 << 11) - (t0 >> 40) - 1; // 21 bits 630 | t0 = v1 * ((1UL << 60) - (v1 * d40)); 631 | v2 = (v1 << 13) + (t0 >> 47); // 34 bits 632 | 633 | e = -(v2 * d63) + ((v1 >> 1) & d0_mask); 634 | mul_hi(t0, v2, e); 635 | v3 = (v2 << 31) + (t0 >> 1); // 65 bits (hi bit is implicit) 636 | mul_wide(t1, t0, v3, d); 637 | t1 += d + ((t0 + d) < d); 638 | v4 = v3 - t1; // 65 bits (hi bit is implicit) 639 | return v4; 640 | } 641 | 642 | 643 | template< typename uint_tp > 644 | __device__ __forceinline__ 645 | int 646 | quorem_normalise_divisor(uint_tp &d) { 647 | int cnt = clz(d); 648 | d <<= cnt; 649 | return cnt; 650 | } 651 | 652 | template< typename uint_tp > 653 | __device__ __forceinline__ 654 | uint_tp 655 | quorem_normalise_dividend(uint_tp &u_hi, uint_tp &u_lo, int cnt) { 656 | // TODO: For 32 bit operands we can just do the following 657 | // asm ("shf.l.clamp.b32 %0, %1, %0, %2;\n\t" 658 | // "shl.b32 %1, %1, %2;" 659 | // : "+r"(u_hi), "+r"(u_lo) : "r"(cnt)); 660 | // 661 | // For 64 bits it's a bit more long-winded 662 | // Inspired by https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#logic-and-shift-instructions-shf 663 | // asm ("{\n\t" 664 | // " .reg .u32 t1;\n\t" 665 | // " .reg .u32 t2;\n\t" 666 | // " .reg .u32 t3;\n\t" 667 | // " .reg .u32 t4;\n\t" 668 | // " mov.b64 { t1, t2 }, %0;\n\t" 669 | // " mov.b64 { t3, t4 }, %1;\n\t" 670 | // " shf.l.clamp.b32 t4, t3, t4, %2;\n\t" 671 | // " shf.l.clamp.b32 t3, t2, t3, %2;\n\t" 672 | // " shf.l.clamp.b32 t2, t1, t2, %2;\n\t" 673 | // " shl.b32 t1, t1, %2;\n\t" 674 | // " mov.b64 %0, { t1, t2 };\n\t" 675 | // " mov.b64 %1, { t3, t4 };\n\t" 676 | // "}" 677 | // : "+l"(u_lo), "+l"(u_hi) : "r"(cnt)); 678 | 679 | static constexpr int WORD_BITS = sizeof(uint_tp) * 8; 680 | uint_tp overflow = u_hi >> (WORD_BITS - cnt); 681 | uint_tp u_hi_lsb = u_lo >> (WORD_BITS - cnt); 682 | #ifndef __CUDA_ARCH__ 683 | // Compensate for the fact that, unlike CUDA, shifts by WORD_BITS 684 | // are undefined in C. 685 | // u_hi_lsb = 0 if cnt=0 or u_hi_lsb if cnt!=0. 686 | u_hi_lsb &= -(uint_tp)!!cnt; 687 | overflow &= -(uint_tp)!!cnt; 688 | #endif 689 | u_hi = (u_hi << cnt) | u_hi_lsb; 690 | u_lo <<= cnt; 691 | return overflow; 692 | } 693 | 694 | /* 695 | * Suppose Q and r satisfy U = Qd + r, where Q = (q_hi, q_lo) and U = 696 | * (u_hi, u_lo) are two-word numbers. This function returns q = min(Q, 697 | * 2^WORD_BITS - 1) and r = U - Qd if q = Q or r = q in the latter 698 | * case. v should be set to quorem_reciprocal(d). 699 | * 700 | * CAVEAT EMPTOR: d and {u_hi, u_lo} need to be normalised (using the 701 | * functions provided) PRIOR to being passed to this 702 | * function. Similarly, the resulting remainder r (but NOT the 703 | * quotient q) needs to be denormalised (i.e. right shift by the 704 | * normalisation factor) after receipt. 705 | * 706 | * Source: Niels Möller and Torbjörn Granlund, “Improved division by 707 | * invariant integers”, IEEE Transactions on Computers, 11 June 708 | * 2010. https://gmplib.org/~tege/division-paper.pdf 709 | */ 710 | template< typename uint_tp > 711 | __device__ 712 | void 713 | quorem_wide_normalised( 714 | uint_tp &q, uint_tp &r, 715 | uint_tp u_hi, uint_tp u_lo, uint_tp d, uint_tp v) 716 | { 717 | static_assert(std::is_unsigned::value == true, 718 | "template type must be unsigned"); 719 | if (u_hi > d) { 720 | q = r = (uint_tp)-1; 721 | return; 722 | } 723 | 724 | uint_tp q_hi, q_lo, mask; 725 | 726 | mul_wide(q_hi, q_lo, u_hi, v); 727 | q_lo += u_lo; 728 | q_hi += u_hi + (q_lo < u_lo) + 1; 729 | r = u_lo - q_hi * d; 730 | 731 | // Branch is unpredicable 732 | //if (r > q_lo) { --q_hi; r += d; } 733 | mask = -(uint_tp)(r > q_lo); 734 | q_hi += mask; 735 | r += mask & d; 736 | 737 | // Branch is very unlikely to be taken 738 | if (r >= d) { r -= d; ++q_hi; } 739 | //mask = -(uint_tp)(r >= d); 740 | //q_hi -= mask; 741 | //r -= mask & d; 742 | 743 | q = q_hi; 744 | } 745 | 746 | /* 747 | * As above, but calculate, then return, the precomputed inverse for d. 748 | * Normalisation of the divisor and dividend is performed then thrown away. 749 | */ 750 | template< typename uint_tp > 751 | __device__ __forceinline__ 752 | uint_tp 753 | quorem_wide( 754 | uint_tp &q, uint_tp &r, 755 | uint_tp u_hi, uint_tp u_lo, uint_tp d) 756 | { 757 | static_assert(std::is_unsigned::value == true, 758 | "template type must be unsigned"); 759 | int lz = quorem_normalise_divisor(d); 760 | uint_tp overflow = quorem_normalise_dividend(u_hi, u_lo, lz); 761 | uint_tp v = quorem_reciprocal(d); 762 | if (overflow) { q = r = (uint_tp)-1; return v; } 763 | quorem_wide_normalised(q, r, u_hi, u_lo, d, v); 764 | assert((r & (((uint_tp)1 << lz) - 1U)) == 0); 765 | r >>= lz; 766 | return v; 767 | } 768 | 769 | /* 770 | * As above, but uses a given precomputed inverse. If the precomputed 771 | * inverse comes from quorem_reciprocal() rather than from quorem_wide() 772 | * above, then make sure the divisor given to quorem_reciprocal() was 773 | * normalised with quorem_normalise_divisor() first. 774 | */ 775 | template< typename uint_tp > 776 | __device__ __forceinline__ 777 | void 778 | quorem_wide( 779 | uint_tp &q, uint_tp &r, 780 | uint_tp u_hi, uint_tp u_lo, uint_tp d, uint_tp v) 781 | { 782 | static_assert(std::is_unsigned::value == true, 783 | "template type must be unsigned"); 784 | int lz = quorem_normalise_divisor(d); 785 | uint_tp overflow = quorem_normalise_dividend(u_hi, u_lo, lz); 786 | if (overflow) { q = r = -(uint_tp)1; } 787 | quorem_wide_normalised(q, r, u_hi, u_lo, d, v); 788 | assert((r & (((uint_tp)1 << lz) - 1U)) == 0); 789 | r >>= lz; 790 | } 791 | 792 | /* 793 | * ceiling(n / d) 794 | */ 795 | template< typename T > 796 | __device__ __forceinline__ 797 | void 798 | ceilquo(T &q, T n, T d) { 799 | q = (n + d - 1) / d; 800 | } 801 | 802 | } // End namespace internal 803 | 804 | } // End namespace cuFIXNUM 805 | -------------------------------------------------------------------------------- /src/fixnum/nail.cu: -------------------------------------------------------------------------------- 1 | 2 | 3 | template< typename digit > 4 | __device__ __forceinline__ void 5 | hand_add(digit &r, digit a, digit b) 6 | { 7 | r = a + b; 8 | } 9 | 10 | 11 | template< typename digit, int NAIL_BITS > 12 | struct nail_data 13 | { 14 | typedef typename digit digit; 15 | // FIXME: This doesn't work if digit is signed 16 | constexpr digit DIGIT_MAX = ~(digit)0; 17 | constexpr int DIGIT_BITS = sizeof(digit) * 8; 18 | constexpr int NON_NAIL_BITS = DIGIT_BITS - NAIL_BITS; 19 | constexpr digit NAIL_MASK = DIGIT_MAX << NON_NAIL_BITS; 20 | constexpr digit NON_NAIL_MASK = ~NAIL_MASK; 21 | constexpr digit NON_NAIL_MAX = NON_NAIL_MASK; // alias 22 | 23 | // A nail must fit in an int. 24 | static_assert(NAIL_BITS > 0 && NAIL_BITS < sizeof(int) * 8, 25 | "invalid number of nail bits"); 26 | }; 27 | 28 | 29 | // TODO: This is ugly 30 | template< typename digit, int NAIL_BITS > 31 | __device__ __forceinline__ int 32 | hand_extract_nail(digit &r) 33 | { 34 | typedef nail_data nd; 35 | 36 | // split r into nail and non-nail parts 37 | nail = r >> nd::NON_NAIL_BITS; 38 | r &= nd::NON_NAIL_MASK; 39 | return nail; 40 | } 41 | 42 | 43 | /* 44 | * Current cost of nail resolution is 4 vote functions. 45 | */ 46 | template< typename digit, int NAIL_BITS > 47 | __device__ int 48 | hand_resolve_nails(digit &r) 49 | { 50 | // TODO: Make this work with a general width 51 | constexpr int WIDTH = warpSize; 52 | // TODO: This is ugly 53 | typedef nail_data nd; 54 | typedef subwarp_data subwarp; 55 | 56 | int nail, nail_hi; 57 | nail = hand_extract_nail(r); 58 | nail_hi = subwarp::shfl(nail, subwarp::toplaneIdx); 59 | 60 | nail = subwarp::shfl_up0(nail, 1); 61 | r += nail; 62 | 63 | // nail is 0 or 1 this time 64 | nail = hand_extract_nail(r); 65 | 66 | return nail_hi + hand_resolve_cy(r, nail, nd::NON_NAIL_MAX); 67 | } 68 | 69 | 70 | template< typename digit, int NAIL_BITS, int WIDTH = warpSize > 71 | __device__ void 72 | hand_mullo_nail(digit &r, digit a, digit b) 73 | { 74 | // FIXME: We shouldn't need nail bits to divide the width 75 | static_assert(!(WIDTH % NAIL_BITS), "nail bits does not divide width"); 76 | // FIXME: also need to check that digit has enough space for the 77 | // accumulated nails. 78 | 79 | typedef subwarp_data subwarp; 80 | 81 | digit n = 0; // nails 82 | 83 | r = 0; 84 | for (int i = WIDTH - 1; i >= 0; --i) { 85 | // FIXME: Should this be NAIL_BITS/2? Because there are two 86 | // additions (hi & lo)? Maybe at most one of the two additions 87 | // will cause an overflow? For example, 0xff * 0xff = 0xfe01 88 | // so overflow is likely in the first case and unlikely in the 89 | // second... 90 | for (int j = 0; j < NAIL_BITS; ++j, --i) { 91 | digit aa = subwarp::shfl(a, i); 92 | 93 | // TODO: See if using umad.wide improves this. 94 | umad_hi(r, aa, b, r); 95 | r = subwarp::shfl_up0(r, 1); 96 | umad_lo(r, aa, b, r); 97 | } 98 | // FIXME: Supposed to shuffle up n by NAIL_BITS digits 99 | // too. Can this be avoided? 100 | n += hand_extract_nails(r); 101 | } 102 | n = subwarp::shfl_up0(n, 1); 103 | hand_add(r, r, n); 104 | hand_resolve_nails(r); 105 | } 106 | 107 | -------------------------------------------------------------------------------- /src/fixnum/slot_layout.cu: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | namespace cuFIXNUM { 6 | 7 | // For some reason the warpSize value provided by CUDA is not 8 | // considered a constant value, so cannot be used in constexprs or 9 | // template parameters or static_asserts. Hence we must use WARPSIZE 10 | // instead. 11 | static constexpr int WARPSIZE = 32; 12 | 13 | // TODO: Tidy up nomenclature: SUBWARP -> Slot 14 | /* 15 | * SUBWARPS: 16 | * 17 | * Most of these functions operate on the level of a "subwarp" (NB: 18 | * this is not standard terminology). A *warp* is a logical block of 19 | * 32 threads executed in lock-step by the GPU (thus obviating the 20 | * need for explicit synchronisation). For any w > 1 that divides 32, 21 | * a warp can be partitioned into 32/w subwarps of w threads. The 22 | * struct below takes a parameter "width" which specifies the subwarp 23 | * size, and which thereby specifies the size of the numbers on which 24 | * its functions operate. 25 | * 26 | * The term "warp" should be reserved for subwarps of width 32 27 | * (=warpSize). 28 | * 29 | * TODO: Work out if using __forceinline__ in these definitions 30 | * actually achieves anything. 31 | */ 32 | 33 | template 34 | struct slot_layout 35 | { 36 | static_assert(width > 0 && !(WARPSIZE & (width - 1)), 37 | "slot width must be a positive divisor of warpSize (=32)"); 38 | 39 | static constexpr int WIDTH = width; 40 | 41 | /* 42 | * Return the lane index within the slot. 43 | * 44 | * The lane index is the thread index modulo the width of the slot. 45 | */ 46 | static __device__ __forceinline__ 47 | int 48 | laneIdx() { 49 | // threadIdx.x % width = threadIdx.x & (width - 1) since width = 2^n 50 | return threadIdx.x & (width - 1); 51 | 52 | // TODO: Replace above with? 53 | // int L; 54 | // asm ("mov.b32 %0, %laneid;" : "=r"(L)); 55 | // return L; 56 | } 57 | 58 | /* 59 | * Index of the top lane of the current slot. 60 | * 61 | * The top lane of a slot is the one with index width - 1. 62 | */ 63 | static constexpr int toplaneIdx = width - 1; 64 | 65 | /* 66 | * Mask which selects the first width bits of a number. 67 | * 68 | * Useful in conjunction with offset() and __ballot(). 69 | */ 70 | static __device__ __forceinline__ 71 | std::uint32_t 72 | mask() { 73 | return ((1UL << width) - 1UL) << offset(); 74 | } 75 | 76 | /* 77 | * Return the thread index within the warp where the slot 78 | * containing this lane begins. Examples: 79 | * 80 | * - width 16: slot offset is 0 for threads 0-15, and 16 for 81 | * threads 16-31 82 | * 83 | * - width 8: slot offset is 0 for threads 0-7, 8 for threads 8-15, 84 | * 16 for threads 16-23, and 24 for threads 24-31. 85 | * 86 | * The slot offset at thread T in a slot of width w is given by 87 | * floor(T/w)*w. 88 | * 89 | * Useful in conjunction with mask() and __ballot(). 90 | */ 91 | static __device__ __forceinline__ 92 | int 93 | offset() { 94 | // Thread index within the (full) warp. 95 | int tid = threadIdx.x & (WARPSIZE - 1); 96 | 97 | // Recall: x mod y = x - y*floor(x/y), so 98 | // 99 | // slotOffset = width * floor(threadIdx/width) 100 | // = threadIdx - (threadIdx % width) 101 | // = threadIdx - (threadIdx & (width - 1)) 102 | // // TODO: Do use this last formulation! 103 | // = set bottom log2(width) bits of threadIdx to zero 104 | // = T & ~mask ?? or "(T >> width) << width" 105 | // 106 | // since width = 2^n. 107 | return tid - (tid & (width - 1)); 108 | } 109 | 110 | /* 111 | * Like ballot(tst) but restrict the result to the containing slot 112 | * of size width. 113 | */ 114 | __device__ __forceinline__ 115 | static uint32_t 116 | ballot(int tst) { 117 | uint32_t b = __ballot_sync(mask(), tst); 118 | return b >> offset(); 119 | } 120 | 121 | /* 122 | * Wrappers for notation consistency. 123 | */ 124 | __device__ __forceinline__ 125 | static T 126 | shfl(T var, int srcLane) { 127 | return __shfl_sync(mask(), var, srcLane, width); 128 | } 129 | 130 | __device__ __forceinline__ 131 | static T 132 | shfl_up(T var, unsigned int delta) { 133 | return __shfl_up_sync(mask(), var, delta, width); 134 | } 135 | 136 | __device__ __forceinline__ 137 | static T 138 | shfl_down(T var, unsigned int delta) { 139 | return __shfl_down_sync(mask(), var, delta, width); 140 | } 141 | 142 | // NB: Assumes delta <= width + L. (There should be no reason for 143 | // it ever to be more than width.) 144 | __device__ __forceinline__ 145 | static T 146 | rotate_up(T var, unsigned int delta) { 147 | int L = laneIdx(); 148 | // Don't need to reduce srcLane modulo width; that is done by __shfl. 149 | int srcLane = L + width - delta; // The +width is to ensure srcLane > 0 150 | return shfl(var, srcLane); 151 | } 152 | 153 | __device__ __forceinline__ 154 | static T 155 | rotate_down(T var, unsigned int delta) { 156 | int L = laneIdx(); 157 | // Don't need to reduce srcLane modulo width; that is done by __shfl. 158 | int srcLane = L + delta; 159 | return shfl(var, srcLane); 160 | } 161 | 162 | /* 163 | * Like shfl_up but set bottom delta variables to zero. 164 | */ 165 | __device__ __forceinline__ 166 | static T 167 | shfl_up0(T var, unsigned int delta) { 168 | T res = shfl_up(var, delta); 169 | //return res & -(T)(laneIdx() > 0); 170 | return laneIdx() < delta ? T(0) : res; 171 | } 172 | 173 | /* 174 | * Like shfl_down but set top delta variables to zero. 175 | */ 176 | __device__ __forceinline__ 177 | static T 178 | shfl_down0(T var, unsigned int delta) { 179 | T res = shfl_down(var, delta); 180 | //return res & -(T)(laneIdx() < toplaneIdx()); 181 | return laneIdx() >= (width - delta) ? T(0) : res; 182 | } 183 | 184 | private: 185 | slot_layout(); 186 | }; 187 | 188 | } // End namespace cuFIXNUM 189 | -------------------------------------------------------------------------------- /src/fixnum/warp_fixnum.cu: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "slot_layout.cu" 4 | #include "word_fixnum.cu" 5 | 6 | namespace cuFIXNUM { 7 | 8 | /* 9 | * This is an archetypal implementation of a fixnum instruction 10 | * set. It defines the de facto interface for such implementations. 11 | * 12 | * All methods are defined for the device. It is someone else's 13 | * problem to get the data onto the device. 14 | */ 15 | template< int BYTES_, typename digit_ = u32_fixnum > 16 | class warp_fixnum { 17 | public: 18 | // NB: Language convention: Call something a 'digit' when it is constant 19 | // across the slot, and call it a 'fixnum' when it can vary between lanes in 20 | // the slot. Similarly, prefix a function call with 'digit::' when the 21 | // arguments are interpreted component-wise, and with 'fixnum::' when 22 | // they're interpreted "across the slot". 23 | typedef digit_ digit; 24 | typedef warp_fixnum fixnum; 25 | 26 | static constexpr int BYTES = BYTES_; 27 | static constexpr int BITS = 8 * BYTES; 28 | static constexpr int SLOT_WIDTH = BYTES / digit::BYTES; 29 | typedef slot_layout layout; 30 | 31 | static_assert(BYTES > 0, 32 | "Fixnum bytes must be positive."); 33 | static_assert(BYTES % digit::BYTES == 0, 34 | "Fixnum digit size must divide fixnum bytes."); 35 | // TODO: Specialise std::is_integral for fixnum_u32? 36 | //static_assert(std::is_integral< digit >::value, 37 | // "digit must be integral."); 38 | 39 | private: 40 | digit x; 41 | 42 | // TODO: These should be private 43 | public: 44 | __device__ __forceinline__ 45 | operator digit () const { return x; } 46 | 47 | __device__ __forceinline__ 48 | operator digit &() { return x; } 49 | 50 | public: 51 | __device__ __forceinline__ 52 | warp_fixnum() { } 53 | 54 | // TODO: Shouldn't this be equivalent to the digit_to_fixnum() function 55 | // below? 56 | __device__ __forceinline__ 57 | warp_fixnum(digit z) : x(z) { } 58 | 59 | /*************************** 60 | * Representation functions. 61 | */ 62 | 63 | /* 64 | * Set r using bytes, interpreting bytes as a base-256 unsigned 65 | * integer. Return the number of bytes used. If nbytes > 66 | * BYTES, then the last nbytes - BYTES are ignored. 67 | * 68 | * NB: Normally we would expect from_bytes to be exclusively a 69 | * device function, but it's the same for the host, so we leave it 70 | * in. 71 | */ 72 | __host__ __device__ static int from_bytes(uint8_t *r, const uint8_t *bytes, int nbytes) { 73 | int n = min(nbytes, BYTES); 74 | memcpy(r, bytes, n); 75 | memset(r + n, 0, BYTES - n); 76 | return n; 77 | } 78 | 79 | /* 80 | * Set bytes using r, converting r to a base-256 unsigned 81 | * integer. Return the number of bytes written. If nbytes < 82 | * BYTES, then the last BYTES - nbytes are ignored. 83 | * 84 | * NB: Normally we would expect from_bytes to be exclusively a 85 | * device function, but it's the same for the host, so we leave it 86 | * in. 87 | */ 88 | __host__ __device__ static int to_bytes(uint8_t *bytes, int nbytes, const uint8_t *r) { 89 | int n = min(nbytes, BYTES); 90 | memcpy(bytes, r, n); 91 | return n; 92 | } 93 | 94 | /* 95 | * Return digit at index idx. 96 | */ 97 | __device__ static digit get(fixnum var, int idx) { 98 | return layout::shfl(var, idx); 99 | } 100 | 101 | /* 102 | * Set var digit at index idx to be x. 103 | */ 104 | __device__ static void set(fixnum &var, digit x, int idx) { 105 | var = (layout::laneIdx() == idx) ? (fixnum)x : var; 106 | } 107 | 108 | /* 109 | * Return digit in most significant place. Might be zero. 110 | */ 111 | __device__ static digit top_digit(fixnum var) { 112 | return layout::shfl(var, layout::toplaneIdx); 113 | } 114 | 115 | /* 116 | * Return digit in the least significant place. Might be zero. 117 | * 118 | * TODO: Not clear how to interpret this function with more exotic fixnum 119 | * implementations such as RNS. 120 | */ 121 | __device__ static digit bottom_digit(fixnum var) { 122 | return layout::shfl(var, 0); 123 | } 124 | 125 | /*********************** 126 | * Arithmetic functions. 127 | */ 128 | 129 | // TODO: Handle carry in 130 | // TODO: A more consistent syntax might be 131 | // fixnum add(fixnum a, fixnum b) 132 | // fixnum add_cc(fixnum a, fixnum b, int &cy_out) 133 | // fixnum addc(fixnum a, fixnum b, int cy_in) 134 | // fixnum addc_cc(fixnum a, fixnum b, int cy_in, int &cy_out) 135 | __device__ static void add_cy(fixnum &r, digit &cy_hi, fixnum a, fixnum b) { 136 | digit cy; 137 | digit::add_cy(r, cy, a, b); 138 | // r propagates carries iff r = FIXNUM_MAX 139 | digit r_cy = effective_carries(cy_hi, digit::is_max(r), cy); 140 | digit::add(r, r, r_cy); 141 | } 142 | 143 | __device__ static void add(fixnum &r, fixnum a, fixnum b) { 144 | digit cy; 145 | add_cy(r, cy, a, b); 146 | } 147 | 148 | // TODO: Handle borrow in 149 | __device__ static void sub_br(fixnum &r, digit &br_hi, fixnum a, fixnum b) { 150 | digit br; 151 | digit::sub_br(r, br, a, b); 152 | // r propagates borrows iff r = FIXNUM_MIN 153 | digit r_br = effective_carries(br_hi, digit::is_min(r), br); 154 | digit::sub(r, r, r_br); 155 | } 156 | 157 | __device__ static void sub(fixnum &r, fixnum a, fixnum b) { 158 | digit br; 159 | sub_br(r, br, a, b); 160 | } 161 | 162 | __device__ static fixnum zero() { 163 | return digit::zero(); 164 | } 165 | 166 | __device__ static fixnum one() { 167 | return digit(layout::laneIdx() == 0); 168 | } 169 | 170 | __device__ static fixnum two() { 171 | return digit(layout::laneIdx() == 0 ? 2 : 0); 172 | } 173 | 174 | __device__ static int is_zero(fixnum a) { 175 | return nonzero_mask(a) == 0; 176 | } 177 | 178 | __device__ static digit incr_cy(fixnum &r) { 179 | digit cy; 180 | add_cy(r, cy, r, one()); 181 | return cy; 182 | } 183 | 184 | __device__ static digit decr_br(fixnum &r) { 185 | digit br; 186 | sub_br(r, br, r, one()); 187 | return br; 188 | } 189 | 190 | __device__ static void neg(fixnum &r, fixnum a) { 191 | sub(r, zero(), a); 192 | } 193 | 194 | /* 195 | * r = a * u, where a is interpreted as a single word, and u a 196 | * full fixnum. a should be constant across the slot for the 197 | * result to make sense. 198 | * 199 | * TODO: Can this be refactored with mad_cy? 200 | * TODO: Come up with a better name for this function. It's 201 | * scalar multiplication in the vspace of polynomials... 202 | */ 203 | __device__ static digit mul_digit(fixnum &r, digit a, fixnum u) { 204 | fixnum hi, lo; 205 | digit cy, cy_hi; 206 | 207 | digit::mul_wide(hi, lo, a, u); 208 | cy_hi = top_digit(hi); 209 | hi = layout::shfl_up0(hi, 1); 210 | add_cy(lo, cy, lo, hi); 211 | 212 | return cy_hi + cy; 213 | } 214 | 215 | /* 216 | * r = lo_half(a * b) 217 | * 218 | * The "lo_half" is the product modulo 2^(8*BYTES), 219 | * i.e. the same size as the inputs. 220 | */ 221 | __device__ static void mul_lo(fixnum &r, fixnum a, fixnum b) { 222 | // TODO: Implement specific mul_lo function. 223 | digit cy = digit::zero(); 224 | 225 | r = zero(); 226 | for (int i = layout::WIDTH - 1; i >= 0; --i) { 227 | digit aa = layout::shfl(a, i); 228 | 229 | digit::mad_hi_cy(r, cy, aa, b, r); 230 | // TODO: Could use rotate here, which is slightly 231 | // cheaper than shfl_up0... 232 | r = layout::shfl_up0(r, 1); 233 | cy = layout::shfl_up0(cy, 1); 234 | digit::mad_lo_cy(r, cy, aa, b, r); 235 | } 236 | cy = layout::shfl_up0(cy, 1); 237 | add(r, r, cy); 238 | } 239 | 240 | /* 241 | * (s, r) = a * b 242 | * 243 | * r is the "lo half" (see mul_lo above) and s is the 244 | * corresponding "hi half". 245 | */ 246 | __device__ static void mul_wide(fixnum &ss, fixnum &rr, fixnum a, fixnum b) { 247 | int L = layout::laneIdx(); 248 | 249 | fixnum r, s; 250 | r = fixnum::zero(); 251 | s = fixnum::zero(); 252 | digit cy = digit::zero(); 253 | 254 | fixnum ai = get(a, 0); 255 | digit::mul_lo(s, ai, b); 256 | r = L == 0 ? s : r; // r[0] = s[0]; 257 | s = layout::shfl_down0(s, 1); 258 | digit::mad_hi_cy(s, cy, ai, b, s); 259 | 260 | for (int i = 1; i < layout::WIDTH; ++i) { 261 | fixnum ai = get(a, i); 262 | digit::mad_lo_cc(s, ai, b, s); 263 | 264 | fixnum s0 = get(s, 0); 265 | r = (L == i) ? s0 : r; // r[i] = s[0] 266 | s = layout::shfl_down0(s, 1); 267 | 268 | // TODO: Investigate whether deferring this carry resolution until 269 | // after the loop improves performance much. 270 | digit::addc_cc(s, s, cy); // add carry from prev digit 271 | digit::addc(cy, 0, 0); // cy = CC.CF 272 | digit::mad_hi_cy(s, cy, ai, b, s); 273 | } 274 | cy = layout::shfl_up0(cy, 1); 275 | add(s, s, cy); 276 | rr = r; 277 | ss = s; 278 | } 279 | 280 | __device__ static void mul_hi(fixnum &s, fixnum a, fixnum b) { 281 | // TODO: Implement specific mul_hi function. 282 | fixnum r; 283 | mul_wide(s, r, a, b); 284 | } 285 | 286 | /* 287 | * Adapt "rediagonalisation" trick described in Figure 4 of Ozturk, 288 | * Guilford, Gopal (2013) "Large Integer Squaring on Intel 289 | * Architecture Processors". 290 | * 291 | * TODO: This function is only definitively faster than mul_wide when WIDTH 292 | * is 32 (but in that case it's ~50% faster). 293 | */ 294 | __device__ static void 295 | sqr_wide_(fixnum &ss, fixnum &rr, fixnum a) 296 | { 297 | constexpr int W = layout::WIDTH; 298 | int L = layout::laneIdx(); 299 | 300 | fixnum r, s; 301 | r = fixnum::zero(); 302 | s = fixnum::zero(); 303 | fixnum diag_lo = fixnum::zero(); 304 | digit cy = digit::zero(); 305 | 306 | for (int i = 0; i < W / 2; ++i) { 307 | fixnum a1, a2, s0; 308 | int lpi = L + i; 309 | // TODO: Explain how on Earth these formulae pick out the correct 310 | // terms for the squaring. 311 | // NB: Could achieve the same with iterative shuffle's; the expressions 312 | // would be clearer, but the shuffles would (presumably) be more expensive. 313 | a1 = get(a, lpi < W ? i : lpi - W/2); 314 | a2 = get(a, lpi < W ? lpi : W/2 + i); 315 | 316 | assert(L != 0 || digit::cmp(a1,a2)==0); // a1 = a2 when L == 0 317 | 318 | fixnum hi, lo; 319 | digit::mul_wide(hi, lo, a1, a2); 320 | 321 | // TODO: These two (almost identical) blocks cause lots of pipeline 322 | // stalls; need to find a way to reduce their data dependencies. 323 | digit::add_cyio(s, cy, s, lo); 324 | lo = get(lo, 0); 325 | diag_lo = (L == 2*i) ? lo : diag_lo; 326 | s0 = get(s, 0); 327 | r = (L == 2*i) ? s0 : r; // r[2i] = s[0] 328 | s = layout::shfl_down0(s, 1); 329 | 330 | digit::add_cyio(s, cy, s, hi); 331 | hi = get(hi, 0); 332 | diag_lo = (L == 2*i + 1) ? hi : diag_lo; 333 | s0 = get(s, 0); 334 | r = (L == 2*i + 1) ? s0 : r; // r[2i+1] = s[0] 335 | s = layout::shfl_down0(s, 1); 336 | } 337 | 338 | // TODO: All these carries and borrows into s should be accumulated into 339 | // one call. 340 | add(s, s, cy); 341 | 342 | fixnum overflow; 343 | lshift_small(s, s, 1); // s *= 2 344 | lshift_small(r, overflow, r, 1); // r *= 2 345 | add_cy(s, cy, s, overflow); // really a logior, since s was just lshifted. 346 | assert(digit::is_zero(cy)); 347 | 348 | // Doubling r above means we've doubled the diagonal terms, though they 349 | // shouldn't be. Compensate by subtracting a copy of them here. 350 | digit br; 351 | sub_br(r, br, r, diag_lo); 352 | br = (L == 0) ? br : digit::zero(); 353 | sub(s, s, br); 354 | 355 | // TODO: This is wasteful, since the odd lane lo's are discarded as are 356 | // the even lane hi's. 357 | fixnum lo, hi, ai = get(a, W/2 + L/2); 358 | digit::mul_lo(lo, ai, ai); 359 | digit::mul_hi(hi, ai, ai); 360 | fixnum diag_hi = L & 1 ? hi : lo; 361 | 362 | add(s, s, diag_hi); 363 | 364 | rr = r; 365 | ss = s; 366 | } 367 | 368 | __device__ __forceinline__ static void 369 | sqr_wide(fixnum &ss, fixnum &rr, fixnum a) { 370 | // Width below which the general multiplication function is used instead 371 | // of this one. TODO: 16 is very high; need to work out why we're not 372 | // doing better on smaller widths. 373 | constexpr int SQUARING_WIDTH_THRESHOLD = 16; 374 | if (layout::WIDTH < SQUARING_WIDTH_THRESHOLD) 375 | mul_wide(ss, rr, a, a); 376 | else 377 | sqr_wide_(ss, rr, a); 378 | } 379 | 380 | __device__ static void sqr_lo(fixnum &r, fixnum a) { 381 | // TODO: Implement specific sqr_lo function. 382 | fixnum s; 383 | sqr_wide(s, r, a); 384 | } 385 | 386 | __device__ static void sqr_hi(fixnum &s, fixnum a) { 387 | // TODO: Implement specific sqr_hi function. 388 | fixnum r; 389 | sqr_wide(s, r, a); 390 | } 391 | 392 | /* 393 | * Return a mask of width bits whose ith bit is set if and only if 394 | * the ith digit of r is nonzero. In particular, result is zero 395 | * iff r is zero. 396 | */ 397 | __device__ static uint32_t nonzero_mask(fixnum r) { 398 | return layout::ballot( ! digit::is_zero(r)); 399 | } 400 | 401 | /* 402 | * Return -1, 0, or 1, depending on whether x is less than, equal 403 | * to, or greater than y. 404 | */ 405 | __device__ static int cmp(fixnum x, fixnum y) { 406 | fixnum r; 407 | digit br; 408 | sub_br(r, br, x, y); 409 | // r != 0 iff x != y. If x != y, then br != 0 => x < y. 410 | return nonzero_mask(r) ? (br ? -1 : 1) : 0; 411 | } 412 | 413 | /* 414 | * Return the index of the most significant digit of x, or -1 if x is 415 | * zero. 416 | */ 417 | __device__ static int most_sig_dig(fixnum x) { 418 | // FIXME: Should be able to get this value from limits or numeric_limits 419 | // or whatever. 420 | enum { UINT32_BITS = 8 * sizeof(uint32_t) }; 421 | static_assert(UINT32_BITS == 32, "uint32_t isn't 32 bits"); 422 | 423 | uint32_t a = nonzero_mask(x); 424 | return UINT32_BITS - (internal::clz(a) + 1); 425 | } 426 | 427 | /* 428 | * Return the index of the most significant bit of x, or -1 if x is 429 | * zero. 430 | * 431 | * TODO: Give this function a better name; maybe floor_log2()? 432 | */ 433 | __device__ static int msb(fixnum x) { 434 | int b = most_sig_dig(x); 435 | if (b < 0) return b; 436 | digit y = layout::shfl(x, b); 437 | // TODO: These two lines are basically the same as most_sig_dig(); 438 | // refactor. 439 | int c = digit::clz(y); 440 | return digit::BITS - (c + 1) + digit::BITS * b; 441 | } 442 | 443 | /* 444 | * Return the 2-valuation of x, i.e. the integer k >= 0 such that 445 | * 2^k divides x but 2^(k+1) does not divide x. Depending on the 446 | * representation, can think of this as CTZ(x) ("Count Trailing 447 | * Zeros"). The 2-valuation of zero is *ahem* fixnum::BITS. 448 | * 449 | * TODO: Refactor common code between here, msb() and 450 | * most_sig_dig(). Perhaps write msb in terms of two_valuation? 451 | * 452 | * FIXME: Pretty sure this function is broken; e.g. if x is 0 but width < 453 | * warpSize, the answer is wrong. 454 | */ 455 | __device__ static int two_valuation(fixnum x) { 456 | uint32_t a = nonzero_mask(x); 457 | int b = internal::ctz(a), c = 0; 458 | if (b < SLOT_WIDTH) { 459 | digit y = layout::shfl(x, b); 460 | c = digit::ctz(y); 461 | } else 462 | b = SLOT_WIDTH; 463 | return c + b * digit::BITS; 464 | } 465 | 466 | __device__ 467 | static void 468 | lshift_small(fixnum &y, fixnum &overflow, fixnum x, int b) { 469 | assert(b >= 0); 470 | assert(b <= digit::BITS); 471 | int L = layout::laneIdx(); 472 | 473 | fixnum cy; 474 | digit::lshift(y, cy, x, b); 475 | overflow = top_digit(cy); 476 | overflow = (L == 0) ? overflow : fixnum::zero(); 477 | cy = layout::shfl_up0(cy, 1); 478 | digit::add(y, y, cy); // logior 479 | } 480 | 481 | __device__ 482 | static void 483 | lshift_small(fixnum &y, fixnum x, int b) { 484 | assert(b >= 0); 485 | assert(b <= digit::BITS); 486 | 487 | fixnum cy; 488 | digit::lshift(y, cy, x, b); 489 | cy = layout::shfl_up0(cy, 1); 490 | digit::add(y, y, cy); // logior 491 | } 492 | 493 | /* 494 | * Set y to be x shifted by b bits to the left; effectively 495 | * multiply by 2^b. Return the top b bits of x in overflow. 496 | * 497 | * FIXME: Currently assumes that fixnum is unsigned. 498 | * 499 | * TODO: Think of better names for these functions. Something like 500 | * mul_2exp. 501 | * 502 | * TODO: Could improve performance significantly by using the funnel shift 503 | * instruction: https://docs.nvidia.com/cuda/parallel-thread-execution/#logic-and-shift-instructions-shf 504 | */ 505 | __device__ 506 | static void 507 | lshift(fixnum &y, fixnum &overflow, fixnum x, int b) { 508 | assert(b >= 0); 509 | assert(b <= BITS); 510 | int q = b / digit::BITS, r = b % digit::BITS; 511 | 512 | y = layout::rotate_up(x, q); 513 | // Hi bits of y[i] (=overflow) become the lo bits of y[(i+1) % width] 514 | digit::lshift(y, overflow, y, r); 515 | overflow = layout::rotate_up(overflow, 1); 516 | // TODO: This was "y |= overflow"; any advantage to using logior? 517 | digit::add(y, y, overflow); 518 | 519 | fixnum t; 520 | int L = layout::laneIdx(); 521 | digit::set_if(overflow, y, L <= q); // Kill high (q-1) words of y; 522 | digit::rem_2exp(t, overflow, r); // Kill high BITS - r bits of overflow[q] 523 | set(overflow, t, q); 524 | digit::set_if(y, y, L >= q); // Kill low q words of y; 525 | digit::rshift(t, y, r); // Kill low r bits of y[q] 526 | digit::lshift(t, t, r); 527 | set(y, t, q); 528 | } 529 | 530 | __device__ 531 | static void 532 | lshift(fixnum &y, fixnum x, int b) { 533 | assert(b >= 0); 534 | assert(b <= BITS); 535 | int q = b / digit::BITS, r = b % digit::BITS; 536 | 537 | y = layout::shfl_up0(x, q); 538 | lshift_small(y, y, r); 539 | } 540 | 541 | /* 542 | * Set y to be x shifted by b bits to the right; effectively 543 | * divide by 2^b. Return the bottom b bits of x. 544 | * 545 | * TODO: Think of better names for these functions. Something like 546 | * mul_2exp. 547 | */ 548 | __device__ 549 | static void 550 | rshift(fixnum &y, fixnum &underflow, fixnum x, int b) { 551 | lshift(underflow, y, x, BITS - b); 552 | } 553 | 554 | __device__ 555 | static void 556 | rshift(fixnum &y, fixnum x, int b) { 557 | fixnum underflow; 558 | rshift(y, underflow, x, b); 559 | } 560 | 561 | private: 562 | __device__ 563 | static void 564 | digit_to_fixnum(digit &c) { 565 | int L = layout::laneIdx(); 566 | // TODO: Try without branching? c &= -(digit)(L == 0); 567 | c = (L == 0) ? c : digit::zero(); 568 | } 569 | 570 | __device__ 571 | static digit 572 | effective_carries(digit &cy_hi, int propagate, int cy) { 573 | int L = layout::laneIdx(); 574 | uint32_t allcarries, p, g; 575 | 576 | g = layout::ballot(cy); // carry generate 577 | p = layout::ballot(propagate); // carry propagate 578 | allcarries = (p | g) + g; // propagate all carries 579 | // NB: There is no way to unify these two expressions to remove the 580 | // conditional. The conditional should be optimised away though, since 581 | // WIDTH is a compile-time constant. 582 | cy_hi = (layout::WIDTH == WARPSIZE) // detect hi overflow 583 | ? (allcarries < g) 584 | : ((allcarries >> layout::WIDTH) & 1); 585 | allcarries = (allcarries ^ p) | (g << 1); // get effective carries 586 | return (allcarries >> L) & 1; 587 | } 588 | }; 589 | 590 | } // End namespace cuFIXNUM 591 | -------------------------------------------------------------------------------- /src/fixnum/word_fixnum.cu: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "fixnum/internal/primitives.cu" 4 | 5 | namespace cuFIXNUM { 6 | 7 | template< typename T > 8 | class word_fixnum { 9 | public: 10 | typedef T digit; 11 | typedef word_fixnum fixnum; 12 | 13 | static constexpr int BYTES = sizeof(T); 14 | static constexpr int BITS = BYTES * 8; 15 | 16 | private: 17 | digit x; 18 | 19 | // TODO: These should be private 20 | public: 21 | __device__ __forceinline__ 22 | operator digit () const { return x; } 23 | 24 | __device__ __forceinline__ 25 | operator digit &() { return x; } 26 | 27 | public: 28 | __device__ __forceinline__ 29 | word_fixnum() { } 30 | 31 | __device__ __forceinline__ 32 | word_fixnum(digit z) : x(z) { } 33 | 34 | __device__ __forceinline__ 35 | static void 36 | set_if(fixnum &s, fixnum a, int cond) { 37 | s = a & -(digit)cond; 38 | } 39 | 40 | // TODO: Implement/use something like numeric_limits::max() for this 41 | // and most_negative(). 42 | // FIXME: These two functions assume that T is unsigned. 43 | __device__ __forceinline__ 44 | static constexpr fixnum 45 | most_positive() { return ~(fixnum)0; } 46 | 47 | __device__ __forceinline__ 48 | static constexpr fixnum 49 | most_negative() { return zero(); }; 50 | 51 | __device__ __forceinline__ 52 | static constexpr fixnum 53 | zero() { return (fixnum)0; } 54 | 55 | __device__ __forceinline__ 56 | static constexpr fixnum 57 | one() { return (fixnum)1; } 58 | 59 | __device__ __forceinline__ 60 | static constexpr fixnum 61 | two() { return (fixnum)2; } 62 | 63 | __device__ __forceinline__ 64 | static void 65 | add(fixnum &s, fixnum a, fixnum b) { 66 | s = a + b; 67 | } 68 | 69 | // TODO: this function does not follow the convention of later '*_cy' 70 | // functions of accumulating the carry into cy. 71 | __device__ __forceinline__ 72 | static void 73 | add_cy(fixnum &s, digit &cy, fixnum a, fixnum b) { 74 | s = a + b; 75 | cy = s < a; 76 | } 77 | 78 | __device__ __forceinline__ 79 | static void 80 | add_cyio(fixnum &s, digit &cy, fixnum a, fixnum b) { 81 | s = a + cy; 82 | cy = s < a; 83 | s += b; 84 | cy |= s < b; 85 | } 86 | 87 | __device__ __forceinline__ 88 | static void 89 | add_cc(fixnum &s, fixnum a, fixnum b) { 90 | internal::add_cc(s, a, b); 91 | } 92 | 93 | __device__ __forceinline__ 94 | static void 95 | addc(fixnum &s, fixnum a, fixnum b) { 96 | internal::addc(s, a, b); 97 | } 98 | 99 | __device__ __forceinline__ 100 | static void 101 | addc_cc(fixnum &s, fixnum a, fixnum b) { 102 | internal::addc_cc(s, a, b); 103 | } 104 | 105 | __device__ __forceinline__ 106 | static void 107 | incr(fixnum &s) { 108 | ++s; 109 | } 110 | 111 | __device__ __forceinline__ 112 | static void 113 | sub(fixnum &d, fixnum a, fixnum b) { 114 | d = a - b; 115 | } 116 | 117 | __device__ __forceinline__ 118 | static void 119 | sub_br(fixnum &d, digit &br, fixnum a, fixnum b) { 120 | d = a - b; 121 | br = d > a; 122 | } 123 | 124 | __device__ __forceinline__ 125 | static void 126 | neg(fixnum &ma, fixnum a) { 127 | ma = -a; 128 | } 129 | 130 | __device__ __forceinline__ 131 | static void 132 | mul_lo(fixnum &lo, fixnum a, fixnum b) { 133 | lo = a * b; 134 | } 135 | 136 | // hi * 2^32 + lo = a * b 137 | __device__ __forceinline__ 138 | static void 139 | mul_hi(fixnum &hi, fixnum a, fixnum b) { 140 | internal::mul_hi(hi, a, b); 141 | } 142 | 143 | // hi * 2^32 + lo = a * b 144 | __device__ __forceinline__ 145 | static void 146 | mul_wide(fixnum &hi, fixnum &lo, fixnum a, fixnum b) { 147 | internal::mul_wide(hi, lo, a, b); 148 | } 149 | 150 | // (hi, lo) = a * b + c 151 | __device__ __forceinline__ 152 | static void 153 | mad_wide(fixnum &hi, fixnum &lo, fixnum a, fixnum b, fixnum c) { 154 | internal::mad_wide(hi, lo, a, b, c); 155 | } 156 | 157 | // lo = a * b + c (mod 2^32) 158 | __device__ __forceinline__ 159 | static void 160 | mad_lo(fixnum &lo, fixnum a, fixnum b, fixnum c) { 161 | internal::mad_lo(lo, a, b, c); 162 | } 163 | 164 | // as above but increment cy by the mad carry 165 | __device__ __forceinline__ 166 | static void 167 | mad_lo_cy(fixnum &lo, fixnum &cy, fixnum a, fixnum b, fixnum c) { 168 | internal::mad_lo_cc(lo, a, b, c); 169 | internal::addc(cy, cy, 0); 170 | } 171 | 172 | __device__ __forceinline__ 173 | static void 174 | mad_hi(fixnum &hi, fixnum a, fixnum b, fixnum c) { 175 | internal::mad_hi(hi, a, b, c); 176 | } 177 | 178 | // as above but increment cy by the mad carry 179 | __device__ __forceinline__ 180 | static void 181 | mad_hi_cy(fixnum &hi, fixnum &cy, fixnum a, fixnum b, fixnum c) { 182 | internal::mad_hi_cc(hi, a, b, c); 183 | internal::addc(cy, cy, 0); 184 | } 185 | 186 | // TODO: There are weird and only included for mul_wide 187 | __device__ __forceinline__ 188 | static void 189 | mad_lo_cc(fixnum &lo, fixnum a, fixnum b, fixnum c) { 190 | internal::mad_lo_cc(lo, a, b, c); 191 | } 192 | 193 | // Returns the reciprocal for d. 194 | __device__ __forceinline__ 195 | static fixnum 196 | quorem(fixnum &q, fixnum &r, fixnum n, fixnum d) { 197 | return quorem_wide(q, r, zero(), n, d); 198 | } 199 | 200 | // Accepts a reciprocal for d. 201 | __device__ __forceinline__ 202 | static void 203 | quorem(fixnum &q, fixnum &r, fixnum n, fixnum d, fixnum v) { 204 | quorem_wide(q, r, zero(), n, d, v); 205 | } 206 | 207 | // Returns the reciprocal for d. 208 | // NB: returns q = r = fixnum::MAX if n_hi > d. 209 | __device__ __forceinline__ 210 | static fixnum 211 | quorem_wide(fixnum &q, fixnum &r, fixnum n_hi, fixnum n_lo, fixnum d) { 212 | return internal::quorem_wide(q, r, n_hi, n_lo, d); 213 | } 214 | 215 | // Accepts a reciprocal for d. 216 | // NB: returns q = r = fixnum::MAX if n_hi > d. 217 | __device__ __forceinline__ 218 | static void 219 | quorem_wide(fixnum &q, fixnum &r, fixnum n_hi, fixnum n_lo, fixnum d, fixnum v) { 220 | internal::quorem_wide(q, r, n_hi, n_lo, d, v); 221 | } 222 | 223 | __device__ __forceinline__ 224 | static void 225 | rem_2exp(fixnum &r, fixnum n, unsigned k) { 226 | unsigned kp = BITS - k; 227 | r = (n << kp) >> kp; 228 | } 229 | 230 | /* 231 | * Count Leading Zeroes in x. 232 | * 233 | * TODO: This is not an intrinsic quality of a digit, so probably shouldn't 234 | * be in the interface. 235 | */ 236 | __device__ __forceinline__ 237 | static int 238 | clz(fixnum x) { 239 | return internal::clz(x); 240 | } 241 | 242 | /* 243 | * Count Trailing Zeroes in x. 244 | * 245 | * TODO: This is not an intrinsic quality of a digit, so probably shouldn't 246 | * be in the interface. 247 | */ 248 | __device__ __forceinline__ 249 | static int 250 | ctz(fixnum x) { 251 | return internal::ctz(x); 252 | } 253 | 254 | __device__ __forceinline__ 255 | static int 256 | cmp(fixnum a, fixnum b) { 257 | // TODO: There is probably a PTX instruction for this. 258 | int br = (a - b) > a; 259 | return br ? -br : (a != b); 260 | } 261 | 262 | __device__ __forceinline__ 263 | static int 264 | is_max(fixnum a) { return a == most_positive(); } 265 | 266 | __device__ __forceinline__ 267 | static int 268 | is_min(fixnum a) { return a == most_negative(); } 269 | 270 | __device__ __forceinline__ 271 | static int 272 | is_zero(fixnum a) { return a == zero(); } 273 | 274 | __device__ __forceinline__ 275 | static void 276 | min(fixnum &m, fixnum a, fixnum b) { 277 | internal::min(m, a, b); 278 | } 279 | 280 | __device__ __forceinline__ 281 | static void 282 | max(fixnum &m, fixnum a, fixnum b) { 283 | internal::max(m, a, b); 284 | } 285 | 286 | __device__ __forceinline__ 287 | static void 288 | lshift(fixnum &z, fixnum x, unsigned b) { 289 | z = x << b; 290 | } 291 | 292 | __device__ __forceinline__ 293 | static void 294 | lshift(fixnum &z, fixnum &overflow, fixnum x, unsigned b) { 295 | internal::lshift(overflow, z, 0, x, b); 296 | } 297 | 298 | __device__ __forceinline__ 299 | static void 300 | rshift(fixnum &z, fixnum x, unsigned b) { 301 | z = x >> b; 302 | } 303 | 304 | __device__ __forceinline__ 305 | static void 306 | rshift(fixnum &z, fixnum &underflow, fixnum x, unsigned b) { 307 | internal::rshift(z, underflow, x, 0, b); 308 | } 309 | 310 | /* 311 | * Return 1/b (mod 2^BITS) where b is odd. 312 | * 313 | * Source: MCA, Section 2.5. 314 | */ 315 | __device__ __forceinline__ 316 | static void 317 | modinv_2exp(fixnum &x, fixnum b) { 318 | internal::modinv_2exp(x, b); 319 | } 320 | 321 | /* 322 | * Return 1 if x = 2^n for some n, 0 otherwise. (Caveat: Returns 1 for x = 0 323 | * which is not a binary power.) 324 | * 325 | * FIXME: This doesn't belong here. 326 | */ 327 | template< typename uint_type > 328 | __device__ __forceinline__ 329 | static int 330 | is_binary_power(uint_type x) { 331 | //static_assert(std::is_unsigned::value == true, 332 | // "template type must be unsigned"); 333 | return ! (x & (x - 1)); 334 | } 335 | 336 | /* 337 | * y >= x such that y = 2^n for some n. NB: This really is "inclusive" 338 | * next, i.e. if x is a binary power we just return it. 339 | * 340 | * FIXME: This doesn't belong here. 341 | */ 342 | __device__ __forceinline__ 343 | static fixnum 344 | next_binary_power(fixnum x) { 345 | return is_binary_power(x) 346 | ? x 347 | : (fixnum)((digit)1 << (BITS - clz(x))); 348 | } 349 | }; 350 | 351 | typedef word_fixnum u32_fixnum; 352 | typedef word_fixnum u64_fixnum; 353 | 354 | } // End namespace cuFIXNUM 355 | -------------------------------------------------------------------------------- /src/functions/chinese.cu: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "functions/quorem_preinv.cu" 4 | #include "functions/multi_modexp.cu" 5 | #include "modnum/modnum_monty_cios.cu" 6 | 7 | namespace cuFIXNUM { 8 | 9 | template< typename fixnum > 10 | class chinese { 11 | public: 12 | __device__ chinese(fixnum p, fixnum q); 13 | 14 | __device__ void operator()(fixnum &m, fixnum mp, fixnum mq) const; 15 | 16 | private: 17 | // TODO: These all have width = WIDTH/2, so this is a waste of 18 | // space, and (worse) the operations below waste cycles. 19 | fixnum p, q, c; // c = p^-1 (mod q) 20 | 21 | quorem_preinv mod_q; 22 | }; 23 | 24 | template< typename fixnum > 25 | __device__ 26 | chinese::chinese(fixnum p_, fixnum q_) 27 | : p(p_), q(q_), mod_q(q) 28 | { 29 | typedef modnum_monty_cios modnum; 30 | 31 | // TODO: q is now stored here and in mod_q; need to work out how 32 | // to share q between them. Probably best just to provide quorem_preinv 33 | // with an accessor to the divisor. 34 | 35 | // TODO: Make modinv use xgcd and use modinv instead. 36 | // Use a^(q-2) = 1 (mod q) 37 | fixnum qm2, two = fixnum::two(); 38 | fixnum::sub(qm2, q, two); 39 | multi_modexp minv(q); 40 | minv(c, p, qm2); 41 | } 42 | 43 | 44 | /* 45 | * CRT on Mp and Mq. 46 | * 47 | * Mp, Mq, p, q must all be WIDTH/2 digits long 48 | * 49 | * Source HAC, Note 14.75. 50 | */ 51 | template< typename fixnum > 52 | __device__ void 53 | chinese::operator()(fixnum &m, fixnum mp, fixnum mq) const 54 | { 55 | typedef typename fixnum::digit digit; 56 | // u = (mq - mp) * c (mod q) 57 | fixnum u, t, hi, lo; 58 | digit br; 59 | fixnum::sub_br(u, br, mq, mp); 60 | 61 | // TODO: It would be MUCH better to ensure that the mul_wide 62 | // and mod_q parts of this condition occur on the main 63 | // execution path to avoid long warp divergence. 64 | if (br) { 65 | // Mp > Mq 66 | // TODO: Can't I get this from u above? Need a negation 67 | // function; maybe use "method of complements". 68 | fixnum::sub_br(u, br, mp, mq); 69 | assert(digit::is_zero(br)); 70 | 71 | // TODO: Replace mul_wide with the equivalent mul_lo 72 | //digit_mul(hi, lo, u, c, width/2); 73 | fixnum::mul_wide(hi, lo, u, c); 74 | assert(digit::is_zero(hi)); 75 | 76 | t = fixnum::zero(); 77 | //quorem_rem(mod_q, t, hi, lo, width/2); 78 | mod_q(t, hi, lo); 79 | 80 | // TODO: This is a mess. 81 | if ( ! fixnum::is_zero(t)) { 82 | fixnum::sub_br(u, br, q, t); 83 | assert(digit::is_zero(br)); 84 | } else { 85 | u = t; 86 | } 87 | } else { 88 | // Mp < Mq 89 | // TODO: Replace mul_wide with the equivalent mul_lo 90 | //digit_mul(hi, lo, u, c, width/2); 91 | fixnum::mul_wide(hi, lo, u, c); 92 | assert(digit::is_zero(hi)); 93 | 94 | u = fixnum::zero(); 95 | //quorem_rem(mod_q, u, hi, lo, width/2); 96 | mod_q(u, hi, lo); 97 | } 98 | // TODO: Replace mul_wide with the equivalent mul_lo 99 | //digit_mul(hi, lo, u, p, width/2); 100 | fixnum::mul_wide(hi, lo, u, p); 101 | //shfl_up(hi, width/2, width); 102 | //t = (L < width/2) ? lo : hi; 103 | assert(digit::is_zero(hi)); 104 | t = lo; 105 | 106 | //digit_add(m, mp, t, width); 107 | fixnum::add(m, mp, t); 108 | } 109 | 110 | } // End namespace cuFIXNUM 111 | -------------------------------------------------------------------------------- /src/functions/divexact.cu: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "functions/modinv.cu" 4 | 5 | namespace cuFIXNUM { 6 | 7 | template< typename fixnum > 8 | class divexact { 9 | public: 10 | __device__ divexact(fixnum divisor) { 11 | b = divisor; 12 | 13 | // divisor must be odd 14 | // TODO: Handle even divisor. Should be easy: just make sure 15 | // the 2-part of the divisor and dividend are the same and 16 | // then remove them. 17 | typename fixnum::digit b0 = fixnum::get(b, 0); 18 | assert(b0 & 1); 19 | 20 | // Calculate b inverse 21 | modinv minv; 22 | minv(bi, b, fixnum::BITS/2); 23 | } 24 | 25 | /* 26 | * q = a / b, assuming b divides a. 27 | * 28 | * Source: MCA Algorithm 1.10. 29 | */ 30 | __device__ void operator()(fixnum &q, fixnum a) const { 31 | fixnum t, w = fixnum::zero(); 32 | 33 | // w <- a bi (mod 2^(NBITS / 2)) 34 | 35 | // FIXME: This is wasteful since we only want the bottom half of the 36 | // result. Could we do something like: 37 | // 38 | // create half_fixnum which is fixnum< FIXNUM_BYTES / 2 > but 39 | // with same slot_layout. Then use half_fixnum::mul_lo(w, a, bi) 40 | // 41 | fixnum::mul_lo(w, a, bi); 42 | // FIXME: This doesn't work when SLOT_WIDTH = 0 43 | //w = (fixnum::slot_layout::laneIdx() < fixnum::SLOT_WIDTH / 2) ? w : 0; 44 | 45 | // TODO: Can use the "middle product" to speed this up a 46 | // bit. See MCA Section 1.4.5. 47 | // t <- b w (mod 2^NBITS) 48 | fixnum::mul_lo(t, b, w); 49 | // t <- a - b w (mod 2^NBITS) 50 | fixnum::sub(t, a, t); 51 | // t <- bi (a - b w) (mod 2^NBITS) 52 | fixnum::mul_lo(t, bi, t); 53 | // w <- w + bi (a - b w) 54 | fixnum::add(w, w, t); 55 | 56 | q = w; 57 | } 58 | 59 | private: 60 | // Divisor 61 | fixnum b; 62 | // 1/b (mod 2^(NBITS/2)) where NBITS := FIXNUM_BITS. bi is 63 | // nevertheless treated as an NBITS fixnum, so its hi half must be 64 | // all zeros. 65 | fixnum bi; 66 | }; 67 | 68 | } // End namespace cuFIXNUM 69 | -------------------------------------------------------------------------------- /src/functions/internal/modexp_impl.cu: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | namespace cuFIXNUM { 4 | 5 | namespace internal 6 | { 7 | /* 8 | * Return floor(log2(x)). In particular, if x = 2^b, return b. 9 | */ 10 | __device__ 11 | constexpr unsigned 12 | floorlog2(unsigned x) { 13 | return x == 1 ? 0 : 1 + floorlog2(x >> 1); 14 | } 15 | 16 | /* 17 | * The following function gives a reasonable choice of WINDOW_SIZE in the k-ary 18 | * modular exponentiation method for a fixnum of B = 2^b bytes. 19 | * 20 | * The origin of the table is as follows. The expected number of multiplications 21 | * for the k-ary method with n-bit exponent and d-bit window is given by 22 | * 23 | * T(n, d) = 2^d - 2 + n - d + (n/d - 1)*(1 - 2^-d) 24 | * 25 | * (see Koç, C. K., 1995, "Analysis of Sliding Window Techniques for 26 | * Exponentiation", Equation 1). The following GP script calculates the values 27 | * of n at which the window size should increase (maximum n = 65536): 28 | * 29 | * ? T(n,d) = 2^d - 2 + n - d + (n/d - 1) * (1 - 2^-d); 30 | * ? M = [ vecsort([[n, d, T(n, d)*1.] | d <- [1 .. 16]], 3)[1][2] | n <- [1 .. 65536] ]; 31 | * ? maxd = M[65536] 32 | * 10 33 | * ? [[d, vecmin([n | n <- [1 .. 65536], M[n] == d])] | d <- [1 .. maxd]] 34 | * [[1, 1], [2, 7], [3, 35], [4, 122], [5, 369], [6, 1044], [7, 2823], [8, 7371], [9, 18726], [10, 46490]] 35 | * 36 | * Table entry i is the window size for a fixnum of 8*(2^i) bits (e.g. 512 = 37 | * 8*2^6 bits falls between 369 and 1044, so the window size is that of the 38 | * smaller, 369, so 5 is in place i = 6). 39 | */ 40 | // NB: For some reason we're not allowed to put this table in the definition 41 | // of bytes_to_window_size(). 42 | constexpr int BYTES_TO_K_ARY_WINDOW_SIZE_TABLE[] = { 43 | -1, 44 | -1, //bytes bits 45 | 2, // 2^2 32 46 | 3, // 2^3 64 47 | 4, // 2^4 128 48 | 4, // 2^5 256 49 | 5, // 2^6 512 50 | 5, // 2^7 1024 51 | 6, // 2^8 2048 52 | 7, // 2^9 4096 53 | 8, //2^10 8192 54 | 8, //2^11 16384 55 | 9, //2^12 32768 56 | 10,//2^13 65536 57 | }; 58 | 59 | __device__ 60 | constexpr int 61 | bytes_to_k_ary_window_size(unsigned bytes) { 62 | return BYTES_TO_K_ARY_WINDOW_SIZE_TABLE[floorlog2(bytes)]; 63 | } 64 | 65 | 66 | /* 67 | * This Table 2 from Koç, C. K., 1995, "Analysis of Sliding Window 68 | * Techniques for Exponentiation". 69 | * 70 | * The resolution of this table is higher than the one above because it's 71 | * used in the fixed exponent modexp code and can benefit from using the 72 | * precise bit length of the exponent, whereas the table above has to 73 | * accommodate multiple different exponents simultaneously. 74 | */ 75 | __constant__ 76 | int BYTES_TO_CLNW_WINDOW_SIZE_TABLE[] = { 77 | -1, // bits 78 | 4, // 128 79 | 5, // 256 80 | 5, // 384 81 | 5, // 512 82 | 6, // 640 83 | 6, // 768 84 | 6, // 896 85 | 6, // 1024 86 | 6, // 1152 87 | 6, // 1280 88 | 6, // 1408 89 | 6, // 1536 90 | 6, // 1664 91 | 7, // 1792 92 | 7, // 1920 93 | 7, // 2048 94 | }; 95 | 96 | __device__ 97 | constexpr int 98 | bits_to_clnw_window_size(unsigned bits) { 99 | // The chained ternary condition is forced upon us by the Draconian 100 | // constraints of C++11 constexpr functions. 101 | return 102 | bits < 64 ? 2 : 103 | bits < 128 ? 3 : 104 | bits > 2048 ? 7 : 105 | BYTES_TO_CLNW_WINDOW_SIZE_TABLE[(bits / 8) / 16]; 106 | } 107 | 108 | } // End namespace internal 109 | 110 | } // End namespace cuFIXNUM 111 | -------------------------------------------------------------------------------- /src/functions/modexp.cu: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "functions/internal/modexp_impl.cu" 4 | #include "modnum/modnum_monty_cios.cu" 5 | 6 | namespace cuFIXNUM { 7 | 8 | template< typename modnum_tp > 9 | class modexp { 10 | typedef typename modnum_tp::fixnum fixnum; 11 | typedef typename fixnum::digit digit; 12 | 13 | // Decomposition of the exponent for use in the constant-width sliding-window 14 | // algorithm. Allocated & deallocated once per thread block. Ref: 15 | // https://docs.nvidia.com/cuda/cuda-c-programming-guide/#per-thread-block-allocation 16 | // TODO: Consider storing the whole exp_wins array in shared memory. 17 | uint32_t *exp_wins; 18 | int exp_wins_len; 19 | int window_size; 20 | 21 | const modnum_tp modnum; 22 | 23 | // Helper functions for decomposing the exponent into windows. 24 | __device__ uint32_t 25 | scan_window(int &hi_idx, fixnum &n, int max_window_bits); 26 | 27 | __device__ int 28 | scan_zero_window(int &hi_idx, fixnum &n); 29 | 30 | __device__ uint32_t 31 | scan_nonzero_window(int &hi_idx, fixnum &n, int max_window_bits); 32 | 33 | public: 34 | /* 35 | * NB: It is assumed that the caller has reduced exp and mod using knowledge 36 | * of their properties (e.g. reducing exp modulo phi(mod), CRT, etc.). 37 | */ 38 | __device__ modexp(fixnum mod, fixnum exp); 39 | 40 | __device__ ~modexp(); 41 | 42 | __device__ void operator()(fixnum &z, fixnum x) const; 43 | }; 44 | 45 | 46 | template< typename modnum_tp > 47 | __device__ uint32_t 48 | modexp::scan_nonzero_window(int &hi_idx, fixnum &n, int max_window_bits) { 49 | uint32_t bits_remaining = hi_idx + 1, win_bits; 50 | digit w, lsd = fixnum::bottom_digit(n); 51 | 52 | internal::min(win_bits, bits_remaining, max_window_bits); 53 | digit::rem_2exp(w, lsd, win_bits); 54 | fixnum::rshift(n, n, win_bits); 55 | hi_idx -= win_bits; 56 | 57 | return w; 58 | } 59 | 60 | 61 | template< typename modnum_tp > 62 | __device__ int 63 | modexp::scan_zero_window(int &hi_idx, fixnum &n) { 64 | int nzeros = fixnum::two_valuation(n); 65 | fixnum::rshift(n, n, nzeros); 66 | hi_idx -= nzeros; 67 | return nzeros; 68 | } 69 | 70 | 71 | template< typename modnum_tp > 72 | __device__ uint32_t 73 | modexp::scan_window(int &hi_idx, fixnum &n, int max_window_bits) { 74 | int nzeros; 75 | uint32_t window; 76 | nzeros = scan_zero_window(hi_idx, n); 77 | window = scan_nonzero_window(hi_idx, n, max_window_bits); 78 | // top half is the odd window, bottom half is nzeros 79 | // TODO: fix magic number 80 | return (window << 16) | nzeros; 81 | } 82 | 83 | 84 | template< typename modnum_tp > 85 | __device__ 86 | modexp::modexp(fixnum mod, fixnum exp) 87 | : modnum(mod) 88 | { 89 | // sliding window decomposition 90 | int hi_idx; 91 | 92 | hi_idx = fixnum::msb(exp); 93 | window_size = internal::bits_to_clnw_window_size(hi_idx + 1); 94 | 95 | uint32_t *data; 96 | int L = fixnum::layout::laneIdx(); 97 | // TODO: This does one malloc per slot; the sliding window exponentiation 98 | // only really makes sense with fixed exponent, so we should be able to arrange 99 | // things so we only need one malloc per warp or even one malloc per thread block. 100 | if (L == 0) { 101 | int max_windows; 102 | internal::ceilquo(max_windows, fixnum::BITS, window_size); 103 | // NB: Default heap on the device is 8MB. 104 | data = (uint32_t *) malloc(max_windows * sizeof(uint32_t)); 105 | // FIXME: Handle this error properly. 106 | assert(data != nullptr); 107 | } 108 | // Broadcast data to each thread in the slot. 109 | exp_wins = (uint32_t *) __shfl_sync(fixnum::layout::mask(), (uintptr_t)data, 0, fixnum::layout::WIDTH); 110 | uint32_t *ptr = exp_wins; 111 | while (hi_idx >= 0) 112 | *ptr++ = scan_window(hi_idx, exp, window_size); 113 | exp_wins_len = ptr - exp_wins; 114 | } 115 | 116 | 117 | template< typename modnum_tp > 118 | __device__ 119 | modexp::~modexp() 120 | { 121 | if (fixnum::layout::laneIdx() == 0) 122 | free(exp_wins); 123 | } 124 | 125 | 126 | template< typename modnum_tp > 127 | __device__ void 128 | modexp::operator()(fixnum &z, fixnum x) const 129 | { 130 | static constexpr int WINDOW_MAX_BITS = 16; 131 | static constexpr int WINDOW_LEN_MASK = (1UL << WINDOW_MAX_BITS) - 1UL; 132 | // TODO: Actual maximum is 16 at the moment (see above), but it will very 133 | // rarely need to be more than 7. Consider storing G in shared memory to 134 | // remove the need for WINDOW_MAX_BITS altogether. 135 | static constexpr int WINDOW_MAX_BITS_REDUCED = 7; 136 | static constexpr int WINDOW_MAX_VAL_REDUCED = 1U << WINDOW_MAX_BITS_REDUCED; 137 | assert(window_size <= WINDOW_MAX_BITS_REDUCED); 138 | 139 | // We need to know that exp_wins_len > 0 when z is initialised just before 140 | // the main loop. 141 | if (exp_wins_len == 0) { 142 | //z = fixnum::one(); 143 | // TODO: This complicated way of producing a 1 is to 144 | // accommodate the possibility that monty.is_valid is false. 145 | modnum.from_modnum(z, modnum.one()); 146 | return; 147 | } 148 | 149 | // TODO: handle case of small exponent specially 150 | 151 | int window_max = 1U << window_size; 152 | /* G[t] = z^(2t + 1) t >= 0 (odd powers of z) */ 153 | fixnum G[WINDOW_MAX_VAL_REDUCED / 2]; 154 | modnum.to_modnum(z, x); 155 | G[0] = z; 156 | if (window_size > 1) { 157 | modnum.sqr(z, z); 158 | for (int t = 1; t < window_max / 2; ++t) { 159 | G[t] = G[t - 1]; 160 | modnum.mul(G[t], G[t], z); 161 | } 162 | } 163 | 164 | // Iterate over windows from most significant window to least significant 165 | // (i.e. reverse order from the order they're stored). 166 | const uint32_t *windows = exp_wins + exp_wins_len - 1; 167 | uint32_t win = *windows--; 168 | uint16_t two_val = win & WINDOW_LEN_MASK; 169 | uint16_t e = win >> WINDOW_MAX_BITS; 170 | 171 | z = G[e / 2]; 172 | while (two_val-- > 0) 173 | modnum.sqr(z, z); 174 | 175 | while (windows >= exp_wins) { 176 | two_val = window_size; 177 | while (two_val-- > 0) 178 | modnum.sqr(z, z); 179 | 180 | win = *windows--; 181 | two_val = win & WINDOW_LEN_MASK; 182 | e = win >> WINDOW_MAX_BITS; 183 | 184 | modnum.mul(z, z, G[e / 2]); 185 | while (two_val-- > 0) 186 | modnum.sqr(z, z); 187 | } 188 | modnum.from_modnum(z, z); 189 | } 190 | 191 | } // End namespace cuFIXNUM 192 | -------------------------------------------------------------------------------- /src/functions/modinv.cu: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | namespace cuFIXNUM { 4 | 5 | /* 6 | * Calculate the modular inverse. 7 | * TODO: Only supports moduli of the form 2^k at the moment. 8 | */ 9 | template< typename fixnum > 10 | struct modinv { 11 | /* 12 | * Return x = 1/b (mod 2^k). Must have 0 < k <= BITS. 13 | * 14 | * Source: MCA Algorithm 1.10. 15 | * 16 | * TODO: Calculate this using the multiple inversion trick (MCA 2.5.1) 17 | */ 18 | __device__ void operator()(fixnum &x, fixnum b, int k) const { 19 | typedef typename fixnum::digit digit; 20 | // b must be odd 21 | digit b0 = fixnum::get(b, 0); 22 | assert(k > 0 && k <= fixnum::BITS); 23 | 24 | digit binv; 25 | digit::modinv_2exp(binv, b0); 26 | x = fixnum::zero(); 27 | fixnum::set(x, binv, 0); 28 | if (k <= digit::BITS) { 29 | digit::rem_2exp(x, x, k); 30 | return; 31 | } 32 | 33 | // Hensel lift x from (mod 2^WORD_BITS) to (mod 2^k) 34 | // FIXME: Double-check this condition on k! 35 | while (k >>= 1) { 36 | fixnum t; 37 | // TODO: Make multiplications faster by using the "middle 38 | // product" (see MCA 1.4.5 and 3.3.2). 39 | fixnum::mul_lo(t, b, x); 40 | fixnum::sub(t, fixnum::one(), t); 41 | fixnum::mul_lo(t, t, x); 42 | fixnum::add(x, x, t); 43 | } 44 | } 45 | }; 46 | 47 | } // End namespace cuFIXNUM 48 | -------------------------------------------------------------------------------- /src/functions/multi_modexp.cu: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "functions/internal/modexp_impl.cu" 4 | #include "modnum/modnum_monty_cios.cu" 5 | 6 | namespace cuFIXNUM { 7 | 8 | template< 9 | typename modnum_tp, 10 | int WINDOW_SIZE = internal::bytes_to_k_ary_window_size(modnum_tp::fixnum::BYTES) > 11 | class multi_modexp { 12 | static_assert(WINDOW_SIZE >= 1 && WINDOW_SIZE < modnum_tp::fixnum::digit::BITS, 13 | "Invalid window size."); 14 | 15 | // TODO: Generalise multi_modexp so that it can work with any modular 16 | // multiplication algorithm. 17 | const modnum_tp modnum; 18 | 19 | public: 20 | typedef typename modnum_tp::fixnum fixnum; 21 | 22 | __device__ multi_modexp(fixnum mod) 23 | : modnum(mod) { } 24 | 25 | __device__ void operator()(fixnum &z, fixnum x, fixnum e) const; 26 | }; 27 | 28 | 29 | /* 30 | * Left-to-right k-ary exponentiation (see [HAC, Algorithm 14.82]). 31 | * 32 | * If it is known that the exponents given to this function will be small, then 33 | * a better window size can be chosen. The window size should be the left value 34 | * in the pair below whose right value is the largest less than the exponent. 35 | * For example, exponents of 192 bits should take the window 4 corresponding to 36 | * 122. 37 | * 38 | * [[1, 1], [2, 7], [3, 35], [4, 122], [5, 369], [6, 1044], 39 | * [7, 2823], [8, 7371], [9, 18726], [10, 46490]] 40 | * 41 | * See the documentation in "functions/internal/modexp_impl.cu" for more 42 | * information. 43 | * 44 | * TODO: The basic algorithm is applied to each word of the exponent in turn, so 45 | * the last window used on each exponent word will be smaller than WINDOW_SIZE. 46 | * Need a better way to scan the exponent so that the same WINDOW_SIZE is used 47 | * throughout. 48 | * 49 | * TODO: Should only start the algorithm at the msb of e; it will result in many 50 | * idle threads, but the current setup means they do pointless work; at least if 51 | * they're idle they might make space for other work to be done. Document the 52 | * fact that inputs should be ordered such that groups with similar exponents 53 | * are together. 54 | * 55 | * NB: I don't immediately see how to use the "modified" variant [HAC, Algo 56 | * 14.83] since there the number of squarings depends on the 2-adic valuation of 57 | * the window value. 58 | */ 59 | template< typename modnum_tp, int WINDOW_SIZE > 60 | __device__ void 61 | multi_modexp::operator()(fixnum &z, fixnum x, fixnum e) const 62 | { 63 | typedef typename modnum_tp::fixnum::digit digit; 64 | static constexpr int WIDTH = fixnum::SLOT_WIDTH; 65 | 66 | // Window decomposition: digit::BITS = q * WINDOW_SIZE + r. 67 | static constexpr int WINDOW_REM_BITS = digit::BITS % WINDOW_SIZE; 68 | static constexpr int WINDOW_MAX = (1U << WINDOW_SIZE); 69 | 70 | /* G[t] = z^t, t >= 0 */ 71 | fixnum G[WINDOW_MAX]; 72 | modnum.to_modnum(z, x); 73 | G[0] = modnum.one(); 74 | for (int t = 1; t < WINDOW_MAX; ++t) { 75 | G[t] = G[t - 1]; 76 | modnum.mul(G[t], G[t], z); 77 | } 78 | 79 | z = G[0]; 80 | for (int i = WIDTH - 1; i >= 0; --i) { 81 | digit f = fixnum::get(e, i); 82 | 83 | // TODO: The squarings are noops on the first iteration (i = 84 | // w-1) and should be removed. 85 | digit win; // TODO: Morally this should be an int 86 | for (int j = digit::BITS - WINDOW_SIZE; j >= 0; j -= WINDOW_SIZE) { 87 | // TODO: For some bizarre reason, it is significantly 88 | // faster to do this loop than it is to unroll the 5 89 | // statements manually. Idem for the remainder below. 90 | // Investigate how this is even possible! 91 | for (int k = 0; k < WINDOW_SIZE; ++k) 92 | modnum.sqr(z, z); 93 | digit fj; 94 | // win = (f >> j) & WINDOW_MAIN_MASK; 95 | digit::rshift(fj, f, j); 96 | digit::rem_2exp(win, fj, WINDOW_SIZE); 97 | modnum.mul(z, z, G[win]); 98 | } 99 | 100 | // Remainder 101 | for (int k = 0; k < WINDOW_REM_BITS; ++k) 102 | modnum.sqr(z, z); 103 | //win = f & WINDOW_REM_MASK; 104 | digit::rem_2exp(win, f, WINDOW_REM_BITS); 105 | modnum.mul(z, z, G[win]); 106 | } 107 | modnum.from_modnum(z, z); 108 | } 109 | 110 | } // End namespace cuFIXNUM 111 | -------------------------------------------------------------------------------- /src/functions/paillier_decrypt.cu: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "functions/quorem_preinv.cu" 4 | #include "functions/divexact.cu" 5 | #include "functions/chinese.cu" 6 | #include "functions/multi_modexp.cu" 7 | #include "modnum/modnum_monty_cios.cu" 8 | 9 | namespace cuFIXNUM { 10 | 11 | template< typename fixnum > 12 | class paillier_decrypt_mod; 13 | 14 | template< typename fixnum > 15 | class paillier_decrypt { 16 | public: 17 | __device__ paillier_decrypt(fixnum p, fixnum q) 18 | : n(prod(p, q)) 19 | , crt(p, q) 20 | , decrypt_modp(p, n) 21 | , decrypt_modq(q, n) { } 22 | 23 | __device__ void operator()(fixnum &ptxt, fixnum ctxt_hi, fixnum ctxt_lo) const; 24 | 25 | private: 26 | // We only need this in the constructor to initialise decrypt_mod[pq], but we 27 | // store it here because it's the only way to save the computation and pass 28 | // it to the constructors of decrypt_mod[pq]. 29 | fixnum n; 30 | 31 | // Secret key is (p, q). 32 | paillier_decrypt_mod decrypt_modp, decrypt_modq; 33 | 34 | // TODO: crt and decrypt_modq both compute and hold quorem_preinv(q); find a 35 | // way to share them. 36 | chinese crt; 37 | 38 | // TODO: It is flipping stupid that this is necessary. 39 | __device__ fixnum prod(fixnum p, fixnum q) { 40 | fixnum n; 41 | // TODO: These don't work when SLOT_WIDTH = 0 42 | //assert(fixnum::slot_layout::laneIdx() < fixnum::SLOT_WIDTH/2 || p == 0); 43 | //assert(fixnum::slot_layout::laneIdx() < fixnum::SLOT_WIDTH/2 || q == 0); 44 | fixnum::mul_lo(n, p, q); 45 | return n; 46 | } 47 | }; 48 | 49 | /** 50 | * Decrypt the ciphertext c = (c_hi, c_lo) and put the resulting plaintext in m. 51 | * 52 | * m, c_hi and c_lo must be PLAINTEXT_DIGITS long. 53 | */ 54 | template< typename fixnum > 55 | __device__ void 56 | paillier_decrypt::operator()(fixnum &ptxt, fixnum ctxt_hi, fixnum ctxt_lo) const 57 | { 58 | fixnum mp, mq; 59 | decrypt_modp(mp, ctxt_hi, ctxt_lo); 60 | decrypt_modq(mq, ctxt_hi, ctxt_lo); 61 | crt(ptxt, mp, mq); 62 | } 63 | 64 | 65 | template< typename fixnum > 66 | class paillier_decrypt_mod { 67 | public: 68 | __device__ paillier_decrypt_mod(fixnum p, fixnum n); 69 | 70 | __device__ void operator()(fixnum &mp, fixnum c_hi, fixnum c_lo) const; 71 | 72 | private: 73 | // FIXME: These all have width = WIDTH/2, so this is a waste of 74 | // space, and (worse) the operations below waste cycles. 75 | 76 | // Precomputation of 77 | // L((1 + n)^(p - 1) mod p^2)^-1 (mod p) 78 | // for CRT, where n = pq is the public key, and L(x) = (x-1)/p. 79 | fixnum h; 80 | 81 | // We only need this in the constructor to initialise mod_p2 and pow, but we 82 | // store it here because it's the only way to save the computation and pass 83 | // it to the constructors of mod_p2 and pow. 84 | fixnum p_sqr; 85 | 86 | // Exact division by p 87 | divexact div_p; 88 | // Remainder after division by p. 89 | quorem_preinv mod_p; 90 | // Remainder after division by p^2. 91 | quorem_preinv mod_p2; 92 | 93 | // Modexp for x |--> x^(p - 1) (mod p^2) 94 | typedef modnum_monty_cios modnum; 95 | modexp pow; 96 | 97 | // TODO: It is flipping stupid that these are necessary. 98 | __device__ fixnum square(fixnum p) { 99 | fixnum p2; 100 | // TODO: This doesn't work when SLOT_WIDTH = 0 101 | //assert(fixnum::slot_layout::laneIdx() < fixnum::SLOT_WIDTH/2 || p == 0); 102 | fixnum::sqr_lo(p2, p); 103 | return p2; 104 | } 105 | __device__ fixnum sub1(fixnum p) { 106 | fixnum pm1; 107 | fixnum::sub(pm1, p, fixnum::one()); 108 | return pm1; 109 | } 110 | }; 111 | 112 | 113 | template< typename fixnum > 114 | __device__ 115 | paillier_decrypt_mod::paillier_decrypt_mod(fixnum p, fixnum n) 116 | : p_sqr(square(p)) 117 | , div_p(p) 118 | , mod_p(p) 119 | , mod_p2(p_sqr) 120 | , pow(p_sqr, sub1(p)) 121 | { 122 | typedef typename fixnum::digit digit; 123 | digit cy; 124 | fixnum t = n; 125 | cy = fixnum::incr_cy(t); 126 | // n is the product of primes, and 2^(2^k) - 1 has (at least) k factors, 127 | // hence n is less than 2^FIXNUM_BITS - 1, hence incrementing n shouldn't 128 | // overflow. 129 | assert(digit::is_zero(cy)); 130 | // TODO: Check whether reducing t is necessary. 131 | mod_p2(t, fixnum::zero(), t); 132 | pow(t, t); 133 | fixnum::decr_br(t); 134 | div_p(t, t); 135 | 136 | // TODO: Make modinv use xgcd and use modinv instead. 137 | // Use a^(p-2) = 1 (mod p) 138 | fixnum pm2; 139 | fixnum::sub(pm2, p, fixnum::two()); 140 | multi_modexp minv(p); 141 | minv(h, t, pm2); 142 | } 143 | 144 | /* 145 | * Decrypt ciphertext (c_hi, c_lo) and put the result in mp. 146 | * 147 | * Decryption mod p of c is put in the (bottom half of) mp. 148 | */ 149 | template< typename fixnum > 150 | __device__ void 151 | paillier_decrypt_mod::operator()(fixnum &mp, fixnum c_hi, fixnum c_lo) const 152 | { 153 | fixnum c, u, hi, lo; 154 | // mp = c_hi * 2^n + c_lo (mod p^2) which is nonzero because p != q 155 | mod_p2(c, c_hi, c_lo); 156 | 157 | pow(u, c); 158 | fixnum::decr_br(u); 159 | div_p(u, u); 160 | // Check that the high half of u is now zero. 161 | // TODO: This doesn't work when SLOT_WIDTH = 0 162 | //assert(fixnum::slot_layout::laneIdx() < fixnum::SLOT_WIDTH/2 || u == 0); 163 | 164 | // TODO: make use of the fact that u and h are half-width. 165 | fixnum::mul_wide(hi, lo, u, h); 166 | assert(fixnum::is_zero(hi)); 167 | mod_p(mp, hi, lo); 168 | } 169 | 170 | } // End namespace cuFIXNUM 171 | -------------------------------------------------------------------------------- /src/functions/paillier_encrypt.cu: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "functions/quorem_preinv.cu" 4 | #include "functions/multi_modexp.cu" 5 | #include "modnum/modnum_monty_cios.cu" 6 | 7 | namespace cuFIXNUM { 8 | 9 | template< typename fixnum > 10 | class paillier_encrypt { 11 | public: 12 | __device__ paillier_encrypt(fixnum n_) 13 | : n(n_), n_sqr(square(n_)), pow(n_sqr, n_), mod_n2(n_sqr) { } 14 | 15 | /* 16 | * NB: In reality, the values r^n should be calculated out-of-band or 17 | * stock-piled and piped into an encryption function. 18 | */ 19 | __device__ void operator()(fixnum &ctxt, fixnum m, fixnum r) const { 20 | // TODO: test this properly 21 | //assert(fixnum::slot_layout::laneIdx() < fixnum::SLOT_WIDTH/2 || m == 0); 22 | fixnum::mul_lo(m, m, n); 23 | fixnum::incr_cy(m); 24 | pow(r, r); 25 | fixnum c_hi, c_lo; 26 | fixnum::mul_wide(c_hi, c_lo, m, r); 27 | mod_n2(ctxt, c_hi, c_lo); 28 | } 29 | 30 | private: 31 | typedef modnum_monty_cios modnum; 32 | 33 | fixnum n; 34 | fixnum n_sqr; 35 | modexp pow; 36 | quorem_preinv mod_n2; 37 | 38 | // TODO: It is flipping stupid that this is necessary. 39 | __device__ fixnum square(fixnum n) { 40 | fixnum n2; 41 | // TODO: test this properly 42 | //assert(fixnum::slot_layout::laneIdx() < fixnum::SLOT_WIDTH/2 || n == 0); 43 | fixnum::sqr_lo(n2, n); 44 | return n2; 45 | } 46 | }; 47 | 48 | } // End namespace cuFIXNUM 49 | -------------------------------------------------------------------------------- /src/functions/quorem.cu: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | namespace cuFIXNUM { 4 | 5 | /* 6 | * Quotient and remainder via long-division. 7 | * 8 | * Source: MCA Algo 1.6, HAC Algo 14.20. 9 | * 10 | * TODO: Implement Svoboda divisor preconditioning (using 11 | * Newton-Raphson iteration to calculate floor(beta^(n+1)/div)) (see 12 | * MCA Algo 1.7). 13 | */ 14 | template< typename fixnum > 15 | class quorem { 16 | static constexpr int WIDTH = fixnum::SLOT_WIDTH; 17 | typedef typename fixnum::digit digit; 18 | 19 | public: 20 | __device__ void operator()( 21 | fixnum &q, fixnum &r, 22 | fixnum A, fixnum div) const; 23 | 24 | // TODO: These functions obviously belong somewhere else. The need 25 | // to be available to both quorem (here) and quorem_preinv. 26 | static __device__ int normalise_divisor(fixnum &div); 27 | static __device__ fixnum normalise_dividend(fixnum &u, int k); 28 | static __device__ fixnum normalise_dividend(fixnum &u_hi, fixnum &u_lo, int k); 29 | static __device__ void quorem_with_candidate_quotient( 30 | fixnum &quo, fixnum &rem, 31 | fixnum A_hi, fixnum A_lo, fixnum div, fixnum q); 32 | }; 33 | 34 | template< typename fixnum > 35 | __device__ int 36 | quorem::normalise_divisor(fixnum &div) { 37 | static constexpr int BITS = fixnum::BITS; 38 | int lz = BITS - (fixnum::msb(div) + 1); 39 | fixnum overflow; 40 | fixnum::lshift(div, overflow, div, lz); 41 | assert(fixnum::is_zero(overflow)); 42 | return lz; 43 | } 44 | 45 | // TODO: Ideally the algos would be written to incorporate the 46 | // normalisation factor, rather than "physically" normalising the 47 | // dividend. 48 | template< typename fixnum > 49 | __device__ fixnum 50 | quorem::normalise_dividend(fixnum &u, int k) { 51 | fixnum overflow; 52 | fixnum::lshift(u, overflow, u, k); 53 | return overflow; 54 | } 55 | 56 | // TODO: Ideally the algos would be written to incorporate the 57 | // normalisation factor, rather than "physically" normalising the 58 | // dividend. 59 | template< typename fixnum > 60 | __device__ fixnum 61 | quorem::normalise_dividend(fixnum &u_hi, fixnum &u_lo, int k) { 62 | fixnum hi_part, middle_part; 63 | fixnum::lshift(u_hi, hi_part, u_hi, k); 64 | fixnum::lshift(u_lo, middle_part, u_lo, k); 65 | digit cy; 66 | fixnum::add_cy(u_hi, cy, u_hi, middle_part); 67 | assert(digit::is_zero(cy)); 68 | return hi_part; 69 | } 70 | 71 | template< typename fixnum > 72 | __device__ void 73 | quorem::quorem_with_candidate_quotient( 74 | fixnum &quo, fixnum &rem, 75 | fixnum A_hi, fixnum A_lo, fixnum div, fixnum q) 76 | { 77 | fixnum hi, lo, r, t, msw; 78 | digit br; 79 | int L = fixnum::layout::laneIdx(); 80 | 81 | // (hi, lo) = q*d 82 | fixnum::mul_wide(hi, lo, q, div); 83 | 84 | // (msw, r) = A - q*d 85 | fixnum::sub_br(r, br, A_lo, lo); 86 | fixnum::sub_br(msw, t, A_hi, hi); 87 | assert(digit::is_zero(t)); // A_hi >= hi 88 | 89 | // TODO: Could skip these two lines if we could pass br to the last 90 | // sub_br above as a "borrow in". 91 | // Make br into a fixnum 92 | br = (L == 0) ? br : digit::zero(); // digit to fixnum 93 | fixnum::sub_br(msw, t, msw, br); 94 | assert(digit::is_zero(t)); // msw >= br 95 | assert((L == 0 && digit::cmp(msw, 4) < 0) 96 | || digit::is_zero(msw)); // msw < 4 (TODO: possibly should have msw < 3) 97 | // Broadcast 98 | msw = fixnum::layout::shfl(msw, 0); 99 | 100 | // NB: Could call incr_cy in the loops instead; as is, it will 101 | // incur an extra add_cy even when msw is 0 and r < d. 102 | digit q_inc = digit::zero(); 103 | while ( ! digit::is_zero(msw)) { 104 | fixnum::sub_br(r, br, r, div); 105 | digit::sub(msw, msw, br); 106 | digit::incr(q_inc); 107 | } 108 | fixnum::sub_br(t, br, r, div); 109 | while (digit::is_zero(br)) { 110 | r = t; 111 | digit::incr(q_inc); 112 | fixnum::sub_br(t, br, r, div); 113 | } 114 | // TODO: Replace loops above with something like the one below, 115 | // which will reduce warp divergence a bit. 116 | #if 0 117 | fixnum tmp, q_inc; 118 | while (1) { 119 | br = fixnum::sub_br(tmp, r, div); 120 | if (msw == 0 && br == 1) 121 | break; 122 | msr -= br; 123 | ++q_inc; 124 | r = tmp; 125 | } 126 | #endif 127 | 128 | q_inc = (L == 0) ? q_inc : digit::zero(); 129 | fixnum::add(q, q, q_inc); 130 | 131 | quo = q; 132 | rem = r; 133 | } 134 | 135 | #if 0 136 | template< typename fixnum > 137 | __device__ void 138 | quorem::operator()( 139 | fixnum &q_hi, fixnum &q_lo, fixnum &r, 140 | fixnum A_hi, fixnum A_lo, fixnum div) const 141 | { 142 | int k = normalise_divisor(div); 143 | fixnum t = normalise_dividend(A_hi, A_lo, k); 144 | assert(t == 0); // dividend too big. 145 | 146 | fixnum r_hi; 147 | (*this)(q_hi, r_hi, A_hi, div); 148 | 149 | // FIXME WRONG! r_hi is not a good enough candidate quotient! 150 | // Do div2by1 of (r_hi, A_lo) by div using that r_hi < div. 151 | // r_hi is now the candidate quotient 152 | fixnum qq = r_hi; 153 | if (fixnum::cmp(A_lo, div) > 0) 154 | fixnum::incr_cy(qq); 155 | 156 | quorem_with_candidate_quotient(q_lo, r, r_hi, A_lo, div, qq); 157 | 158 | digit lo_bits = fixnum::rshift(r, r, k); 159 | assert(lo_bits == 0); 160 | } 161 | #endif 162 | 163 | // TODO: Implement a specifically *parallel* algorithm for division, 164 | // such as those of Takahashi. 165 | template< typename fixnum > 166 | __device__ void 167 | quorem::operator()( 168 | fixnum &q, fixnum &r, fixnum A, fixnum div) const 169 | { 170 | int n = fixnum::most_sig_dig(div) + 1; 171 | assert(n >= 0); // division by zero. 172 | 173 | digit div_msw = fixnum::get(div, n - 1); 174 | 175 | // TODO: Factor out the normalisation code. 176 | int k = digit::clz(div_msw); // guaranteed to be >= 0, since div_msw != 0 177 | 178 | // div is normalised when its msw is >= 2^(WORD_BITS - 1), 179 | // i.e. when its highest bit is on, i.e. when the number of 180 | // leading zeros of msw is 0. 181 | if (k > 0) { 182 | fixnum h; 183 | // Normalise div by shifting it to the left. 184 | fixnum::lshift(div, h, div, k); 185 | assert(fixnum::is_zero(h)); 186 | fixnum::lshift(A, h, A, k); 187 | // FIXME: We should be able to handle this case. 188 | assert(fixnum::is_zero(h)); // FIXME: check if h == 0 using cmp() and zero() 189 | digit::lshift(div_msw, div_msw, k); 190 | } 191 | 192 | int m = fixnum::most_sig_dig(A) - n + 1; 193 | // FIXME: Just return div in this case 194 | assert(m >= 0); // dividend too small 195 | 196 | // TODO: Work out if we can just incorporate the normalisation factor k 197 | // into the subsequent algorithm, rather than actually modifying div and A. 198 | 199 | q = r = fixnum::zero(); 200 | 201 | // Set q_m 202 | digit qj; 203 | fixnum dj, tmp; 204 | // TODO: Urgh. 205 | typedef typename fixnum::layout layout; 206 | dj = layout::shfl_up0(div, m); 207 | digit br; 208 | fixnum::sub_br(tmp, br, A, dj); 209 | if (br) qj = fixnum::zero(); // dj > A 210 | else { qj = fixnum::one(); A = tmp; } 211 | 212 | fixnum::set(q, qj, m); 213 | 214 | digit dinv = internal::quorem_reciprocal(div_msw); 215 | for (int j = m - 1; j >= 0; --j) { 216 | digit a_hi, a_lo, hi, dummy; 217 | 218 | // (q_hi, q_lo) = floor((a_{n+j} B + a_{n+j-1}) / div_msw) 219 | // TODO: a_{n+j} is a_{n+j-1} from the previous iteration; hence I 220 | // should be able to get away with just one call to get() per 221 | // iteration. 222 | // TODO: Could normalise A on the fly here, one word at a time. 223 | a_hi = fixnum::get(A, n + j); 224 | a_lo = fixnum::get(A, n + j - 1); 225 | 226 | // TODO: uquorem_wide has a bad branch at the start which will 227 | // cause trouble when div_msw < a_hi is not universally true 228 | // across the warp. Need to investigate ways to alleviate that. 229 | digit::quorem_wide(qj, dummy, a_hi, a_lo, div_msw, dinv); 230 | 231 | dj = layout::shfl_up0(div, j); 232 | hi = fixnum::mul_digit(tmp, qj, dj); 233 | assert(digit::is_zero(hi)); 234 | 235 | int iters = 0; 236 | fixnum AA; 237 | while (1) { 238 | fixnum::sub_br(AA, br, A, tmp); 239 | if (!br) 240 | break; 241 | fixnum::sub_br(tmp, br, tmp, dj); 242 | assert(digit::is_zero(br)); 243 | --qj; 244 | ++iters; 245 | } 246 | A = AA; 247 | assert(iters <= 2); // MCA, Proof of Theorem 1.3. 248 | fixnum::set(q, qj, j); 249 | } 250 | // Denormalise A to produce r. 251 | fixnum::rshift(r, tmp, A, k); 252 | assert(fixnum::is_zero(tmp)); // Above division should be exact. 253 | } 254 | 255 | } // End namespace cuFIXNUM 256 | -------------------------------------------------------------------------------- /src/functions/quorem_preinv.cu: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "functions/quorem.cu" 4 | 5 | namespace cuFIXNUM { 6 | 7 | /* 8 | * Quotient and remainder via Barrett reduction. 9 | * 10 | * div: the divisor 11 | * mu: floor(2^(2*NBITS) / div) where NBITS = FIXNUM_BITS (note: mu has an 12 | * implicit hi bit). 13 | */ 14 | template< typename fixnum > 15 | class quorem_preinv { 16 | public: 17 | __device__ quorem_preinv(fixnum div); 18 | 19 | // Assume clz(A) <= clz(div) 20 | __device__ void operator()(fixnum &q, fixnum &r, fixnum A_hi, fixnum A_lo) const; 21 | 22 | // Just return the remainder. 23 | __device__ void operator()(fixnum &r, fixnum A_hi, fixnum A_lo) const { 24 | fixnum q; 25 | (*this)(q, r, A_hi, A_lo); 26 | } 27 | 28 | // TODO: This should be somewhere more appropriate. 29 | __device__ static void reciprocal_approx(fixnum &mu, fixnum div); 30 | 31 | private: 32 | static constexpr int WIDTH = fixnum::SLOT_WIDTH; 33 | typedef typename fixnum::digit digit; 34 | 35 | // Note that mu has an implicit hi bit that is always on. 36 | fixnum div, mu; 37 | int lz; 38 | }; 39 | 40 | // Assumes div has been normalised already. 41 | // NB: Result has implicit hi bit on. 42 | // TODO: This function should be generalised and made available more widely 43 | template< typename fixnum > 44 | __device__ void 45 | quorem_preinv::reciprocal_approx(fixnum &mu, fixnum div) 46 | { 47 | // Let B = 2^FIXNUM_BITS 48 | 49 | // Initial estimate is 2*B - div = B + (B - div) (implicit hi bit) 50 | // TODO: Use better initial estimate: (48/17) - (32/17)*div (see 51 | // https://en.wikipedia.org/wiki/Division_algorithm#Newton-Raphson_division) 52 | fixnum::neg(mu, div); 53 | 54 | // If we initialise mu = 2*B - div, then the error is 1.0 - mu*div/B^2 < 1/4. 55 | // In general, the error after iteration k = 0, 1, ... less than 1/(4^(2^k)). 56 | // We need an error less than 1/B^2, hence k >= log2(log2(B)). 57 | static constexpr uint32_t BITS = fixnum::BITS; 58 | // FIXME: For some reason this code doesn't converge as fast as it should. 59 | const int NITERS = internal::ctz(BITS); // TODO: Make ctz, hence NITERS, a constexpr 60 | int L = fixnum::layout::laneIdx(); 61 | 62 | // TODO: Instead of calculating/using floor(B^2/div), calculate/use the 63 | // equivalent floor((B^2 - 1)/div) - B as described in the Möller & Granlund 64 | // paper; this should allow simplification because there's no implicit hi bit 65 | // in mu to account for. 66 | for (int i = 0; i < NITERS; ++i) { 67 | digit cy, br; 68 | fixnum a, b, c, d, e; 69 | 70 | // (hi, lo) = B^2 - mu*div. This is always positive. 71 | fixnum::mul_wide(a, b, mu, div); 72 | fixnum::add_cy(a, cy, a, div); // account for hi bit of mu 73 | // cy will be 1 when mu = floor(B^2/div), which happens on the last iteration 74 | assert(digit::is_zero(cy)); 75 | fixnum::sub_br(b, br, fixnum::zero(), b); // br == 0 iff b == 0. 76 | br = (L == 0) ? br : digit::zero(); 77 | fixnum::neg(a, a); 78 | fixnum::sub(a, a, br); 79 | 80 | // TODO: a + c is actually correct to within a single bit; investigate 81 | // whether using a mu that is off by one bit matters? If it does, we 82 | // should only do this correction on the last iteration. 83 | // TODO: Implement fused-multiply-add and use it here for "a*mu + b". 84 | fixnum::mul_wide(c, d, a, mu); 85 | fixnum::add_cy(d, cy, d, b); 86 | cy = (L == 0) ? cy : digit::zero(); 87 | fixnum::add_cy(c, cy, c, cy); 88 | assert(digit::is_zero(cy)); 89 | 90 | // cy is the single extra bit that propogates to (a + c) 91 | fixnum::mul_hi(e, mu, b); 92 | fixnum::add_cy(d, cy, d, e); 93 | cy = (L == 0) ? cy : digit::zero(); 94 | 95 | // mu += a + c + cy_in 96 | fixnum::add_cy(a, cy, a, cy); assert(digit::is_zero(cy)); 97 | fixnum::add_cy(mu, cy, mu, c); assert(digit::is_zero(cy)); 98 | fixnum::add_cy(mu, cy, mu, a); assert(digit::is_zero(cy)); 99 | } 100 | } 101 | 102 | 103 | /* 104 | * Create a quorem_preinv object. 105 | * 106 | * Raise an error if div does not have a sufficiently high bit switched 107 | * on. 108 | */ 109 | template< typename fixnum > 110 | __device__ 111 | quorem_preinv::quorem_preinv(fixnum div_) 112 | : div(div_) 113 | { 114 | lz = quorem::normalise_divisor(div); 115 | reciprocal_approx(mu, div); 116 | } 117 | 118 | /* 119 | * Return the quotient and remainder of A after division by div. 120 | * 121 | * Uses Barret reduction. See HAC, Algo 14.42, and MCA, Algo 2.5. 122 | */ 123 | template< typename fixnum > 124 | __device__ void 125 | quorem_preinv::operator()( 126 | fixnum &q, fixnum &r, fixnum A_hi, fixnum A_lo) const 127 | { 128 | fixnum t; 129 | int L = fixnum::layout::laneIdx(); 130 | 131 | // Normalise A 132 | // TODO: Rather than normalising A, we should incorporate the 133 | // normalisation factor into the algorithm at the appropriate 134 | // place. 135 | t = quorem::normalise_dividend(A_hi, A_lo, lz); 136 | assert(fixnum::is_zero(t)); 137 | 138 | // q = "A_hi * mu / 2^NBITS" 139 | // TODO: the lower half of the product, t, is unused, so we might 140 | // be able to use a mul_hi() function that only calculates an 141 | // approximate answer (see Short Product discussion at MCA, 142 | // Section 3.3 (from Section 2.4.1, p59)). 143 | fixnum::mul_wide(q, t, A_hi, mu); 144 | // TODO: For some reason (void)cy; does stop the compiler complaining about 145 | // cy being assigned but not used. Find a better way to avoid the warning 146 | // than this preprocessor crap. 147 | #ifndef NDEBUG 148 | digit cy; 149 | fixnum::add_cy(q, cy, q, A_hi); // mu has implicit hi bit 150 | assert(digit::is_zero(cy)); 151 | #else 152 | fixnum::add(q, q, A_hi); // mu has implicit hi bit 153 | #endif 154 | 155 | quorem::quorem_with_candidate_quotient(q, r, A_hi, A_lo, div, q); 156 | 157 | // Denormalise r 158 | fixnum lo_bits; 159 | fixnum::rshift(r, lo_bits, r, lz); 160 | assert(fixnum::is_zero(lo_bits)); 161 | } 162 | 163 | } // End namespace cuFIXNUM 164 | -------------------------------------------------------------------------------- /src/modnum/internal/monty.cu: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "functions/quorem_preinv.cu" 4 | 5 | namespace cuFIXNUM { 6 | 7 | namespace internal { 8 | 9 | template< typename fixnum_ > 10 | class monty { 11 | public: 12 | typedef fixnum_ fixnum; 13 | typedef fixnum modnum; 14 | 15 | __device__ monty(fixnum modulus); 16 | 17 | __device__ void add(modnum &z, modnum x, modnum y) const { 18 | fixnum::add(z, x, y); 19 | if (fixnum::cmp(z, mod) >= 0) 20 | fixnum::sub(z, z, mod); 21 | } 22 | 23 | __device__ void neg(modnum &z, modnum x) const { 24 | fixnum::sub(z, mod, x); 25 | } 26 | 27 | __device__ void sub(modnum &z, modnum x, modnum y) const { 28 | fixnum my; 29 | neg(my, y); 30 | fixnum::add(z, x, my); 31 | if (fixnum::cmp(z, mod) >= 0) 32 | fixnum::sub(z, z, mod); 33 | } 34 | 35 | /* 36 | * Return the Montgomery image of one. 37 | */ 38 | __device__ modnum one() const { 39 | return R_mod; 40 | } 41 | 42 | /* 43 | * Return the Montgomery image of one. 44 | */ 45 | __device__ modnum zero() const { 46 | return fixnum::zero(); 47 | } 48 | 49 | // FIXME: Get rid of this hack 50 | int is_valid; 51 | 52 | // Modulus for Monty arithmetic 53 | fixnum mod; 54 | // R_mod = 2^fixnum::BITS % mod 55 | modnum R_mod; 56 | // Rsqr = R^2 % mod 57 | modnum Rsqr_mod; 58 | 59 | // TODO: We save this after using it in the constructor; work out 60 | // how to make it available for later use. For example, it could 61 | // be used to reduce arguments to modexp prior to the main 62 | // iteration. 63 | quorem_preinv modrem; 64 | 65 | __device__ void normalise(modnum &x, int msb) const; 66 | }; 67 | 68 | 69 | template< typename fixnum > 70 | __device__ 71 | monty::monty(fixnum modulus) 72 | : mod(modulus), modrem(modulus) 73 | { 74 | // mod must be odd > 1 in order to calculate R^-1 mod "mod". 75 | // FIXME: Handle these errors properly 76 | if (fixnum::two_valuation(modulus) != 0 //fixnum::get(modulus, 0) & 1 == 0 77 | || fixnum::cmp(modulus, fixnum::one()) == 0) { 78 | is_valid = 0; 79 | return; 80 | } 81 | is_valid = 1; 82 | 83 | fixnum Rsqr_hi, Rsqr_lo; 84 | 85 | // R_mod = R % mod 86 | modrem(R_mod, fixnum::one(), fixnum::zero()); 87 | fixnum::sqr_wide(Rsqr_hi, Rsqr_lo, R_mod); 88 | // Rsqr_mod = R^2 % mod 89 | modrem(Rsqr_mod, Rsqr_hi, Rsqr_lo); 90 | } 91 | 92 | /* 93 | * Let X = x + msb * 2^64. Then return X -= m if X > m. 94 | * 95 | * Assumes X < 2*m, i.e. msb = 0 or 1, and if msb = 1, then x < m. 96 | */ 97 | template< typename fixnum > 98 | __device__ void 99 | monty::normalise(modnum &x, int msb) const { 100 | typedef typename fixnum::digit digit; 101 | modnum r; 102 | digit br; 103 | 104 | // br = 0 ==> x >= mod 105 | fixnum::sub_br(r, br, x, mod); 106 | if (msb || digit::is_zero(br)) { 107 | // If the msb was set, then we must have had to borrow. 108 | assert(!msb || msb == br); 109 | x = r; 110 | } 111 | } 112 | 113 | } // End namespace internal 114 | 115 | } // End namespace cuFIXNUM 116 | -------------------------------------------------------------------------------- /src/modnum/modnum_monty_cios.cu: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "functions/modinv.cu" 4 | #include "modnum/internal/monty.cu" 5 | 6 | namespace cuFIXNUM { 7 | 8 | template< typename fixnum_ > 9 | class modnum_monty_cios { 10 | public: 11 | typedef fixnum_ fixnum; 12 | typedef fixnum modnum; 13 | 14 | __device__ modnum_monty_cios(fixnum modulus); 15 | 16 | __device__ modnum zero() const { return monty.zero(); } 17 | __device__ modnum one() const { return monty.one(); } 18 | __device__ void add(modnum &z, modnum x, modnum y) const { monty.add(z, x, y); } 19 | __device__ void sub(modnum &z, modnum x, modnum y) const { monty.sub(z, x, y); } 20 | __device__ void neg(modnum &z, modnum x, modnum y) const { monty.neg(z, x); } 21 | 22 | /** 23 | * z <- x * y 24 | */ 25 | __device__ void mul(modnum &z, modnum x, modnum y) const; 26 | 27 | /** 28 | * z <- x^2 29 | */ 30 | __device__ void sqr(modnum &z, modnum x) const { 31 | mul(z, x, x); 32 | } 33 | 34 | // TODO: Might be worth specialising multiplication for this case, since one of 35 | // the operands is known. 36 | __device__ void to_modnum(modnum &z, fixnum x) const { 37 | mul(z, x, monty.Rsqr_mod); 38 | } 39 | 40 | // TODO: Might be worth specialising multiplication for this case, since one of 41 | // the operands is known. 42 | __device__ void from_modnum(fixnum &z, modnum x) const { 43 | mul(z, x, fixnum::one()); 44 | } 45 | 46 | private: 47 | typedef typename fixnum::digit digit; 48 | // TODO: Check whether we can get rid of this declaration 49 | static constexpr int WIDTH = fixnum::SLOT_WIDTH; 50 | 51 | internal::monty monty; 52 | 53 | // inv_mod * mod = -1 % 2^digit::BITS. 54 | digit inv_mod; 55 | }; 56 | 57 | 58 | template< typename fixnum > 59 | __device__ 60 | modnum_monty_cios::modnum_monty_cios(fixnum mod) 61 | : monty(mod) 62 | { 63 | if ( ! monty.is_valid) 64 | return; 65 | 66 | // TODO: Tidy this up. 67 | modinv minv; 68 | fixnum im; 69 | minv(im, mod, digit::BITS); 70 | digit::neg(inv_mod, im); 71 | // TODO: Ugh. 72 | typedef typename fixnum::layout layout; 73 | // TODO: Can we avoid this broadcast? 74 | inv_mod = layout::shfl(inv_mod, 0); 75 | assert(1 + inv_mod * layout::shfl(mod, 0) == 0); 76 | } 77 | 78 | /* 79 | * z = x * y (mod) in Monty form. 80 | * 81 | * Spliced multiplication/reduction implementation of Montgomery 82 | * modular multiplication. Specifically it is the CIOS (coursely 83 | * integrated operand scanning) splice. 84 | */ 85 | template< typename fixnum > 86 | __device__ void 87 | modnum_monty_cios::mul(modnum &z, modnum x, modnum y) const 88 | { 89 | typedef typename fixnum::layout layout; 90 | // FIXME: Fix this hack! 91 | z = zero(); 92 | if (!monty.is_valid) { return; } 93 | 94 | int L = layout::laneIdx(); 95 | digit tmp; 96 | digit::mul_lo(tmp, x, inv_mod); 97 | digit::mul_lo(tmp, tmp, fixnum::get(y, 0)); 98 | digit cy = digit::zero(); 99 | 100 | for (int i = 0; i < WIDTH; ++i) { 101 | digit u; 102 | digit xi = fixnum::get(x, i); 103 | digit z0 = fixnum::get(z, 0); 104 | digit tmpi = fixnum::get(tmp, i); 105 | 106 | digit::mad_lo(u, z0, inv_mod, tmpi); 107 | 108 | digit::mad_lo_cy(z, cy, monty.mod, u, z); 109 | digit::mad_lo_cy(z, cy, y, xi, z); 110 | 111 | assert(L || digit::is_zero(z)); // z[0] must be 0 112 | z = layout::shfl_down0(z, 1); // Shift right one word 113 | 114 | digit::add_cy(z, cy, z, cy); 115 | 116 | digit::mad_hi_cy(z, cy, monty.mod, u, z); 117 | digit::mad_hi_cy(z, cy, y, xi, z); 118 | } 119 | // Resolve carries 120 | digit msw = fixnum::top_digit(cy); 121 | cy = layout::shfl_up0(cy, 1); // left shift by 1 122 | fixnum::add_cy(z, cy, z, cy); 123 | digit::add(msw, msw, cy); 124 | assert(msw == !!msw); // msw = 0 or 1. 125 | 126 | monty.normalise(z, (int)msw); 127 | } 128 | 129 | } // End namespace cuFIXNUM 130 | -------------------------------------------------------------------------------- /src/modnum/modnum_monty_redc.cu: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "modnum/internal/monty.cu" 4 | 5 | namespace cuFIXNUM { 6 | 7 | template< typename fixnum_ > 8 | class modnum_monty_redc { 9 | public: 10 | typedef fixnum_ fixnum; 11 | typedef fixnum modnum; 12 | 13 | __device__ modnum_monty_redc(fixnum mod) 14 | : monty(mod) { 15 | if ( ! monty.is_valid) return; 16 | 17 | modinv minv; 18 | minv(inv_mod, mod, fixnum::BITS); 19 | fixnum::neg(inv_mod, inv_mod); 20 | #ifndef NDEBUG 21 | fixnum tmp; 22 | fixnum::mul_lo(tmp, inv_mod, mod); 23 | fixnum::add(tmp, tmp, fixnum::one()); 24 | assert(fixnum::is_zero(tmp)); 25 | #endif 26 | } 27 | 28 | __device__ modnum zero() const { return monty.zero(); } 29 | __device__ modnum one() const { return monty.one(); } 30 | __device__ void add(modnum &z, modnum x, modnum y) const { monty.add(z, x, y); } 31 | __device__ void sub(modnum &z, modnum x, modnum y) const { monty.sub(z, x, y); } 32 | __device__ void neg(modnum &z, modnum x, modnum y) const { monty.neg(z, x); } 33 | 34 | __device__ void sqr(modnum &z, modnum x) const { 35 | // FIXME: Fix this hack! 36 | z = zero(); 37 | if (!monty.is_valid) return; 38 | 39 | modnum a_hi, a_lo; 40 | fixnum::sqr_wide(a_hi, a_lo, x); 41 | redc(z, a_hi, a_lo); 42 | } 43 | 44 | __device__ void mul(modnum &z, modnum x, modnum y) const { 45 | // FIXME: Fix this hack! 46 | z = zero(); 47 | if (!monty.is_valid) return; 48 | 49 | modnum a_hi, a_lo; 50 | fixnum::mul_wide(a_hi, a_lo, x, y); 51 | redc(z, a_hi, a_lo); 52 | } 53 | 54 | // TODO: Might be worth specialising multiplication for this case, since one of 55 | // the operands is known. 56 | __device__ void to_modnum(modnum &z, fixnum x) const { 57 | mul(z, x, monty.Rsqr_mod); 58 | } 59 | 60 | __device__ void from_modnum(fixnum &z, modnum x) const { 61 | //mul(z, x, fixnum::one()); 62 | redc(z, fixnum::zero(), x); 63 | } 64 | 65 | private: 66 | internal::monty monty; 67 | // inv_mod * mod = -1 % 2^fixnum::BITS. 68 | fixnum inv_mod; 69 | 70 | __device__ void redc(fixnum &r, fixnum a_hi, fixnum a_lo) const; 71 | }; 72 | 73 | 74 | template< typename fixnum > 75 | __device__ void 76 | modnum_monty_redc::redc(fixnum &r, fixnum a_hi, fixnum a_lo) const { 77 | typedef typename fixnum::digit digit; 78 | fixnum b, s_hi, s_lo; 79 | digit cy, c; 80 | 81 | // FIXME: Fix this hack! 82 | r = zero(); 83 | if (!monty.is_valid) return; 84 | 85 | fixnum::mul_lo(b, a_lo, inv_mod); 86 | 87 | // This section is essentially s = floor(mad_wide(b, mod, a) / R) 88 | 89 | // TODO: Can we employ the trick to avoid a multiplication because we 90 | // know b = am' (mod R)? 91 | fixnum::mul_wide(s_hi, s_lo, b, monty.mod); 92 | // TODO: Only want the carry; find a cheaper way to determine that 93 | // without doing the full addition. 94 | fixnum::add_cy(s_lo, cy, s_lo, a_lo); 95 | 96 | // TODO: The fact that we need to turn cy into a fixnum before using it in 97 | // arithmetic should be handled more cleanly. Also, this code is already in 98 | // the private function digit_to_fixnum() in ''warp_fixnum.cu'. 99 | int L = fixnum::layout::laneIdx(); 100 | cy = (L == 0) ? cy : digit::zero(); 101 | 102 | // TODO: The assert below fails; work out why. 103 | #if 0 104 | // NB: b = am' (mod R) => a + bm = a + amm' = 2a (mod R). So surely 105 | // all I need to propagate is the top bit of a_lo? 106 | fixnum top_bit, dummy; 107 | fixnum::lshift(dummy, top_bit, a_lo, 1); 108 | assert(digit::cmp(cy, top_bit) == 0); 109 | #endif 110 | fixnum::add_cy(r, cy, s_hi, cy); 111 | fixnum::add_cy(r, c, r, a_hi); 112 | digit::add(cy, cy, c); 113 | assert(cy == !!cy); // cy = 0 or 1. 114 | 115 | monty.normalise(r, cy); 116 | } 117 | 118 | } // End namespace cuFIXNUM 119 | -------------------------------------------------------------------------------- /src/util/cuda_wrap.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | 6 | namespace cuFIXNUM { 7 | 8 | /* 9 | * Convenience wrappers around some CUDA library functions 10 | */ 11 | static inline void 12 | cuda_print_errmsg(cudaError err, const char *msg, const char *file, const int line) 13 | { 14 | if (err != cudaSuccess) { 15 | fprintf(stderr, "Fatal CUDA error at %s:%d : %s : %s\n", 16 | file, line, msg, cudaGetErrorString(err)); 17 | if (cudaDeviceReset() != cudaSuccess) 18 | fprintf(stderr, " ...and failed to reset the device!\n"); 19 | exit(EXIT_FAILURE); 20 | } 21 | } 22 | 23 | } // End namespace cuFIXNUM 24 | 25 | #define cuda_check(err, msg) \ 26 | ::cuFIXNUM::cuda_print_errmsg(err, msg, __FILE__, __LINE__) 27 | 28 | #define cuda_malloc(ptr, size) \ 29 | cuda_check(cudaMalloc(ptr, size), "memory allocation") 30 | #define cuda_malloc_managed(ptr, size) \ 31 | cuda_check(cudaMallocManaged(ptr, size), \ 32 | "unified memory allocation (default attach)") 33 | #define cuda_malloc_managed_host(ptr, size) \ 34 | cuda_check(cudaMallocManaged(ptr, size, cudaMemAttachHost), \ 35 | "unified memory allocation (host attach)") 36 | #define cuda_stream_attach_mem(stream, ptr) \ 37 | cuda_check(cudaStreamAttachMemAsync(stream, ptr), "attach unified memory to stream") 38 | #define cuda_free(ptr) \ 39 | cuda_check(cudaFree(ptr), "memory deallocation") 40 | #define cuda_memcpy_to_device(dest, src, size) \ 41 | cuda_check(cudaMemcpy(dest, src, size, cudaMemcpyHostToDevice), "copy to device") 42 | #define cuda_memcpy_from_device(dest, src, size) \ 43 | cuda_check(cudaMemcpy(dest, src, size, cudaMemcpyDeviceToHost), "copy from device") 44 | #define cuda_memcpy_on_device(dest, src, size) \ 45 | cuda_check(cudaMemcpy(dest, src, size, cudaMemcpyDeviceToDevice), "copy on device") 46 | #define cuda_memset(dest, val, size) \ 47 | cuda_check(cudaMemset(dest, val, size), "memset on device") 48 | #define cuda_device_synchronize() \ 49 | cuda_check(cudaDeviceSynchronize(), "device synchronize") 50 | 51 | -------------------------------------------------------------------------------- /tests/gentests.py: -------------------------------------------------------------------------------- 1 | from itertools import chain, product 2 | from collections import deque 3 | from timeit import default_timer as timer 4 | from gmpy2 import is_prime 5 | 6 | def write_int(dest, sz, n): 7 | dest.write(n.to_bytes(sz, byteorder = 'little')) 8 | 9 | def write_vector(dest, elt_sz, v): 10 | for n in v: 11 | write_int(dest, elt_sz, n) 12 | 13 | def mktests(op, xs, nargs, bits): 14 | # FIXME: Refactor this. 15 | if nargs == 1: 16 | yield zip(*[op(x, bits) for x in xs]) 17 | elif nargs == 2: 18 | ys = deque(xs) 19 | for i in range(len(xs)): 20 | yield zip(*[op(x, y, bits) for x, y in zip(xs, ys)]) 21 | ys.rotate(1) 22 | elif nargs == 3: 23 | ys = deque(xs) 24 | zs = deque(xs) 25 | for _ in range(len(xs)): 26 | for _ in range(len(xs)): 27 | yield list(zip(*[op(x, y, z, bits) for x, y, z in zip(xs, ys, zs)])) 28 | zs.rotate(1) 29 | ys.rotate(1) 30 | elif nargs == 4: 31 | ys = deque(xs) 32 | zs = deque(xs) 33 | ws = deque(xs) 34 | for _ in range(len(xs)): 35 | for _ in range(len(xs)): 36 | for _ in range(len(xs)): 37 | yield list(zip(*[op(x, y, z, w, bits) for x, y, z, w in zip(xs, ys, zs, ws)])) 38 | ws.rotate(1) 39 | zs.rotate(1) 40 | ys.rotate(1) 41 | else: 42 | raise NotImplementedError() 43 | 44 | def write_tests(fname, arg): 45 | op, xs, nargs, nres, bits = arg 46 | vec_len = len(xs) 47 | ntests = vec_len**nargs 48 | t = timer() 49 | print('Writing {} tests into "{}"... '.format(ntests, fname), end='', flush=True) 50 | with open(fname, 'wb') as f: 51 | fixnum_bytes = bits >> 3 52 | write_int(f, 4, fixnum_bytes) 53 | write_int(f, 4, vec_len) 54 | write_int(f, 4, nres) 55 | write_vector(f, fixnum_bytes, xs) 56 | for v in mktests(op, xs, nargs, bits): 57 | v = list(v) 58 | assert len(v) == nres, 'bad result length; expected {}, got {}'.format(nres, len(v)) 59 | for res in v: 60 | write_vector(f, fixnum_bytes, res) 61 | t = timer() - t 62 | print('done ({:.2f}s).'.format(t)) 63 | return fname 64 | 65 | def add_cy(x, y, bits): 66 | return [(x + y) & ((1<> bits] 67 | 68 | def sub_br(x, y, bits): 69 | return [(x - y) & ((1<> bits] 73 | 74 | def sqr_wide(x, bits): 75 | return [(x * x) & ((1<> bits] 76 | 77 | def modexp(x, y, z, bits): 78 | # FIXME: Handle these cases properly! 79 | if z % 2 == 0: 80 | return [0] 81 | return [pow(x, y, z)] 82 | 83 | def paillier_encrypt(p, q, r, m, bits): 84 | n = p * q 85 | n2 = n * n 86 | return [((1 + m * n) * pow(r, n, n2)) % n2] 87 | 88 | def test_inputs(nbytes): 89 | q = nbytes // 4 90 | res = [0] 91 | 92 | nums = [1, 2, 3]; 93 | nums.extend([2**32 - n for n in nums]) 94 | 95 | for i in range(q): 96 | res.extend(n << 32*i for n in nums) 97 | 98 | lognbits = (32*q).bit_length() 99 | for i in range(2, lognbits - 1): 100 | # b = 0xF, 0xFF, 0xFFFF, 0xFFFFFFFF, ... 101 | e = 1 << i 102 | b = (1 << e) - 1 103 | c = sum(b << 2*e*j for j in range(32*q // (2*e))) 104 | res.extend([c, (1 << 32*q) - c - 1]) 105 | return res 106 | 107 | def prev_prime(n): 108 | # subtract 1 or 2 from n, depending on whether n is even or odd. 109 | n -= 1 + n%2 110 | while not (is_prime(n) or n < 3): 111 | n -= 2 112 | assert n >= 3, 'Failed to find a prime' 113 | return n 114 | 115 | def test_primes(nbytes): 116 | ps = [3, 5, 7] 117 | for i in range(nbytes, nbytes + 1): # i in range(1, ...): 118 | p = prev_prime(1 << (8*i)) 119 | ps.append(p) 120 | p = prev_prime(p) 121 | ps.append(p) 122 | p = prev_prime(p) 123 | ps.append(p) 124 | return ps 125 | 126 | def generate_tests(nbytes, tests): 127 | assert nbytes >= 4 and (nbytes & (nbytes - 1)) == 0, "nbytes must be a binary power at least 4" 128 | print('Generating input arguments... ', end='', flush=True) 129 | bits = nbytes * 8 130 | 131 | t = timer() 132 | xs = test_inputs(nbytes) 133 | t = timer() - t 134 | print('done ({:.2f}s). Created {} arguments.'.format(t, len(xs))) 135 | 136 | # ps is only used by paillier_encrypt, which needs primes 1/4 the size of 137 | # the ciphertext. 138 | ps = test_primes(nbytes // 4) 139 | print('primes = ', ps) 140 | 141 | ops = { 142 | 'add_cy': (add_cy, xs, 2, 2, bits), 143 | 'sub_br': (sub_br, xs, 2, 2, bits), 144 | 'mul_wide': (mul_wide, xs, 2, 2, bits), 145 | 'sqr_wide': (sqr_wide, xs, 1, 2, bits), 146 | 'modexp': (modexp, xs, 3, 1, bits), 147 | 'paillier_encrypt' : (paillier_encrypt, ps, 4, 1, bits) 148 | } 149 | test_names = ops.keys() & tests if len(tests) > 0 else ops.keys() 150 | test_fns = { fn: ops[fn] for fn in test_names } 151 | fnames = map(lambda fn: fn + '_' + str(nbytes), test_fns.keys()) 152 | return list(map(write_tests, fnames, test_fns.values())) 153 | 154 | 155 | def print_usage(progname): 156 | print("""Please specify the functions for which you want to generate test cases: 157 | 158 | $ python3 {} ... 159 | 160 | where each is one of 'add_cy', 'sub_br', 'mul_wide', 'sqr_wide', 'modexp', 'paillier_encrypt'. 161 | Specifying no functions will generate all of them (you will need to do this at least once).""".format(progname)) 162 | 163 | 164 | if __name__ == '__main__': 165 | import sys 166 | if len(sys.argv[1:]) > 0 and sys.argv[1] == '-h': 167 | print_usage(sys.argv[0]) 168 | else: 169 | for i in range(2, 9): 170 | generate_tests(1 << i, sys.argv[1:]) 171 | -------------------------------------------------------------------------------- /tests/test-suite.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | 11 | #include "array/fixnum_array.h" 12 | #include "fixnum/word_fixnum.cu" 13 | #include "fixnum/warp_fixnum.cu" 14 | #include "modnum/modnum_monty_cios.cu" 15 | #include "modnum/modnum_monty_redc.cu" 16 | #include "functions/modexp.cu" 17 | #include "functions/paillier_encrypt.cu" 18 | #include "functions/paillier_decrypt.cu" 19 | 20 | using namespace std; 21 | using namespace cuFIXNUM; 22 | 23 | typedef vector byte_array; 24 | 25 | void die_if(bool p, const string &msg) { 26 | if (p) { 27 | cerr << "Error: " << msg << endl; 28 | abort(); 29 | } 30 | } 31 | 32 | int 33 | arrays_are_equal( 34 | const uint8_t *expected, size_t expected_len, 35 | const uint8_t *actual, size_t actual_len) 36 | { 37 | if (expected_len > actual_len) 38 | return actual_len; 39 | size_t i; 40 | for (i = 0; i < expected_len; ++i) { 41 | if (expected[i] != actual[i]) 42 | return i; 43 | } 44 | for (; i < actual_len; ++i) { 45 | if (actual[i] != 0) 46 | return i; 47 | } 48 | return -1; 49 | } 50 | 51 | 52 | template< typename fixnum_ > 53 | struct TypedPrimitives : public ::testing::Test { 54 | typedef fixnum_ fixnum; 55 | 56 | TypedPrimitives() {} 57 | }; 58 | 59 | typedef ::testing::Types< 60 | warp_fixnum<4, u32_fixnum>, 61 | warp_fixnum<8, u32_fixnum>, 62 | warp_fixnum<16, u32_fixnum>, 63 | warp_fixnum<32, u32_fixnum>, 64 | warp_fixnum<64, u32_fixnum>, 65 | warp_fixnum<128, u32_fixnum>, 66 | 67 | warp_fixnum<8, u64_fixnum>, 68 | warp_fixnum<16, u64_fixnum>, 69 | warp_fixnum<32, u64_fixnum>, 70 | warp_fixnum<64, u64_fixnum>, 71 | warp_fixnum<128, u64_fixnum>, 72 | warp_fixnum<256, u64_fixnum> 73 | > FixnumImplTypes; 74 | 75 | TYPED_TEST_CASE(TypedPrimitives, FixnumImplTypes); 76 | 77 | void read_into(ifstream &file, uint8_t *buf, size_t nbytes) { 78 | file.read(reinterpret_cast(buf), nbytes); 79 | die_if( ! file.good(), "Read error."); 80 | die_if(static_cast(file.gcount()) != nbytes, "Expected more data."); 81 | } 82 | 83 | uint32_t read_int(ifstream &file) { 84 | uint32_t res; 85 | file.read(reinterpret_cast(&res), sizeof(res)); 86 | return res; 87 | } 88 | 89 | template 90 | void read_tcases( 91 | vector &res, 92 | fixnum_array *&xs, 93 | const string &fname, 94 | int nargs) { 95 | static constexpr int fixnum_bytes = fixnum::BYTES; 96 | ifstream file(fname + "_" + std::to_string(fixnum_bytes)); 97 | die_if( ! file.good(), "Couldn't open file."); 98 | 99 | uint32_t fn_bytes, vec_len, noutvecs; 100 | fn_bytes = read_int(file); 101 | vec_len = read_int(file); 102 | noutvecs = read_int(file); 103 | 104 | stringstream ss; 105 | ss << "Inconsistent reporting of fixnum bytes. " 106 | << "Expected " << fixnum_bytes << " got " << fn_bytes << "."; 107 | die_if(fixnum_bytes != fn_bytes, ss.str()); 108 | 109 | size_t nbytes = fixnum_bytes * vec_len; 110 | uint8_t *buf = new uint8_t[nbytes]; 111 | 112 | read_into(file, buf, nbytes); 113 | xs = fixnum_array::create(buf, nbytes, fixnum_bytes); 114 | 115 | // ninvecs = number of input combinations 116 | uint32_t ninvecs = 1; 117 | for (int i = 1; i < nargs; ++i) 118 | ninvecs *= vec_len; 119 | res.reserve(noutvecs * ninvecs); 120 | for (uint32_t i = 0; i < ninvecs; ++i) { 121 | for (uint32_t j = 0; j < noutvecs; ++j) { 122 | read_into(file, buf, nbytes); 123 | res.emplace_back(buf, buf + nbytes); 124 | } 125 | } 126 | 127 | delete[] buf; 128 | } 129 | 130 | template< typename fixnum, typename tcase_iter > 131 | void check_result( 132 | tcase_iter &tcase, uint32_t vec_len, 133 | initializer_list *> args, 134 | int skip = 1, 135 | uint32_t nvecs = 1) 136 | { 137 | static constexpr int fixnum_bytes = fixnum::BYTES; 138 | size_t total_vec_len = vec_len * nvecs; 139 | size_t nbytes = fixnum_bytes * total_vec_len; 140 | // TODO: The fixnum_arrays are in managed memory; there isn't really any 141 | // point to copying them into buf. 142 | byte_array buf(nbytes); 143 | 144 | int arg_idx = 0; 145 | for (auto arg : args) { 146 | auto buf_iter = buf.begin(); 147 | for (uint32_t i = 0; i < nvecs; ++i) { 148 | std::copy(tcase->begin(), tcase->end(), buf_iter); 149 | buf_iter += fixnum_bytes*vec_len; 150 | tcase += skip; 151 | } 152 | int r = arrays_are_equal(buf.data(), nbytes, arg->get_ptr(), nbytes); 153 | EXPECT_TRUE(r < 0) << "failed near byte " << r << " in argument " << arg_idx; 154 | ++arg_idx; 155 | } 156 | } 157 | 158 | template< typename fixnum > 159 | struct add_cy { 160 | __device__ void operator()(fixnum &r, fixnum &cy, fixnum a, fixnum b) { 161 | typedef typename fixnum::digit digit; 162 | digit c; 163 | fixnum::add_cy(r, c, a, b); 164 | // TODO: This is like digit_to_fixnum 165 | cy = (fixnum::layout::laneIdx() == 0) ? c : digit::zero(); 166 | } 167 | }; 168 | 169 | TYPED_TEST(TypedPrimitives, add_cy) { 170 | typedef typename TestFixture::fixnum fixnum; 171 | typedef fixnum_array fixnum_array; 172 | 173 | fixnum_array *res, *cys, *xs; 174 | vector tcases; 175 | 176 | read_tcases(tcases, xs, "tests/add_cy", 2); 177 | int vec_len = xs->length(); 178 | res = fixnum_array::create(vec_len); 179 | cys = fixnum_array::create(vec_len); 180 | 181 | auto tcase = tcases.begin(); 182 | for (int i = 0; i < vec_len; ++i) { 183 | fixnum_array *ys = xs->rotate(i); 184 | fixnum_array::template map(res, cys, xs, ys); 185 | check_result(tcase, vec_len, {res, cys}); 186 | delete ys; 187 | } 188 | delete res; 189 | delete cys; 190 | delete xs; 191 | } 192 | 193 | 194 | template< typename fixnum > 195 | struct sub_br { 196 | __device__ void operator()(fixnum &r, fixnum &br, fixnum a, fixnum b) { 197 | typedef typename fixnum::digit digit; 198 | digit bb; 199 | fixnum::sub_br(r, bb, a, b); 200 | br = (fixnum::layout::laneIdx() == 0) ? bb : digit::zero(); 201 | } 202 | }; 203 | 204 | TYPED_TEST(TypedPrimitives, sub_br) { 205 | typedef typename TestFixture::fixnum fixnum; 206 | typedef fixnum_array fixnum_array; 207 | 208 | fixnum_array *res, *brs, *xs; 209 | vector tcases; 210 | 211 | read_tcases(tcases, xs, "tests/sub_br", 2); 212 | int vec_len = xs->length(); 213 | res = fixnum_array::create(vec_len); 214 | brs = fixnum_array::create(vec_len); 215 | 216 | auto tcase = tcases.begin(); 217 | for (int i = 0; i < vec_len; ++i) { 218 | fixnum_array *ys = xs->rotate(i); 219 | fixnum_array::template map(res, brs, xs, ys); 220 | check_result(tcase, vec_len, {res, brs}); 221 | delete ys; 222 | } 223 | delete res; 224 | delete brs; 225 | delete xs; 226 | } 227 | 228 | template< typename fixnum > 229 | struct mul_lo { 230 | __device__ void operator()(fixnum &r, fixnum a, fixnum b) { 231 | fixnum rr; 232 | fixnum::mul_lo(rr, a, b); 233 | r = rr; 234 | } 235 | }; 236 | 237 | TYPED_TEST(TypedPrimitives, mul_lo) { 238 | typedef typename TestFixture::fixnum fixnum; 239 | typedef fixnum_array fixnum_array; 240 | 241 | fixnum_array *res, *xs; 242 | vector tcases; 243 | 244 | read_tcases(tcases, xs, "tests/mul_wide", 2); 245 | int vec_len = xs->length(); 246 | res = fixnum_array::create(vec_len); 247 | 248 | auto tcase = tcases.begin(); 249 | for (int i = 0; i < vec_len; ++i) { 250 | fixnum_array *ys = xs->rotate(i); 251 | fixnum_array::template map(res, xs, ys); 252 | check_result(tcase, vec_len, {res}, 2); 253 | delete ys; 254 | } 255 | delete res; 256 | delete xs; 257 | } 258 | 259 | template< typename fixnum > 260 | struct mul_hi { 261 | __device__ void operator()(fixnum &r, fixnum a, fixnum b) { 262 | fixnum rr; 263 | fixnum::mul_hi(rr, a, b); 264 | r = rr; 265 | } 266 | }; 267 | 268 | TYPED_TEST(TypedPrimitives, mul_hi) { 269 | typedef typename TestFixture::fixnum fixnum; 270 | typedef fixnum_array fixnum_array; 271 | 272 | fixnum_array *res, *xs; 273 | vector tcases; 274 | 275 | read_tcases(tcases, xs, "tests/mul_wide", 2); 276 | int vec_len = xs->length(); 277 | res = fixnum_array::create(vec_len); 278 | 279 | auto tcase = tcases.begin() + 1; 280 | for (int i = 0; i < vec_len; ++i) { 281 | fixnum_array *ys = xs->rotate(i); 282 | fixnum_array::template map(res, xs, ys); 283 | check_result(tcase, vec_len, {res}, 2); 284 | delete ys; 285 | } 286 | delete res; 287 | delete xs; 288 | } 289 | 290 | template< typename fixnum > 291 | struct mul_wide { 292 | __device__ void operator()(fixnum &s, fixnum &r, fixnum a, fixnum b) { 293 | fixnum rr, ss; 294 | fixnum::mul_wide(ss, rr, a, b); 295 | s = ss; 296 | r = rr; 297 | } 298 | }; 299 | 300 | TYPED_TEST(TypedPrimitives, mul_wide) { 301 | typedef typename TestFixture::fixnum fixnum; 302 | typedef fixnum_array fixnum_array; 303 | 304 | fixnum_array *his, *los, *xs; 305 | vector tcases; 306 | 307 | read_tcases(tcases, xs, "tests/mul_wide", 2); 308 | int vec_len = xs->length(); 309 | his = fixnum_array::create(vec_len); 310 | los = fixnum_array::create(vec_len); 311 | 312 | auto tcase = tcases.begin(); 313 | for (int i = 0; i < vec_len; ++i) { 314 | fixnum_array *ys = xs->rotate(i); 315 | fixnum_array::template map(his, los, xs, ys); 316 | check_result(tcase, vec_len, {los, his}); 317 | delete ys; 318 | } 319 | delete his; 320 | delete los; 321 | delete xs; 322 | } 323 | 324 | template< typename fixnum > 325 | struct sqr_lo { 326 | __device__ void operator()(fixnum &r, fixnum a) { 327 | fixnum rr; 328 | fixnum::sqr_lo(rr, a); 329 | r = rr; 330 | } 331 | }; 332 | 333 | TYPED_TEST(TypedPrimitives, sqr_lo) { 334 | typedef typename TestFixture::fixnum fixnum; 335 | typedef fixnum_array fixnum_array; 336 | 337 | fixnum_array *res, *xs; 338 | vector tcases; 339 | 340 | read_tcases(tcases, xs, "tests/sqr_wide", 1); 341 | int vec_len = xs->length(); 342 | res = fixnum_array::create(vec_len); 343 | 344 | fixnum_array::template map(res, xs); 345 | auto tcase = tcases.begin(); 346 | check_result(tcase, vec_len, {res}, 2); 347 | 348 | delete res; 349 | delete xs; 350 | } 351 | 352 | template< typename fixnum > 353 | struct sqr_hi { 354 | __device__ void operator()(fixnum &r, fixnum a) { 355 | fixnum rr; 356 | fixnum::sqr_hi(rr, a); 357 | r = rr; 358 | } 359 | }; 360 | 361 | TYPED_TEST(TypedPrimitives, sqr_hi) { 362 | typedef typename TestFixture::fixnum fixnum; 363 | typedef fixnum_array fixnum_array; 364 | 365 | fixnum_array *res, *xs; 366 | vector tcases; 367 | 368 | read_tcases(tcases, xs, "tests/sqr_wide", 1); 369 | int vec_len = xs->length(); 370 | res = fixnum_array::create(vec_len); 371 | 372 | fixnum_array::template map(res, xs); 373 | auto tcase = tcases.begin() + 1; 374 | check_result(tcase, vec_len, {res}, 2); 375 | 376 | delete res; 377 | delete xs; 378 | } 379 | 380 | template< typename fixnum > 381 | struct sqr_wide { 382 | __device__ void operator()(fixnum &s, fixnum &r, fixnum a) { 383 | fixnum rr, ss; 384 | fixnum::sqr_wide(ss, rr, a); 385 | s = ss; 386 | r = rr; 387 | } 388 | }; 389 | 390 | TYPED_TEST(TypedPrimitives, sqr_wide) { 391 | typedef typename TestFixture::fixnum fixnum; 392 | typedef fixnum_array fixnum_array; 393 | 394 | fixnum_array *his, *los, *xs; 395 | vector tcases; 396 | 397 | read_tcases(tcases, xs, "tests/sqr_wide", 1); 398 | int vec_len = xs->length(); 399 | his = fixnum_array::create(vec_len); 400 | los = fixnum_array::create(vec_len); 401 | 402 | fixnum_array::template map(his, los, xs); 403 | auto tcase = tcases.begin(); 404 | check_result(tcase, vec_len, {los, his}); 405 | 406 | delete his; 407 | delete los; 408 | delete xs; 409 | } 410 | 411 | template< typename modnum > 412 | struct my_modexp { 413 | typedef typename modnum::fixnum fixnum; 414 | 415 | __device__ void operator()(fixnum &z, fixnum x, fixnum e, fixnum m) { 416 | modexp me(m, e); 417 | fixnum zz; 418 | me(zz, x); 419 | z = zz; 420 | }; 421 | }; 422 | 423 | // TODO: Refactor the modexp tests; need to fix check_result(). 424 | template< typename fixnum > 425 | using modexp_redc = my_modexp< modnum_monty_redc >; 426 | 427 | TYPED_TEST(TypedPrimitives, modexp_redc) { 428 | typedef typename TestFixture::fixnum fixnum; 429 | typedef fixnum_array fixnum_array; 430 | 431 | fixnum_array *res, *input, *xs, *zs; 432 | vector tcases; 433 | 434 | read_tcases(tcases, input, "tests/modexp", 3); 435 | int vec_len = input->length(); 436 | int vec_len_sqr = vec_len * vec_len; 437 | 438 | res = fixnum_array::create(vec_len_sqr); 439 | xs = input->repeat(vec_len); 440 | zs = input->rotations(vec_len); 441 | 442 | auto tcase = tcases.begin(); 443 | for (int i = 0; i < vec_len; ++i) { 444 | fixnum_array *tmp = input->rotate(i); 445 | fixnum_array *ys = tmp->repeat(vec_len); 446 | 447 | fixnum_array::template map(res, xs, ys, zs); 448 | check_result(tcase, vec_len, {res}, 1, vec_len); 449 | 450 | delete ys; 451 | delete tmp; 452 | } 453 | delete res; 454 | delete input; 455 | delete xs; 456 | delete zs; 457 | } 458 | 459 | template< typename fixnum > 460 | using modexp_cios = my_modexp< modnum_monty_cios >; 461 | 462 | TYPED_TEST(TypedPrimitives, modexp_cios) { 463 | typedef typename TestFixture::fixnum fixnum; 464 | typedef fixnum_array fixnum_array; 465 | 466 | fixnum_array *res, *input, *xs, *zs; 467 | vector tcases; 468 | 469 | read_tcases(tcases, input, "tests/modexp", 3); 470 | int vec_len = input->length(); 471 | int vec_len_sqr = vec_len * vec_len; 472 | 473 | res = fixnum_array::create(vec_len_sqr); 474 | xs = input->repeat(vec_len); 475 | zs = input->rotations(vec_len); 476 | 477 | auto tcase = tcases.begin(); 478 | for (int i = 0; i < vec_len; ++i) { 479 | fixnum_array *tmp = input->rotate(i); 480 | fixnum_array *ys = tmp->repeat(vec_len); 481 | 482 | fixnum_array::template map(res, xs, ys, zs); 483 | check_result(tcase, vec_len, {res}, 1, vec_len); 484 | 485 | delete ys; 486 | delete tmp; 487 | } 488 | delete res; 489 | delete input; 490 | delete xs; 491 | delete zs; 492 | } 493 | 494 | 495 | template< typename modnum > 496 | struct my_multi_modexp { 497 | typedef typename modnum::fixnum fixnum; 498 | 499 | __device__ void operator()(fixnum &z, fixnum x, fixnum e, fixnum m) { 500 | multi_modexp mme(m); 501 | fixnum zz; 502 | mme(zz, x, e); 503 | z = zz; 504 | }; 505 | }; 506 | 507 | template< typename fixnum > 508 | using multi_modexp_redc = my_multi_modexp< modnum_monty_redc >; 509 | 510 | TYPED_TEST(TypedPrimitives, multi_modexp_redc) { 511 | typedef typename TestFixture::fixnum fixnum; 512 | typedef fixnum_array fixnum_array; 513 | 514 | fixnum_array *res, *input, *xs, *zs; 515 | vector tcases; 516 | 517 | read_tcases(tcases, input, "tests/modexp", 3); 518 | int vec_len = input->length(); 519 | int vec_len_sqr = vec_len * vec_len; 520 | 521 | res = fixnum_array::create(vec_len_sqr); 522 | xs = input->repeat(vec_len); 523 | zs = input->rotations(vec_len); 524 | 525 | auto tcase = tcases.begin(); 526 | for (int i = 0; i < vec_len; ++i) { 527 | fixnum_array *tmp = input->rotate(i); 528 | fixnum_array *ys = tmp->repeat(vec_len); 529 | 530 | fixnum_array::template map(res, xs, ys, zs); 531 | check_result(tcase, vec_len, {res}, 1, vec_len); 532 | 533 | delete ys; 534 | delete tmp; 535 | } 536 | delete res; 537 | delete input; 538 | delete xs; 539 | delete zs; 540 | } 541 | 542 | template< typename fixnum > 543 | using multi_modexp_cios = my_multi_modexp< modnum_monty_cios >; 544 | 545 | TYPED_TEST(TypedPrimitives, multi_modexp_cios) { 546 | typedef typename TestFixture::fixnum fixnum; 547 | typedef fixnum_array fixnum_array; 548 | 549 | fixnum_array *res, *input, *xs, *zs; 550 | vector tcases; 551 | 552 | read_tcases(tcases, input, "tests/modexp", 3); 553 | int vec_len = input->length(); 554 | int vec_len_sqr = vec_len * vec_len; 555 | 556 | res = fixnum_array::create(vec_len_sqr); 557 | xs = input->repeat(vec_len); 558 | zs = input->rotations(vec_len); 559 | 560 | auto tcase = tcases.begin(); 561 | for (int i = 0; i < vec_len; ++i) { 562 | fixnum_array *tmp = input->rotate(i); 563 | fixnum_array *ys = tmp->repeat(vec_len); 564 | 565 | fixnum_array::template map(res, xs, ys, zs); 566 | check_result(tcase, vec_len, {res}, 1, vec_len); 567 | 568 | delete ys; 569 | delete tmp; 570 | } 571 | delete res; 572 | delete input; 573 | delete xs; 574 | delete zs; 575 | } 576 | 577 | template< typename fixnum > 578 | struct pencrypt { 579 | __device__ void operator()(fixnum &z, fixnum p, fixnum q, fixnum r, fixnum m) { 580 | fixnum n, zz; 581 | fixnum::mul_lo(n, p, q); 582 | paillier_encrypt enc(n); 583 | enc(zz, m, r); 584 | z = zz; 585 | }; 586 | }; 587 | 588 | template< typename fixnum > 589 | struct pdecrypt { 590 | __device__ void operator()(fixnum &z, fixnum ct, fixnum p, fixnum q, fixnum r, fixnum m) { 591 | if (fixnum::cmp(p, q) == 0 592 | || fixnum::cmp(r, p) == 0 593 | || fixnum::cmp(r, q) == 0) { 594 | z = fixnum::zero(); 595 | return; 596 | } 597 | paillier_decrypt dec(p, q); 598 | fixnum n, zz; 599 | dec(zz, fixnum::zero(), ct); 600 | fixnum::mul_lo(n, p, q); 601 | quorem_preinv qr(n); 602 | qr(m, fixnum::zero(), m); 603 | 604 | // z = (z != m) 605 | z = fixnum::digit( !! fixnum::cmp(zz, m)); 606 | }; 607 | }; 608 | 609 | TYPED_TEST(TypedPrimitives, paillier) { 610 | typedef typename TestFixture::fixnum fixnum; 611 | 612 | typedef fixnum ctxt; 613 | // TODO: BYTES/2 only works when BYTES > 4 614 | //typedef default_fixnum ptxt; 615 | typedef fixnum ptxt; 616 | 617 | typedef fixnum_array ctxt_array; 618 | typedef fixnum_array ptxt_array; 619 | 620 | ctxt_array *ct, *pt, *p; 621 | vector tcases; 622 | read_tcases(tcases, p, "tests/paillier_encrypt", 4); 623 | 624 | int vec_len = p->length(); 625 | ct = ctxt_array::create(vec_len); 626 | pt = ctxt_array::create(vec_len); 627 | 628 | // TODO: Parallelise these tests similar to modexp above. 629 | ctxt_array *zeros = ctxt_array::create(vec_len, 0); 630 | auto tcase = tcases.begin(); 631 | for (int i = 0; i < vec_len; ++i) { 632 | ctxt_array *q = p->rotate(i); 633 | for (int j = 0; j < vec_len; ++j) { 634 | ctxt_array *r = p->rotate(j); 635 | for (int k = 0; k < vec_len; ++k) { 636 | ctxt_array *m = p->rotate(k); 637 | 638 | ctxt_array::template map(ct, p, q, r, m); 639 | check_result(tcase, vec_len, {ct}); 640 | 641 | ptxt_array::template map(pt, ct, p, q, r, m); 642 | 643 | size_t nbytes = vec_len * ctxt::BYTES; 644 | const uint8_t *zptr = reinterpret_cast(zeros->get_ptr()); 645 | const uint8_t *ptptr = reinterpret_cast(pt->get_ptr()); 646 | EXPECT_TRUE(arrays_are_equal(zptr, nbytes, ptptr, nbytes)); 647 | 648 | delete m; 649 | } 650 | delete r; 651 | } 652 | delete q; 653 | } 654 | 655 | delete p; 656 | delete ct; 657 | delete zeros; 658 | } 659 | 660 | int main(int argc, char *argv[]) 661 | { 662 | int r; 663 | 664 | testing::InitGoogleTest(&argc, argv); 665 | r = RUN_ALL_TESTS(); 666 | return r; 667 | } 668 | --------------------------------------------------------------------------------