├── LICENSE.txt ├── README.md ├── gpump.cu └── gpump.cuh /LICENSE.txt: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) [year] [fullname] 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GPUMP 2 | # Multiple-Precision Arithmetic and Barrett Reduction on GPU 3 | 4 | ## Overview 5 | This project implements multiple-precision arithmetic operations with a focus on the Barrett reduction algorithm optimized for GPUs using CUDA. The goal is to facilitate operations on very large integers that standard data types cannot handle, such as those required in cryptographic computations. 6 | 7 | ## Multiple-Precision Arithmetic Operations 8 | 9 | ### A. Comparison, Addition, and Subtraction 10 | 11 | #### Algorithm 1: Multiple-Precision Comparison 12 | ```plaintext 13 | INPUT: non-negative integers x and y, each with n+1 radix b digits. 14 | OUTPUT: 1 if x > y; 0 if x = y; -1 if x < y. 15 | 16 | 1. i <- n; 17 | 2. while (i >= 0 && x[i] == y[i]) 18 | 3. i <- i - 1; 19 | 4. end while 20 | 5. if (x[i] > y[i]) 21 | 6. return 1; 22 | 7. else if (x[i] < y[i]) 23 | 8. return -1; 24 | 9. else 25 | 10. return 0; 26 | ``` 27 | #### Algorithm 2: Multiple-Precision Addition 28 | ```plaintext 29 | INPUT: non-negative integers x and y, each with n+1 radix b digits. 30 | OUTPUT: x + y = (z_n ... z_1 z_0)_b. 31 | 32 | 1. c <- 0; // carry digit 33 | 2. for (i from 0 to n) do 34 | 3. z_i <- (x_i + y_i + c) mod b; 35 | 4. c <- (x_i + y_i + c) / b; 36 | 5. end for 37 | 6. z_{n+1} <- c; 38 | 7. return (z_n ... z_1 z_0)_b; 39 | ``` 40 | #### Algorithm 3: Multiple-Precision Subtraction 41 | 42 | ```plaintext 43 | INPUT: non-negative integers x and y, each with n+1 radix b digits, x >= y. 44 | OUTPUT: x - y = (z_n ... z_1 z_0)_b. 45 | 46 | 1. c <- 0; // carry digit 47 | 2. for (i from 0 to n) do 48 | 3. z_i <- (x_i - y_i - c) mod b; 49 | 4. if ((x_i - y_i - c) < 0) then c <- 1; 50 | 5. else c <- 0; 51 | 6. end if 52 | 7. end for 53 | 8. return (z_n ... z_1 z_0)_b; 54 | ``` 55 | 56 | ### B. Modular Addition and Subtraction 57 | 58 | #### Algorithm 4: Multiple-Precision Modular Addition 59 | ```plaintext 60 | INPUT: non-negative integers x and y, each with n+1 radix b digits, x < m, y < m. 61 | OUTPUT: (x + y) mod m = (z_{n+1} z_n ... z_1 z_0)_b. 62 | 63 | 1. c <- 0; // carry digit 64 | 2. for (i from 0 to n) do 65 | 3. z_i <- (x_i + y_i + c) mod b; 66 | 4. if ((x_i + y_i + c) >= b) then c <- 1; 67 | 5. else c <- 0; 68 | 6. end if 69 | 7. end for 70 | 8. z_{n+1} <- m_{n+1} - c; // m_{n+1} is 0 71 | 9. if ((z_{n+1} z_n ... z_1 z_0)_b >= (m_{n+1} m_n ... m_1 m_0)_b) then 72 | 10. return (z_{n+1} z_n ... z_1 z_0)_b - (m_{n+1} m_n ... m_1 m_0)_b; 73 | 11. else return (z_{n+1} z_n ... z_1 z_0)_b; 74 | ``` 75 | #### Algorithm 5: Multiple-Precision Modular Subtraction 76 | 77 | ```plaintext 78 | INPUT: non-negative integers x and y, each with n+1 radix b digits, x < m, y < m. 79 | OUTPUT: (x - y) mod m = (z_{n+1} z_n ... z_1 z_0)_b. 80 | 81 | 1. if (x >= y) 82 | 2. return (x - y); 83 | 3. else 84 | 4. t <- (m - y); 85 | 5. return (x + t) mod m; 86 | 6. end if 87 | ``` 88 | #### Algorithm 6: Multiple-Precision Multiplication 89 | ```plaintext 90 | INPUT: non-negative integers x and y, each with n+1 radix b digits. 91 | OUTPUT: x * y = (z_{2n+s+1} z_{2n+s} ... z_1 z_0)_b. 92 | 93 | 1. for (i from 0 to n+s+1) do 94 | 2. z_i <- 0; 95 | 3. end for 96 | 4. for (i from 0 to s) do 97 | 5. c <- 0; // carry digit 98 | 6. for (j from 0 to n) do 99 | 7. (uv)_b <- z_{i+j} + x_j * y_i + c; 100 | 8. z_{i+j} <- v; c <- u; 101 | 9. end for 102 | 10. z_{n+i+1} <- u; 103 | 11. end for 104 | 12. return (z_{2n+s+1} z_{2n+s} ... z_1 z_0)_b; 105 | ``` 106 | 107 | ### Barrett Reduction Algorithm 108 | 109 | ```plaintext 110 | INPUT: 111 | - Non-negative integers `x` and modulus `p`. 112 | - Radix `b`, the base of `x` and `p` representation. 113 | - Integer `k` such that `k = ⌊log_b(p)⌋ + 1`. 114 | - Integer `z` such that `0 ≤ z < b^(2k)`. 115 | - Precomputed `µ` as `µ = ⌊b^(2k) / p⌋`. 116 | 117 | OUTPUT: `z mod p`. 118 | 119 | 1. Compute `q̄` as `⌊⌊z / b^(k-1)⌋ * µ / b^(k+1)⌋`. 120 | 2. Compute `r` as `(z mod b^(k+1)) - (q̄ * p mod b^(k+1))`. 121 | 3. If `r < 0` then `r <- r + b^(k+1)`. 122 | 4. While `r ≥ p` do `r <- r - p`. 123 | 5. Return `r`. 124 | ``` 125 | 126 | ## Implementation Details 127 | This implementation is designed for 32-bit limbs with a base of \(2^{26}\), allowing efficient use of the 32-bit integer operations available on GPUs. Each function handles carry and overflow conditions to ensure correctness across all limbs. 128 | 129 | ## Acknowledgements 130 | Special thanks to my friends who accompanied me through this project!❤️ 131 | -------------------------------------------------------------------------------- /gpump.cu: -------------------------------------------------------------------------------- 1 | #include "gpump.cuh" 2 | 3 | /*Example divisor n*/ 4 | __constant__ fe_num_t n_fe = { 5 | {3555649, 9937716, 33799165, 60472610, 45788892, 67108863, 67108863, 67108863, 67108863, 4194303} 6 | }; 7 | 8 | /*Precalculated Barret parameter of n*/ 9 | __constant__ fe_num_t mu = { 10 | {29278365, 6081522, 4457178, 21159295, 22094617, 81, 0, 0, 0, 0, 16} 11 | }; 12 | 13 | /*A test function used to print fe_num_t numbers*/ 14 | __device__ void printNumberAsHex(const fe_num_t* num) { 15 | uint8_t bytes[25 * 26 / 8 + 1] = { 0 }; 16 | unsigned long long temp = 0; 17 | int shift = 0; 18 | int byteIndex = 0; 19 | 20 | 21 | for (int i = 0; i < 25; i++) { 22 | temp |= ((unsigned long long)num->n[i]) << shift; 23 | shift += 26; 24 | 25 | while (shift >= 8) { 26 | bytes[byteIndex++] = temp & 0xff; 27 | temp >>= 8; 28 | shift -= 8; 29 | } 30 | } 31 | 32 | if (shift > 0) { 33 | bytes[byteIndex] = temp & 0xff; 34 | } 35 | 36 | for (int i = byteIndex; i >= 0; i--) { 37 | printf("%02x", bytes[i]); 38 | if (i % 4 == 0 && i != 0) printf(" "); 39 | } 40 | printf("\n"); 41 | } 42 | 43 | /* Converters */ 44 | __device__ void fe_num_set_b32(fe_num_t* r, const unsigned char* a) { 45 | 46 | for (int i = 0; i < 25; i++) { 47 | r->n[i] = 0; 48 | } 49 | 50 | for (int i = 0; i < 32; i++) { 51 | for (int j = 0; j < 4; j++) { 52 | int limb = (8 * i + 2 * j) / 26; 53 | if (limb < 25) { 54 | int shift = (8 * i + 2 * j) % 26; 55 | r->n[limb] |= (uint32_t)((a[31 - i] >> (2 * j)) & 0x3) << shift; 56 | } 57 | } 58 | } 59 | } 60 | 61 | __device__ void fe_num_get_b32(unsigned char* r, const fe_num_t* a) { 62 | for (int i = 0; i < 32; i++) { 63 | r[i] = 0; 64 | } 65 | for (int i = 0; i < 32; i++) { 66 | int c = 0; 67 | for (int j = 0; j < 4; j++) { 68 | int limb = (8 * i + 2 * j) / 26; 69 | if (limb < 25) { 70 | int shift = (8 * i + 2 * j) % 26; 71 | c |= ((a->n[limb] >> shift) & 0x3) << (2 * j); 72 | } 73 | } 74 | r[31 - i] = c; 75 | } 76 | } 77 | 78 | __device__ int mpCompare(const fe_num_t* x, const fe_num_t* y) { 79 | for (int i = 24; i >= 0; i--) { 80 | if (x->n[i] > y->n[i]) return 1; 81 | else if (x->n[i] < y->n[i]) return -1; 82 | } 83 | return 0; 84 | } 85 | 86 | __device__ void mpAdd(fe_num_t* result, const fe_num_t* x, const fe_num_t* y) { 87 | unsigned int carry = 0; 88 | for (int i = 0; i < 25; i++) { 89 | unsigned long long sum = (unsigned long long)x->n[i] + y->n[i] + carry; 90 | result->n[i] = sum & ((1ULL << 26) - 1); // sum mod 2^26 91 | carry = sum >> 26; // sum / 2^26 92 | } 93 | } 94 | 95 | 96 | __device__ void addBkPlusOne(fe_num_t* num, int k) { 97 | if (k + 1 < 25) { 98 | num->n[k + 1] += 1; 99 | int i = k + 1; 100 | while (i < 24 && num->n[i] == 0) { 101 | num->n[i + 1] += 1; 102 | i++; 103 | } 104 | } 105 | } 106 | 107 | 108 | 109 | __device__ void mpSubtract(fe_num_t* result, const fe_num_t* x, const fe_num_t* y) { 110 | int borrow = 0; 111 | for (int i = 0; i < 25; i++) { 112 | int sub = x->n[i] - y->n[i] - borrow; 113 | if (sub < 0) { 114 | sub += (1 << 26); 115 | borrow = 1; 116 | } 117 | else { 118 | borrow = 0; 119 | } 120 | result->n[i] = sub; 121 | } 122 | } 123 | 124 | /*Subtract function used in Barret Reduction*/ 125 | __device__ void mpSubtractSafe(fe_num_t* result, const fe_num_t* x, const fe_num_t* y, int k) { 126 | fe_num_t adjusted_x; 127 | for (int i = 0; i < 25; i++) { 128 | adjusted_x.n[i] = x->n[i]; 129 | } 130 | if (mpCompare(x, y) < 0) { 131 | addBkPlusOne(&adjusted_x, k); 132 | } 133 | mpSubtract(result, &adjusted_x, y); 134 | } 135 | 136 | __device__ void mpModularAdd(fe_num_t* result, const fe_num_t* x, const fe_num_t* y, const fe_num_t* mod) { 137 | fe_num_t temp_sum; 138 | mpAdd(&temp_sum, x, y); 139 | if (mpCompare(&temp_sum, mod) >= 0) { 140 | mpSubtract(result, &temp_sum, mod); 141 | } 142 | else { 143 | *result = temp_sum; 144 | } 145 | } 146 | 147 | __device__ void mpMul(fe_num_t* result, const fe_num_t* x, const fe_num_t* y) { 148 | int n = 24; 149 | unsigned long long carry, uv, v; 150 | 151 | for (int i = 0; i <= n + n + 1; i++) { 152 | result->n[i] = 0; 153 | } 154 | 155 | for (int i = 0; i <= n; i++) { // Loop over each limb of y 156 | carry = 0; 157 | for (int j = 0; j <= n; j++) { // Loop over each limb of x 158 | if (i + j <= n + n) { 159 | uv = (unsigned long long)x->n[j] * (unsigned long long)y->n[i] + (unsigned long long)result->n[i + j] + carry; 160 | v = uv & ((1ULL << 26) - 1); // Extract the lower 26 bits 161 | carry = uv >> 26; // Extract the carry (upper bits) 162 | result->n[i + j] = (unsigned int)v; // Store the result 163 | } 164 | } 165 | if (i + n <= n + n) { 166 | result->n[i + n + 1] += (unsigned int)carry; // Store the last carry 167 | } 168 | } 169 | } 170 | 171 | 172 | 173 | 174 | 175 | __device__ void rightShift(const fe_num_t* num, fe_num_t* result, int shift_bits) { 176 | int limb_shift = shift_bits / 26; 177 | int bit_shift = shift_bits % 26; 178 | 179 | for (int i = 0; i < 25; i++) { 180 | result->n[i] = 0; 181 | } 182 | 183 | // Perform the shift for each limb. 184 | for (int i = limb_shift; i < 25; i++) { 185 | result->n[i - limb_shift] = num->n[i]; 186 | } 187 | } 188 | 189 | __device__ void calculateQBar(fe_num_t* q_bar, const fe_num_t* z, const fe_num_t* mu, int k) { 190 | // Step 1: Compute q by shifting z right by k bits. 191 | fe_num_t q; 192 | rightShift(z, &q, (k - 1) * 26); 193 | // Step 2: Compute q * mu. 194 | fe_num_t q_mul_mu; 195 | mpMul(&q_mul_mu, &q, mu); 196 | 197 | // Step 3: Compute q_bar by shifting (q * mu) right by k + 1 bits. 198 | rightShift(&q_mul_mu, q_bar, (k + 1) * 26); 199 | 200 | } 201 | 202 | __device__ void modBkPlus1(fe_num_t* result, const fe_num_t* num, int k) { 203 | // This function computes num mod b^(k+1) where b = 2^26. 204 | // It effectively just copies the first k+1 limbs from num to result. 205 | int limb_count = k + 1; // Assuming each limb is a power of 2^26. 206 | 207 | // Initialize result to zero. 208 | for (int i = 0; i < 25; i++) { 209 | result->n[i] = 0; 210 | } 211 | 212 | // Copy the relevant limbs. 213 | for (int i = 0; i < limb_count && i < 25; i++) { 214 | result->n[i] = num->n[i]; 215 | } 216 | } 217 | 218 | __device__ void calculateR(fe_num_t* r, const fe_num_t* z, const fe_num_t* q_bar, const fe_num_t* p, int k) { 219 | // Step 1: Compute z mod b^(k+1) 220 | 221 | fe_num_t z_mod_bk_plus_1; 222 | modBkPlus1(&z_mod_bk_plus_1, z, k); 223 | 224 | // Step 2: Compute (q_bar * p) mod b^(k+1) 225 | fe_num_t q_bar_mul_p; 226 | mpMul(&q_bar_mul_p, q_bar, p); 227 | fe_num_t q_bar_mul_p_mod_bk_plus_1; 228 | modBkPlus1(&q_bar_mul_p_mod_bk_plus_1, &q_bar_mul_p, k); 229 | 230 | // Step 3: Compute r = (z mod b^(k+1)) - ((q_bar * p) mod b^(k+1)) 231 | mpSubtractSafe(r, &z_mod_bk_plus_1, &q_bar_mul_p_mod_bk_plus_1, k); 232 | 233 | 234 | } 235 | 236 | 237 | 238 | __device__ void barrettReduction(fe_num_t* r, const fe_num_t* z, const fe_num_t* p, const fe_num_t* mu, int k) { 239 | // Calculate q_bar 240 | fe_num_t q_bar; 241 | calculateQBar(&q_bar, z, mu, k); 242 | 243 | // Calculate r 244 | calculateR(r, z, &q_bar, p, k); 245 | 246 | // While r >= p, subtract p from r 247 | while (mpCompare(r, p) >= 0) { 248 | mpSubtract(r, r, p); 249 | } 250 | } 251 | 252 | 253 | /*This function is used to compute the s-signature of an Ethernet EIP 1559 type transaction, and is an example of GPUMP usage*/ 254 | 255 | __device__ void calculateS(uint8_t* s, const uint8_t* k_inv, const uint8_t* h, const uint8_t* rd) { 256 | fe_num_t k_inv_fe, h_fe, rd_fe; 257 | fe_num_set_b32(&k_inv_fe, k_inv); 258 | fe_num_set_b32(&h_fe, h); 259 | fe_num_set_b32(&rd_fe, rd); 260 | 261 | // (h + rd) mod n 262 | fe_num_t h_plus_rd; 263 | mpModularAdd(&h_plus_rd, &h_fe, &rd_fe, &n_fe); 264 | 265 | // k_inv * (h + rd) 266 | fe_num_t product; 267 | mpMul(&product, &k_inv_fe, &h_plus_rd); 268 | 269 | fe_num_t r; 270 | barrettReduction(&r, &product, &n_fe, &mu, 10); 271 | 272 | fe_num_get_b32(s, &r); 273 | } -------------------------------------------------------------------------------- /gpump.cuh: -------------------------------------------------------------------------------- 1 | #ifndef GPUMP_CUH 2 | #define GPUMP_CUH 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | typedef struct { 11 | unsigned int n[50]; 12 | } fe_num_t; 13 | __device__ int mpCompare(const fe_num_t* x, const fe_num_t* y); 14 | __device__ void mpAdd(fe_num_t* result, const fe_num_t* x, const fe_num_t* y); 15 | __device__ void mpSubtract(fe_num_t* result, const fe_num_t* x, const fe_num_t* y); 16 | __device__ void mpModularAdd(fe_num_t* result, const fe_num_t* x, const fe_num_t* y, const fe_num_t* mod); 17 | __device__ void mpMul(fe_num_t* result, const fe_num_t* x, const fe_num_t* y); 18 | __device__ void barrettReduction(fe_num_t* r, const fe_num_t* z, const fe_num_t* p, const fe_num_t* mu, int k); 19 | #endif --------------------------------------------------------------------------------