├── LICENSE ├── README.md ├── bigint.wgsl ├── curve.wgsl ├── curve_test.py ├── field.wgsl ├── index.html ├── main.wgsl ├── pippenger.wgsl ├── pippenger_fake.wgsl └── storage.wgsl /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Sampriti Panda 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # msm-webgpu 2 | 3 | Contributors: Nalin Bhardwaj (@nalinbhardwaj), Adhyyan Sekhsaria (@Adhyyan1252), Sampriti 4 | -------------------------------------------------------------------------------- /bigint.wgsl: -------------------------------------------------------------------------------- 1 | const W = 16u; 2 | const W_mask = (1 << W) - 1u; 3 | const L = 256; 4 | const N = L / W; 5 | 6 | // No overflow 7 | struct BigInt256 { 8 | limbs: array 9 | } 10 | 11 | struct BigInt512 { 12 | limbs: array 13 | } 14 | 15 | struct BigInt272 { 16 | limbs: array 17 | } 18 | 19 | // Careful, a and res may point to the same thing. 20 | fn add(a: BigInt256, b: BigInt256, res: ptr) -> u32 { 21 | var carry: u32 = 0; 22 | for (var i: u32 = 0; i < N; i = i + 1u) { 23 | let c = a.limbs[i] + b.limbs[i] + carry; 24 | (*res).limbs[i] = c & W_mask; 25 | carry = c >> W; 26 | } 27 | return carry; 28 | } 29 | 30 | // assumes a >= b 31 | fn sub(a: BigInt256, b: BigInt256, res: ptr) -> u32 { 32 | var borrow: u32 = 0; 33 | for (var i: u32 = 0; i < N; i = i + 1u) { 34 | (*res).limbs[i] = a.limbs[i] - b.limbs[i] - borrow; 35 | if (a.limbs[i] < (b.limbs[i] + borrow)) { 36 | (*res).limbs[i] += W_mask + 1; 37 | borrow = 1u; 38 | } else { 39 | borrow = 0u; 40 | } 41 | } 42 | return borrow; 43 | } 44 | 45 | // repeated code pls fix 46 | fn add_512(a: BigInt512, b: BigInt512, res: ptr) -> u32 { 47 | var carry: u32 = 0; 48 | for (var i: u32 = 0; i < (2*N); i = i + 1u) { 49 | let c = a.limbs[i] + b.limbs[i] + carry; 50 | (*res).limbs[i] = c & W_mask; 51 | carry = c >> W; 52 | } 53 | return carry; 54 | } 55 | 56 | // assumes a >= b 57 | fn sub_512(a: BigInt512, b: BigInt512, res: ptr) -> u32 { 58 | var borrow: u32 = 0; 59 | for (var i: u32 = 0; i < (2*N); i = i + 1u) { 60 | (*res).limbs[i] = a.limbs[i] - b.limbs[i] - borrow; 61 | if (a.limbs[i] < (b.limbs[i] + borrow)) { 62 | (*res).limbs[i] += W_mask + 1; 63 | borrow = 1u; 64 | } else { 65 | borrow = 0u; 66 | } 67 | } 68 | return borrow; 69 | } 70 | 71 | // assumes a >= b 72 | fn sub_272(a: BigInt272, b: BigInt272, res: ptr) -> u32 { 73 | var borrow: u32 = 0; 74 | for (var i: u32 = 0; i < N + 1; i = i + 1u) { 75 | (*res).limbs[i] = a.limbs[i] - b.limbs[i] - borrow; 76 | if (a.limbs[i] < (b.limbs[i] + borrow)) { 77 | (*res).limbs[i] += W_mask + 1; 78 | borrow = 1u; 79 | } else { 80 | borrow = 0u; 81 | } 82 | } 83 | return borrow; 84 | } 85 | 86 | fn mul(a: BigInt256, b: BigInt256) -> BigInt512 { 87 | var res: BigInt512; 88 | for (var i = 0u; i < N; i = i + 1u) { 89 | for (var j = 0u; j < N; j = j + 1u) { 90 | let c = a.limbs[i] * b.limbs[j]; 91 | res.limbs[i+j] += c & W_mask; 92 | res.limbs[i+j+1] += c >> W; 93 | } 94 | } 95 | // start from 0 and carry the extra over to the next index 96 | for (var i = 0u; i < 2*N - 1; i = i + 1u) { 97 | res.limbs[i+1] += res.limbs[i] >> W; 98 | res.limbs[i] = res.limbs[i] & W_mask; 99 | } 100 | return res; 101 | } 102 | 103 | fn sqr(a: BigInt256) -> BigInt512 { 104 | var res: BigInt512; 105 | for (var i = 0u;i < N; i = i + 1u) { 106 | let sc = a.limbs[i] * a.limbs[i]; 107 | res.limbs[(i << 1)] += sc & W_mask; 108 | res.limbs[(i << 1)+1] += sc >> W; 109 | 110 | for (var j = i + 1;j < N;j = j + 1u) { 111 | let c = a.limbs[i] * a.limbs[j]; 112 | res.limbs[i+j] += (c & W_mask) << 1; 113 | res.limbs[i+j+1] += (c >> W) << 1; 114 | } 115 | } 116 | 117 | for (var i = 0u; i < 2*N - 1; i = i + 1u) { 118 | res.limbs[i+1] += res.limbs[i] >> W; 119 | res.limbs[i] = res.limbs[i] & W_mask; 120 | } 121 | return res; 122 | } 123 | -------------------------------------------------------------------------------- /curve.wgsl: -------------------------------------------------------------------------------- 1 | struct JacobianPoint { 2 | x: BaseField, 3 | y: BaseField, 4 | z: BaseField 5 | }; 6 | 7 | fn is_inf(p: JacobianPoint) -> bool { 8 | return field_eq(p.z, ZERO); 9 | } 10 | 11 | fn jacobian_double(p: JacobianPoint) -> JacobianPoint { 12 | // https://www.hyperelliptic.org/EFD/g1p/auto-shortw-jacobian-0.html#doubling-dbl-2009-l 13 | let A = field_sqr(p.x); 14 | let B = field_sqr(p.y); 15 | let C = field_sqr(B); 16 | let X1plusB = field_add(p.x, B); 17 | let D = field_small_scalar_shift(1, field_sub(field_sqr(X1plusB), field_add(A, C))); 18 | let E = field_add(field_small_scalar_shift(1, A), A); 19 | let F = field_sqr(E); 20 | let x3 = field_sub(F, field_small_scalar_shift(1, D)); 21 | let y3 = field_sub(field_mul(E, field_sub(D, x3)), field_small_scalar_shift(3, C)); 22 | let z3 = field_mul(field_small_scalar_shift(1, p.y), p.z); 23 | return JacobianPoint(x3, y3, z3); 24 | } 25 | 26 | // double p and add q 27 | // todo: can be optimized if one of the z coordinates is 1 28 | // fn jacobian_dadd(p: JacobianPoint, q: JacobianPoint) -> JacobianPoint { 29 | // if (is_inf(p)) { 30 | // return q; 31 | // } else if (is_inf(q)) { 32 | // return jacobian_double(p); 33 | // } 34 | 35 | // let twox = field_small_scalar_shift(1, p.x); 36 | // let sqrx = field_mul(p.x, p.x); 37 | // let dblR = field_add(field_small_scalar_shift(1, sqrx), sqrx); 38 | // let dblH = field_small_scalar_shift(1, p.y); 39 | 40 | // let x3 = field_mul(q.z, q.z); 41 | // let z3 = field_mul(p.z, q.z); 42 | // let addH = field_mul(p.z, p.z); 43 | 44 | // } 45 | 46 | fn jacobian_add(p: JacobianPoint, q: JacobianPoint) -> JacobianPoint { 47 | // https://www.hyperelliptic.org/EFD/g1p/auto-shortw-jacobian-0.html#addition-add-2007-bl 48 | if (field_eq(p.y, ZERO)) { 49 | return q; 50 | } 51 | if (field_eq(q.y, ZERO)) { 52 | return p; 53 | } 54 | 55 | let Z1Z1 = field_sqr(p.z); 56 | let Z2Z2 = field_sqr(q.z); 57 | let U1 = field_mul(p.x, Z2Z2); 58 | let U2 = field_mul(q.x, Z1Z1); 59 | let S1 = field_mul(p.y, field_mul(Z2Z2, q.z)); 60 | let S2 = field_mul(q.y, field_mul(Z1Z1, p.z)); 61 | if (field_eq(U1, U2)) { 62 | if (field_eq(S1, S2)) { 63 | return jacobian_double(p); 64 | } else { 65 | return JacobianPoint(ZERO, ZERO, ONE); 66 | } 67 | } 68 | 69 | let H = field_sub(U2, U1); 70 | let I = field_small_scalar_shift(2, field_sqr(H)); 71 | let J = field_mul(H, I); 72 | let R = field_small_scalar_shift(1, field_sub(S2, S1)); 73 | let V = field_mul(U1, I); 74 | let nx = field_sub(field_sqr(R), field_add(J, field_small_scalar_shift(1, V))); 75 | let ny = field_sub(field_mul(R, field_sub(V, nx)), field_small_scalar_shift(1, field_mul(S1, J))); 76 | let nz = field_mul(H, field_sub(field_pow(field_add(p.z, q.z), 2), field_add(Z1Z1, Z2Z2))); 77 | return JacobianPoint(nx, ny, nz); 78 | } 79 | 80 | fn jacobian_mul(p: JacobianPoint, k: ScalarField) -> JacobianPoint { 81 | var r: JacobianPoint = JacobianPoint(ZERO, ZERO, ONE); 82 | var t: JacobianPoint = p; 83 | for (var i = 0u; i < N; i = i + 1u) { 84 | var k_s = k.limbs[i]; 85 | for (var j = 0u; j < W; j = j + 1u) { 86 | if ((k_s & 1) == 1u) { 87 | r = jacobian_add(r, t); 88 | } 89 | t = jacobian_double(t); 90 | k_s = k_s >> 1; 91 | } 92 | } 93 | return r; 94 | } 95 | -------------------------------------------------------------------------------- /curve_test.py: -------------------------------------------------------------------------------- 1 | from sage.all_cmdline import * 2 | 3 | P = GF(0x40000000000000000000000000000000224698fc094cf91b992d30ed00000001) 4 | E = EllipticCurve(P, [0, 5]) 5 | 6 | def from_jacob(res): 7 | res = (P(res[0]), P(res[1]), P(res[2])) 8 | return E(res[0]/(res[2]**2), res[1]/(res[2]**3)) 9 | 10 | G = E(22304380549750642616165107876029345325911088198117424279971154895103981677948, 14354096399413720219912473247241970521073754194408414292017996939864946211566) 11 | s = 115792089237316195423570985008687907853269984665640564039457584007913129639935 12 | print(G) 13 | print(s) 14 | 15 | a = 0x3d50eb7491f36a1c746cf044d8e97fd1e5f6d0d6e4da9633d37275b198640140 16 | b = 0xbe564a8781cbd8fa78a8ea366e6d0a03b368ad2033cd06efa3954c0e5b05603 17 | c = 0x31e71da1d2922ce27f46dade9cd8d540ed3046ae8c4eb87c10427d10ca722637 18 | 19 | kek = from_jacob((a, b, c)) 20 | print(kek == G * s * 4096 * 64) 21 | for i in range(4096+1): 22 | if kek == G * s * 64 * i: 23 | print(i) 24 | break 25 | 26 | exit() 27 | 28 | while True: 29 | a = eval(input().split()[-1]) 30 | b = eval(input().split()[-1]) 31 | c = eval(input().split()[-1]) 32 | 33 | try: 34 | print(from_jacob((a, b, c))) 35 | except Exception as e: 36 | print(e) 37 | -------------------------------------------------------------------------------- /field.wgsl: -------------------------------------------------------------------------------- 1 | alias BaseField = BigInt256; 2 | alias ScalarField = BigInt256; 3 | 4 | const BASE_MODULUS: BigInt256 = BigInt256( 5 | array(1u, 0u, 12525u, 39213u, 63771u, 2380u, 39164u, 8774u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 16384u) 6 | ); 7 | 8 | const BASE_MODULUS_MEDIUM_WIDE: BigInt272 = BigInt272( 9 | array(1u, 0u, 12525u, 39213u, 63771u, 2380u, 39164u, 8774u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 16384u, 0u) 10 | ); 11 | 12 | const BASE_MODULUS_WIDE: BigInt512 = BigInt512( 13 | array(1u, 0u, 12525u, 39213u, 63771u, 2380u, 39164u, 8774u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 16384u, 14 | 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u) 15 | ); 16 | 17 | const BASE_NBITS = 255; 18 | 19 | const BASE_M = BigInt256( 20 | array(65532u, 65535u, 15435u, 39755u, 7057u, 56012u, 39951u, 30437u, 65535u, 65535u, 65535u, 65535u, 65535u, 65535u, 65535u, 65535u) 21 | ); 22 | 23 | const ZERO: BigInt256 = BigInt256( 24 | array(0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u) 25 | ); 26 | 27 | const ONE: BigInt256 = BigInt256( 28 | array(1u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u) 29 | ); 30 | 31 | fn get_higher_with_slack(a: BigInt512) -> BaseField { 32 | var out: BaseField; 33 | const slack = L - BASE_NBITS; 34 | for (var i = 0u; i < N; i = i + 1u) { 35 | out.limbs[i] = ((a.limbs[i + N] << slack) + (a.limbs[i + N - 1] >> (W - slack))) & W_mask; 36 | } 37 | return out; 38 | } 39 | 40 | // once reduces once (assumes that 0 <= a < 2 * mod) 41 | fn field_reduce(a: BigInt256) -> BaseField { 42 | var res: BigInt256; 43 | var underflow = sub(a, BASE_MODULUS, &res); 44 | if (underflow == 1u) { 45 | return a; 46 | } else { 47 | return res; 48 | } 49 | } 50 | 51 | fn shorten(a: BigInt272) -> BigInt256 { 52 | var out: BigInt256; 53 | for (var i = 0u; i < N; i = i + 1u) { 54 | out.limbs[i] = a.limbs[i]; 55 | } 56 | return out; 57 | } 58 | 59 | // reduces l times (assumes that 0 <= a < multi * mod) 60 | fn field_reduce_272(a: BigInt272, multi: u32) -> BaseField { 61 | var res: BigInt272; 62 | var cur = a; 63 | var cur_multi = multi + 1; 64 | while (cur_multi > 0u) { 65 | var underflow = sub_272(cur, BASE_MODULUS_MEDIUM_WIDE, &res); 66 | if (underflow == 1u) { 67 | return shorten(cur); 68 | } else { 69 | cur = res; 70 | } 71 | cur_multi = cur_multi - 1u; 72 | } 73 | return ZERO; 74 | } 75 | 76 | fn field_add(a: BaseField, b: BaseField) -> BaseField { 77 | var res: BaseField; 78 | add(a, b, &res); 79 | return field_reduce(res); 80 | } 81 | 82 | fn field_sub(a: BaseField, b: BaseField) -> BaseField { 83 | var res: BaseField; 84 | var carry = sub(a, b, &res); 85 | if (carry == 0u) { 86 | return res; 87 | } 88 | add(res, BASE_MODULUS, &res); 89 | return res; 90 | } 91 | 92 | fn field_mul(a: BaseField, b: BaseField) -> BaseField { 93 | var xy: BigInt512 = mul(a, b); 94 | var xy_hi: BaseField = get_higher_with_slack(xy); 95 | var l: BigInt512 = mul(xy_hi, BASE_M); 96 | var l_hi: BaseField = get_higher_with_slack(l); 97 | var lp: BigInt512 = mul(l_hi, BASE_MODULUS); 98 | var r_wide: BigInt512; 99 | sub_512(xy, lp, &r_wide); 100 | 101 | var r_wide_reduced: BigInt512; 102 | var underflow = sub_512(r_wide, BASE_MODULUS_WIDE, &r_wide_reduced); 103 | if (underflow == 0u) { 104 | r_wide = r_wide_reduced; 105 | } 106 | var r: BaseField; 107 | for (var i = 0u; i < N; i = i + 1u) { 108 | r.limbs[i] = r_wide.limbs[i]; 109 | } 110 | return field_reduce(r); 111 | } 112 | 113 | // This is slow, probably don't want to use this 114 | // fn field_small_scalar_mul(a: u32, b: BaseField) -> BaseField { 115 | // var constant: BaseField; 116 | // constant.limbs[0] = a; 117 | // return field_mul(constant, b); 118 | // } 119 | 120 | fn field_small_scalar_shift(l: u32, a: BaseField) -> BaseField { // max shift allowed is 16 121 | // assert (l < 16u); 122 | var res: BigInt272; 123 | for (var i = 0u; i < N; i = i + 1u) { 124 | let shift = a.limbs[i] << l; 125 | res.limbs[i] = res.limbs[i] | (shift & W_mask); 126 | res.limbs[i + 1] = (shift >> W); 127 | } 128 | 129 | var output = field_reduce_272(res, (1u << l)); // can probably be optimised 130 | return output; 131 | } 132 | 133 | fn field_pow(p: BaseField, e: u32) -> BaseField { 134 | var res: BaseField = p; 135 | for (var i = 1u; i < e; i = i + 1u) { 136 | res = field_mul(res, p); 137 | } 138 | return res; 139 | } 140 | 141 | fn field_eq(a: BaseField, b: BaseField) -> bool { 142 | for (var i = 0u; i < N; i = i + 1u) { 143 | if (a.limbs[i] != b.limbs[i]) { 144 | return false; 145 | } 146 | } 147 | return true; 148 | } 149 | 150 | fn field_sqr(a: BaseField) -> BaseField { 151 | var xy: BigInt512 = sqr(a); 152 | var xy_hi: BaseField = get_higher_with_slack(xy); 153 | var l: BigInt512 = mul(xy_hi, BASE_M); 154 | var l_hi: BaseField = get_higher_with_slack(l); 155 | var lp: BigInt512 = mul(l_hi, BASE_MODULUS); 156 | var r_wide: BigInt512; 157 | sub_512(xy, lp, &r_wide); 158 | 159 | var r_wide_reduced: BigInt512; 160 | var underflow = sub_512(r_wide, BASE_MODULUS_WIDE, &r_wide_reduced); 161 | if (underflow == 0u) { 162 | r_wide = r_wide_reduced; 163 | } 164 | var r: BaseField; 165 | for (var i = 0u; i < N; i = i + 1u) { 166 | r.limbs[i] = r_wide.limbs[i]; 167 | } 168 | return field_reduce(r); 169 | } 170 | 171 | /* 172 | fn field_to_bits(a: BigInt256) -> array { 173 | let res: array = array(); 174 | for (var i = 0u;i < N;i += 1) { 175 | for (var j = 0u;j < 32u;j += 1) { 176 | var bit = (a.limbs[i] >> j) & 1u; 177 | res[i * 32u + j] = bit == 1u; 178 | } 179 | } 180 | return res; 181 | } 182 | */ 183 | -------------------------------------------------------------------------------- /index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 252 | 253 | 254 | 255 | 256 | -------------------------------------------------------------------------------- /main.wgsl: -------------------------------------------------------------------------------- 1 | // 3 -> 2 2 | // storage -> workgroup 3 | // second stage 4 | 5 | 6 | @compute @workgroup_size(1) 7 | fn main( 8 | @builtin(global_invocation_id) global_id: vec3, 9 | @builtin(local_invocation_id) local_id: vec3 10 | ) { 11 | let gidx = global_id.x; 12 | let lidx = local_id.x; 13 | 14 | result[gidx] = pippenger(gidx); 15 | } 16 | 17 | @compute @workgroup_size(256) 18 | fn aggregate( 19 | @builtin(global_invocation_id) global_id: vec3, 20 | @builtin(local_invocation_id) local_id: vec3 21 | ) { 22 | let gidx = global_id.x; 23 | let lidx = local_id.x; 24 | 25 | const split = NUM_INVOCATIONS / 256; 26 | 27 | for (var j = 1; j < split; j = j + 1) { 28 | result[lidx] = jacobian_add(result[lidx], result[lidx + split * 256]); 29 | } 30 | 31 | storageBarrier(); 32 | 33 | for (var offset: u32 = 256 / 2u; offset > 0u; offset = offset / 2u) { 34 | if (lidx < offset) { 35 | result[gidx] = jacobian_add(result[gidx], result[gidx + offset]); 36 | } 37 | storageBarrier(); 38 | } 39 | } 40 | -------------------------------------------------------------------------------- /pippenger.wgsl: -------------------------------------------------------------------------------- 1 | const POINTS_PER_INVOCATION = 64u; 2 | const PARTITION_SIZE = 8u; 3 | const POW_PART = (1u << PARTITION_SIZE); 4 | const NUM_PARTITIONS = POINTS_PER_INVOCATION / PARTITION_SIZE; 5 | const PS_SZ = POW_PART; 6 | const BB_SIZE = 256; 7 | const BB_SIZE_FAKE = 20; 8 | 9 | @group(0) @binding(4) 10 | var powerset_sums: array; 11 | @group(0) @binding(5) 12 | var cur_sum: array; 13 | 14 | fn pippenger(gidx: u32) -> JacobianPoint { 15 | var ps_base = gidx * PS_SZ; 16 | var sum_base = i32(gidx) * BB_SIZE; 17 | var point_base = gidx * POINTS_PER_INVOCATION; 18 | 19 | // first calculate power set sums for each partition of points 20 | // then calculate the sets for each point 21 | 22 | for(var bb = 0; bb < BB_SIZE; bb = bb + 1) { 23 | cur_sum[sum_base + bb] = JacobianPoint(ZERO, ZERO, ONE); 24 | } 25 | for(var i = 0u; i < PS_SZ; i = i + 1) { 26 | powerset_sums[ps_base + i] = JacobianPoint(ZERO, ZERO, ONE); 27 | } 28 | 29 | 30 | for(var i = 0u; i < NUM_PARTITIONS; i = i + 1) { 31 | 32 | // compute all power sums in this partition 33 | var idx = 0u; 34 | for(var j = 1u; j < POW_PART; j = j + 1){ 35 | if((i32(j) & -i32(j)) == i32(j)) { 36 | powerset_sums[ps_base + j] = points[point_base + i * PARTITION_SIZE + idx]; 37 | idx = idx + 1; 38 | } else { 39 | let cur_point = points[point_base + i * PARTITION_SIZE + idx]; 40 | let mask = j & u32(j - 1); 41 | let other_mask = j ^ mask; 42 | powerset_sums[ps_base + j] = jacobian_add(powerset_sums[ps_base + mask], powerset_sums[ps_base + u32(other_mask)]); 43 | } 44 | } 45 | 46 | for(var bb: i32 = BB_SIZE - 1; bb >= 0; bb = bb - 1){ 47 | var b = u32(bb); 48 | 49 | var powerset_idx = 0u; 50 | let modbW = b % W; 51 | let quotbW = b / W; 52 | for(var j = 0u; j < PARTITION_SIZE; j = j + 1){ 53 | if((scalars[point_base + i * PARTITION_SIZE + j].limbs[quotbW] & (1u << modbW)) > 0) { 54 | powerset_idx = powerset_idx | (1u << j); 55 | } 56 | } 57 | cur_sum[sum_base + bb] = jacobian_add(cur_sum[sum_base + bb], powerset_sums[ps_base + powerset_idx]); 58 | } 59 | } 60 | var running_total: JacobianPoint; 61 | for(var bb = BB_SIZE - 1; bb >= 0; bb = bb - 1){ 62 | running_total = jacobian_add(jacobian_double(running_total), cur_sum[sum_base + bb]); 63 | } 64 | return running_total; 65 | } 66 | -------------------------------------------------------------------------------- /pippenger_fake.wgsl: -------------------------------------------------------------------------------- 1 | fn pippenger_fake(points: array, scalars: array) -> u32 { 2 | 3 | // first calculate power set sums for each partition of points 4 | // then calculate the sets for each point 5 | var powerset_sums: array; 6 | 7 | var cur_sum: array; 8 | 9 | for(var i = 0u; i < NUM_PARTITIONS; i = i + 1) { 10 | 11 | // compute all power sums in this partition 12 | var idx = 0u; 13 | for(var j = 1u; j < POW_PART; j = j + 1){ 14 | if((i32(j) & -i32(j)) == i32(j)) { 15 | powerset_sums[j] = points[i * PARTITION_SIZE + idx]; 16 | idx = idx + 1; 17 | } else { 18 | let cur_point = points[i * PARTITION_SIZE + idx]; 19 | let mask = j & u32(j - 1); 20 | let other_mask = j ^ mask; 21 | powerset_sums[j] = powerset_sums[mask] + powerset_sums[u32(other_mask)]; 22 | } 23 | } 24 | 25 | for(var bb = BB_SIZE_FAKE; bb >= 0; bb = bb - 1){ 26 | var b = u32(bb); 27 | 28 | var powerset_idx = 0u; 29 | for(var j = 0u; j < PARTITION_SIZE; j = j + 1){ 30 | if((scalars[i * PARTITION_SIZE + j] & (1u << b)) > 0){ 31 | powerset_idx = powerset_idx | (1u << j); 32 | } 33 | } 34 | cur_sum[bb] = cur_sum[bb] + powerset_sums[powerset_idx]; 35 | } 36 | } 37 | var running_total = 0u; 38 | for(var bb = BB_SIZE_FAKE; bb >= 0; bb = bb - 1){ 39 | running_total = running_total * 2u + cur_sum[bb]; 40 | } 41 | return running_total; 42 | } 43 | -------------------------------------------------------------------------------- /storage.wgsl: -------------------------------------------------------------------------------- 1 | const WORKGROUP_SIZE = 64; 2 | const NUM_INVOCATIONS = 4096; 3 | const MSM_SIZE = WORKGROUP_SIZE * NUM_INVOCATIONS; 4 | 5 | @group(0) @binding(0) 6 | var points: array; 7 | @group(0) @binding(1) 8 | var scalars: array; 9 | @group(0) @binding(2) 10 | var result: array; 11 | @group(0) @binding(3) 12 | var mem: array; 13 | --------------------------------------------------------------------------------