├── zorch ├── __init__.py ├── m31 │ ├── .test.py.swp │ ├── .__init__.py.swp │ ├── .m31_field.py.swp │ ├── .m31_utils.py.swp │ ├── __pycache__ │ │ ├── test.cpython-310.pyc │ │ ├── __init__.cpython-310.pyc │ │ ├── m31_circle.cpython-310.pyc │ │ ├── m31_field.cpython-310.pyc │ │ └── m31_utils.cpython-310.pyc │ ├── __init__.py │ ├── m31_circle.py │ ├── test.py │ ├── m31_numpy_utils.py │ ├── m31_utils.py │ └── m31_field.py ├── binary │ ├── __init__.py │ ├── test.py │ ├── utils.py │ └── binary_field.py └── koalabear │ ├── __init__.py │ ├── koalabear_numpy_utils.py │ └── koalabear_field.py ├── README.md └── setup.py /zorch/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /zorch/m31/.test.py.swp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vbuterin/zorch/HEAD/zorch/m31/.test.py.swp -------------------------------------------------------------------------------- /zorch/m31/.__init__.py.swp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vbuterin/zorch/HEAD/zorch/m31/.__init__.py.swp -------------------------------------------------------------------------------- /zorch/m31/.m31_field.py.swp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vbuterin/zorch/HEAD/zorch/m31/.m31_field.py.swp -------------------------------------------------------------------------------- /zorch/m31/.m31_utils.py.swp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vbuterin/zorch/HEAD/zorch/m31/.m31_utils.py.swp -------------------------------------------------------------------------------- /zorch/m31/__pycache__/test.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vbuterin/zorch/HEAD/zorch/m31/__pycache__/test.cpython-310.pyc -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Zorch is a package for CUDA-optimized STARK proving. 2 | 3 | All files freely released under the WTFPL: http://www.wtfpl.net/ 4 | -------------------------------------------------------------------------------- /zorch/m31/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vbuterin/zorch/HEAD/zorch/m31/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /zorch/m31/__pycache__/m31_circle.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vbuterin/zorch/HEAD/zorch/m31/__pycache__/m31_circle.cpython-310.pyc -------------------------------------------------------------------------------- /zorch/m31/__pycache__/m31_field.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vbuterin/zorch/HEAD/zorch/m31/__pycache__/m31_field.cpython-310.pyc -------------------------------------------------------------------------------- /zorch/m31/__pycache__/m31_utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vbuterin/zorch/HEAD/zorch/m31/__pycache__/m31_utils.cpython-310.pyc -------------------------------------------------------------------------------- /zorch/binary/__init__.py: -------------------------------------------------------------------------------- 1 | from . import utils 2 | from . import binary_field 3 | from .binary_field import Binary, ExtendedBinary 4 | from .utils import zeros_like 5 | -------------------------------------------------------------------------------- /zorch/koalabear/__init__.py: -------------------------------------------------------------------------------- 1 | from . import koalabear_numpy_utils 2 | from .koalabear_numpy_utils import modulus, zeros_like 3 | from . import koalabear_field 4 | from .koalabear_field import KoalaBear, ExtendedKoalaBear, matmul 5 | -------------------------------------------------------------------------------- /zorch/m31/__init__.py: -------------------------------------------------------------------------------- 1 | try: 2 | from . import m31_utils 3 | from .m31_utils import modulus, zeros_like 4 | except: 5 | from . import m31_numpy_utils 6 | from .m31_numpy_utils import modulus, zeros_like 7 | from . import m31_field 8 | from . import m31_circle 9 | from .m31_field import M31, ExtendedM31, matmul 10 | from .m31_circle import Point, Z, G 11 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # setup.py 2 | from setuptools import setup, find_packages 3 | 4 | setup( 5 | name='zorch', 6 | version='0.1.12', 7 | packages=find_packages(), 8 | description='Cupy-based tools for STARK proving', 9 | long_description=open('README.md').read(), 10 | long_description_content_type='text/markdown', 11 | author='Vitalik Buterin', 12 | author_email='v@buterin.com', 13 | url='https://github.com/vbuterin/zorch', # Replace with your GitHub repo 14 | classifiers=[ 15 | 'Programming Language :: Python :: 3', 16 | 'License :: OSI Approved :: MIT License', 17 | 'Operating System :: OS Independent', 18 | ], 19 | python_requires='>=3.6', 20 | ) 21 | -------------------------------------------------------------------------------- /zorch/binary/test.py: -------------------------------------------------------------------------------- 1 | from .utils import cp, arange 2 | from .binary_field import Binary, ExtendedBinary 3 | 4 | def test(): 5 | x_orig = Binary(3 ** arange(10**4)) 6 | x = x_orig.copy() 7 | for i in range(16): 8 | x = x * x 9 | assert x_orig == x 10 | assert ( 11 | x[:10] * (x[10:20] ^ x[20:30]) == 12 | (x[:10] * x[10:20]) ^ (x[:10] * x[20:30]) 13 | ) 14 | assert (x ^ x[::-1]) ^ x[::-1] == x 15 | print("Basic arithmetic tests passed") 16 | x4_orig = ExtendedBinary( 17 | 3 ** arange(8 * 10**5).reshape((10**5, 8)) 18 | ) 19 | x4 = ExtendedBinary(cp.copy(x4_orig.value)) 20 | x5 = ExtendedBinary(cp.copy(x4_orig.value)) 21 | for i in range(4): 22 | x4 = x4 * x4 23 | for i in range(15): 24 | x5 *= x4_orig 25 | assert x4 == x5 26 | assert ( 27 | x4[:10] * (x4[10:20] ^ x4[20:30]) == 28 | (x4[:10] * x4[10:20]) ^ (x4[:10] * x4[20:30]) 29 | ) 30 | assert ( 31 | x4[:10] * (x[10:20] ^ x4[20:30]) == 32 | (x4[:10] * x[10:20]) ^ (x4[:10] * x4[20:30]) 33 | ) 34 | print("Extended arithmetic tests passed") 35 | x6 = Binary(3 ** arange(10**4)) 36 | x7 = x6.inv() 37 | x8 = x6 * x7 38 | assert x8 == 1 39 | assert x8 ^ 1 == 0 40 | print("Basic modinv tests passed") 41 | assert (x4.inv() * x4_orig).inv() * x4_orig == x4 42 | assert (x4[:10000].inv() * x).inv() * x == x4[:10000] 43 | x9 = ExtendedBinary(x6.value.reshape((2500, 4))) 44 | x10 = x9.inv() 45 | assert x9 * x10 == 1 46 | assert x9 * x10 ^ 1 == 0 47 | print("Extended modinv tests passed") 48 | x11 = ExtendedBinary([1,2,3,4,0,0,0,0]) 49 | x11_short = ExtendedBinary([1,2,3,4]) 50 | example = ExtendedBinary([5,6,7,8,9,10,11,12]) 51 | assert x11 * example == x11_short * example 52 | assert x11 ^ example == x11_short ^ example 53 | print("Different-length tests passed") 54 | -------------------------------------------------------------------------------- /zorch/m31/m31_circle.py: -------------------------------------------------------------------------------- 1 | try: 2 | import cupy 3 | from .m31_utils import modulus 4 | except: 5 | from .m31_numpy_utils import modulus 6 | from .m31_field import M31 7 | 8 | class Point(): 9 | def __init__(self, x, y): 10 | assert x.shape == y.shape 11 | self.x = x 12 | self.y = y 13 | 14 | @classmethod 15 | def zeros(cls, shape): 16 | return cls(M31.zeros(shape) + 1, M31.zeros(shape)) 17 | 18 | @classmethod 19 | def append(cls, *args, axis=0): 20 | return cls( 21 | args[0].x.__class__.append(*(v.x for v in args), axis=axis), 22 | args[0].x.__class__.append(*(v.y for v in args), axis=axis) 23 | ) 24 | 25 | @property 26 | def shape(self): 27 | return self.x.shape 28 | 29 | def reshape(self, shape): 30 | return Point( 31 | self.x.reshape(shape), 32 | self.y.reshape(shape) 33 | ) 34 | 35 | def swapaxes(self, ax1, ax2): 36 | return Point(self.x.swapaxes(ax1, ax2), self.y.swapaxes(ax1, ax2)) 37 | 38 | def copy(self): 39 | return Point(cp.copy(self.x), cp.copy(self.y)) 40 | 41 | @property 42 | def ndim(self): 43 | return self.x.ndim 44 | 45 | def to_extended(self): 46 | return Point(self.x.to_extended(), self.y.to_extended()) 47 | 48 | def __getitem__(self, index): 49 | return Point(self.x[index], self.y[index]) 50 | 51 | def __setitem__(self, index, value): 52 | assert self.__class__ == value.__class__ 53 | self.x[index] = value.x 54 | self.y[index] = value.y 55 | 56 | def __add__(self, other): 57 | assert self.__class__ == other.__class__ 58 | return Point( 59 | self.x * other.x - self.y * other.y, 60 | self.x * other.y + self.y * other.x 61 | ) 62 | 63 | def double(self): 64 | return Point( 65 | self.x * self.x * 2 - 1, 66 | self.x * self.y * 2 67 | ) 68 | 69 | def __repr__(self): 70 | return f'(x={self.x}, y={self.y})' 71 | 72 | def __len__(self): 73 | return len(self.value) 74 | 75 | def tobytes(self): 76 | if isinstance(self.x, ExtendedM31) or isinstance(self.y, ExtendedM31): 77 | self.x = self.x.to_extended() 78 | self.y = self.y.to_extended() 79 | return self.x.tobytes() + self.y.tobytes() 80 | 81 | def __eq__(self, other): 82 | return self.x == other.x and self.y == other.y 83 | 84 | Z = Point(M31(1), M31(0)) 85 | 86 | G = Point(M31(1268011823), M31(2)) 87 | -------------------------------------------------------------------------------- /zorch/m31/test.py: -------------------------------------------------------------------------------- 1 | from .m31_field import M31, ExtendedM31, matmul 2 | from .m31_utils import modulus, cp, arange 3 | from .m31_circle import Point, G, Z 4 | 5 | def test(): 6 | x_orig = M31(3 ** arange(10**7) % modulus) 7 | x = M31(x_orig.value.copy()) 8 | for i in range(31): 9 | x = x * x 10 | assert x == x_orig * x_orig 11 | assert ( 12 | x[:10] * (x[10:20] + x[20:30]) == 13 | (x[:10] * x[10:20]) + (x[:10] * x[20:30]) 14 | ) 15 | assert x ** 5 == x * x * x * x * x 16 | assert (x.inv() * x) == 1 17 | assert (x + x_orig) - x_orig == x 18 | print("Basic arithmetic tests passed") 19 | x4_orig = ExtendedM31( 20 | 3 ** arange(4 * 10**7, dtype=cp.uint16).reshape((10**7, 4)) % modulus 21 | ) 22 | x4 = ExtendedM31(cp.copy(x4_orig.value)) 23 | x5 = ExtendedM31(cp.copy(x4_orig.value)) 24 | for i in range(4): 25 | x4 = x4 * x4 26 | for i in range(15): 27 | x5 *= x4_orig 28 | assert x4 == x5 29 | assert ( 30 | x4[:10] * (x4[10:20] + x4[20:30]) == 31 | (x4[:10] * x4[10:20]) + (x4[:10] * x4[20:30]) 32 | ) 33 | print("Extended arithmetic tests passed") 34 | x6 = M31(3 ** arange(10**6) % modulus) 35 | x7 = x6.inv() 36 | x8 = x6 * x7 37 | assert x8 == 1 38 | assert x8 - 1 == 0 39 | print("Basic modinv tests passed") 40 | assert (x4.inv() * x4_orig).inv() * x4_orig == x4 41 | assert (x4.inv() * x).inv() * x == x4 42 | x9 = ExtendedM31(x6.value.reshape((250000, 4))) 43 | x10 = x9.inv() 44 | assert x9 * x10 == 1 45 | assert x9 * x10 - 1 == 0 46 | print("Extended modinv tests passed") 47 | x = G 48 | for i in range(31): 49 | x = x.double() 50 | assert x == Z 51 | x = Point.zeros(1) 52 | coeff = G 53 | for i in range(4): 54 | double_x = x.double() 55 | x = Point.zeros(x.shape[0] * 2) 56 | x[::2] = double_x 57 | x[1::2] = double_x + G 58 | for i in range(15): 59 | assert x[i+1] == x[i] + G 60 | ext_point = Point( 61 | ExtendedM31([968417241, 1522700037, 1711331479, 520782658]), 62 | ExtendedM31([950082908, 1835034903, 1779185035, 1647796460]) 63 | ) 64 | x = Point.zeros(1).to_extended() 65 | coeff = ext_point 66 | for i in range(4): 67 | double_x = x.double() 68 | x = Point.zeros(x.shape[0] * 2).to_extended() 69 | x[::2] = double_x 70 | x[1::2] = double_x + ext_point 71 | for i in range(15): 72 | assert x[i+1] == x[i] + ext_point 73 | assert x[i+1] + G == (x[i] + G) + ext_point 74 | print("Point arithmetic tests passed") 75 | a = M31([123, 456000]) 76 | m1 = M31([[3, 4], [2, 3]]) 77 | m2 = M31([[3, -4], [-2, 3]]) 78 | med = matmul(a, m1) * 10 79 | o = matmul(med, m2) * 10 80 | assert o == M31([12300, 45600000]) 81 | a2 = ExtendedM31([[1,2,3,4],[5,6,7,800000]]) 82 | med2 = matmul(a2, m1, assume_second_input_small=True) * 10 83 | o2 = matmul(med2, m2) * 10 84 | assert o2 == ExtendedM31([[100,200,300,400],[500,600,700,80000000]]) 85 | print("Matrix multiplication tests passed") 86 | 87 | if __name__ == '__main__': 88 | test() 89 | -------------------------------------------------------------------------------- /zorch/binary/utils.py: -------------------------------------------------------------------------------- 1 | import cupy as cp 2 | 3 | MAX_SIZE = 32768 4 | 5 | # Multiply v1 * v2 in the binary tower field 6 | # See https://blog.lambdaclass.com/snarks-on-binary-fields-binius/ 7 | # for introduction to how binary tower fields work 8 | # 9 | # The general rule is that if i = b0b1...bk in binary, then the 10 | # index-i bit is the product of all x_i where b_i=1, eg. the 11 | # index-5 bit (32) is x_2 * x_0 12 | # 13 | # Multiplication involves multiplying these multivariate polynomials 14 | # as usual, but with the reduction rule that: 15 | # (x_0)^2 = x_0 + 1 16 | # (x_{i+1})^2 = x_{i+1} * x_i + 1 17 | 18 | def binmul(v1, v2, length=None): 19 | if v1 < 256 and v2 < 256 and rawmulcache[v1][v2] is not None: 20 | return rawmulcache[v1][v2] 21 | if v1 < 2 or v2 < 2: 22 | return v1 * v2 23 | if length is None: 24 | length = 1 << (max(v1, v2).bit_length() - 1).bit_length() 25 | halflen = length//2 26 | quarterlen = length//4 27 | halfmask = (1 << halflen)-1 28 | 29 | L1, R1 = v1 & halfmask, v1 >> halflen 30 | L2, R2 = v2 & halfmask, v2 >> halflen 31 | 32 | # Optimized special case (used to compute R1R2_high), sec III of 33 | # https://ieeexplore.ieee.org/document/612935 34 | if (L1, R1) == (0, 1): 35 | outR = binmul(1 << quarterlen, R2, halflen) ^ L2 36 | return R2 ^ (outR << halflen) 37 | 38 | # x_{i+1}^2 reduces to 1 + x_{i+1} * x_i 39 | # Uses Karatsuba to only require three sub-multiplications for each input 40 | # halving (R1R2_high doesn't count because of the above optimization) 41 | L1L2 = binmul(L1, L2, halflen) 42 | R1R2 = binmul(R1, R2, halflen) 43 | R1R2_high = binmul(1 << quarterlen, R1R2, halflen) 44 | Z3 = binmul(L1 ^ R1, L2 ^ R2, halflen) 45 | return ( 46 | L1L2 ^ 47 | R1R2 ^ 48 | ((Z3 ^ L1L2 ^ R1R2 ^ R1R2_high) << halflen) 49 | ) 50 | 51 | rawmulcache = [[None for _ in range(256)] for _ in range(256)] 52 | 53 | for i in range(256): 54 | for j in range(256): 55 | rawmulcache[i][j] = binmul(i, j) 56 | 57 | def build_mul_table_small(): 58 | table_low = cp.zeros((65536, 256), dtype=cp.uint16) 59 | table_high = cp.zeros((65536, 256), dtype=cp.uint16) 60 | 61 | for i in [2**x for x in range(16)]: 62 | top_p_of_2 = 0 63 | for j in range(1, 256): 64 | if (j & (j-1)) == 0: 65 | table_low[i, j] = binmul(i, j) 66 | table_high[i, j] = binmul(i, j << 8) 67 | top_p_of_2 = j 68 | else: 69 | for table in (table_low, table_high): 70 | table[i][j] = table[i][top_p_of_2] ^ table[i][j-top_p_of_2] 71 | 72 | for i in [2**x for x in range(1, 16)]: 73 | for table in (table_low, table_high): 74 | table[i:2*i] = table[i] ^ table[:i] 75 | 76 | return table_low, table_high 77 | 78 | def multiply_small(x, y): 79 | return mul_table[0][x, y & 255] ^ mul_table[1][x, y >> 8] 80 | 81 | mul_table = build_mul_table_small() 82 | mul = multiply_small 83 | print("Built multiplication table (low memory option)") 84 | 85 | assert mul(12345, 23456) == 65306 86 | 87 | # Build a table mapping x -> 1/x 88 | def build_inv_table(): 89 | output = cp.ones(65536, dtype=cp.uint16) 90 | exponents = cp.arange(0, 65536, 1, dtype=cp.uint16) 91 | for i in range(15): 92 | exponents = mul(exponents, exponents) 93 | output = mul(exponents, output) 94 | return output 95 | 96 | inv = build_inv_table() 97 | print("Built inversion table") 98 | 99 | assert mul(7890, inv[7890]) == 1 100 | 101 | # Convert a 128-bit integer into the field element representation we 102 | # use here, which is a length-8 vector of uint16's 103 | def int_to_bigbin(value): 104 | return cp.array( 105 | [(value >> (k*16)) & 65535 for k in range(8)], 106 | dtype=cp.uint16 107 | ) 108 | 109 | # Convert a uint16-representation big field element into an int 110 | def bigbin_to_int(value): 111 | return sum(int(x) << (16*i) for i,x in enumerate(value)) 112 | 113 | # Multiplying an element in the i'th level subfield by X_i can be done in 114 | # an optimized way. See sec III of https://ieeexplore.ieee.org/document/612935 115 | def mul_by_Xi(x, N): 116 | assert x.shape[-1] == N 117 | if N == 1: 118 | return mul(x, 256) 119 | L, R = x[..., :N//2], x[..., N//2:] 120 | outR = mul_by_Xi(R, N//2) ^ L 121 | return cp.concatenate((R, outR), axis=-1) 122 | 123 | # Multiplies together two field elements, using the Karatsuba algorithm 124 | def big_mul(x1, x2): 125 | N = x1.shape[-1] 126 | if N == 1: 127 | return mul(x1, x2) 128 | L1, L2 = x1[..., :N//2], x2[..., :N//2] 129 | R1, R2 = x1[..., N//2:], x2[..., N//2:] 130 | L1L2 = big_mul(L1, L2) 131 | R1R2 = big_mul(R1, R2) 132 | R1R2_high = mul_by_Xi(R1R2, N//2) 133 | Z3 = big_mul(L1 ^ R1, L2 ^ R2) 134 | o = cp.concatenate(( 135 | L1L2 ^ R1R2, 136 | (Z3 ^ L1L2 ^ R1R2 ^ R1R2_high) 137 | ), axis=-1) 138 | return o 139 | 140 | def zeros(shape): 141 | return cp.zeros(shape, dtype=cp.uint16) 142 | 143 | def array(x): 144 | return cp.array(x, dtype=cp.uint16) 145 | 146 | def arange(*args): 147 | return cp.arange(*args, dtype=cp.uint16) 148 | 149 | def append(*args, axis=0): 150 | return cp.concatenate((*args,), axis=axis) 151 | 152 | def tobytes(x): 153 | return x.tobytes() 154 | 155 | xor = cp.ReductionKernel( 156 | 'uint16 x', # input params 157 | 'uint16 y', # output params 158 | 'x', # map 159 | '(a ^ b)', # reduce 160 | 'y = a', # post-reduction map 161 | '0', # identity value 162 | 'xor' # kernel name 163 | ) 164 | 165 | def zeros_like(obj): 166 | if obj.__class__ == cp.ndarray: 167 | return cp.zeros_like(obj) 168 | else: 169 | return obj.__class__.zeros(obj.shape) 170 | -------------------------------------------------------------------------------- /zorch/m31/m31_numpy_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | modulus = 2**31 - 1 4 | _M31_u32 = np.uint32(modulus) 5 | _M31_u64 = np.uint64(modulus) 6 | 7 | 8 | # ---------- low-level helpers (vectorized) ---------- 9 | 10 | def _as_u32(x): 11 | return np.asarray(x, dtype=np.uint32) 12 | 13 | def _mod31(x): 14 | """ 15 | Fast Mersenne reduction for arrays/ints using uint64 intermediates, 16 | then fold twice and final % to be safe. 17 | """ 18 | x64 = x.astype(np.uint64, copy=False) 19 | t = (x64 & _M31_u64) + (x64 >> 31) 20 | t = (t & _M31_u64) + (t >> 31) 21 | t = t % _M31_u64 22 | return t.astype(np.uint32) 23 | 24 | def _addmod(a, b): 25 | return _mod31(a.astype(np.uint64) + b.astype(np.uint64)) 26 | 27 | def _submod(a, b): 28 | # (a - b) mod p == a + p - b (all uint64 to avoid wrap) 29 | return _mod31(a.astype(np.uint64) + _M31_u64 - b.astype(np.uint64)) 30 | 31 | def _mulmod(a, b): 32 | # exact via 64-bit product, then % p 33 | prod = a.astype(np.uint64) * b.astype(np.uint64) 34 | return (prod % _M31_u64).astype(np.uint32) 35 | 36 | def _shl1_mod(a): 37 | # (a << 1) mod p 38 | return ((a.astype(np.uint64) << 1) % _M31_u64).astype(np.uint32) 39 | 40 | 41 | # ---------- public elementwise ops ---------- 42 | 43 | def add(x, y): 44 | """ 45 | Elementwise (x + y) mod (2^31-1) 46 | """ 47 | x = _as_u32(x); y = _as_u32(y) 48 | return _addmod(x, y) 49 | 50 | def sub(x, y): 51 | """ 52 | Elementwise (x - y) mod (2^31-1) 53 | """ 54 | x = _as_u32(x); y = _as_u32(y) 55 | return _submod(x, y) 56 | 57 | def pow5(x): 58 | """ 59 | Elementwise x^5 mod (2^31-1) 60 | """ 61 | x = _as_u32(x) 62 | x2 = _mulmod(x, x) # x^2 63 | x4 = _mulmod(x2, x2) # x^4 64 | return _mulmod(x4, x) # x^5 65 | 66 | def sum(x, axis=0): 67 | """ 68 | Reduction sum(x) mod (2^31-1) -> uint32 scalar 69 | """ 70 | x = _as_u32(x) 71 | s = np.sum(x.astype(np.uint64) % _M31_u64, dtype=np.uint64, axis=axis) % _M31_u64 72 | return np.uint32(s) 73 | 74 | def mul(x, y): 75 | """ 76 | Elementwise (x * y) mod (2^31-1) 77 | """ 78 | x = _as_u32(x); y = _as_u32(y) 79 | return _mulmod(x, y) 80 | 81 | 82 | # ---------- extension-field style ops ---------- 83 | 84 | def _multiply_complex(A0, A1, B0, B1): 85 | """ 86 | Multiply (A0 + i*A1) * (B0 + i*B1) in F_p[i]/(i^2 = -1), 87 | all lanes are uint32 arrays under mod p. 88 | Returns (real, imag). 89 | """ 90 | low = _mulmod(A0, B0) 91 | high = _mulmod(A1, B1) 92 | med = _mulmod(_addmod(A0, A1), _addmod(B0, B1)) 93 | real = _submod(low, high) 94 | imag = _submod(med, _addmod(low, high)) 95 | return real, imag 96 | 97 | def mul_ext(x, y): 98 | """ 99 | Vectorized 'extension' multiply, operating on groups of 4 uint32s. 100 | The layout matches the original cupy kernel: 101 | x = [x0, x1, x2, x3, x0, x1, x2, x3, ...] 102 | y = [y0, y1, y2, y3, y0, y1, y2, y3, ...] 103 | Returns same-shape uint32 array. 104 | """ 105 | x = _as_u32(x); y = _as_u32(y) 106 | assert x.size % 4 == 0 and y.size % 4 == 0 107 | 108 | X = x.reshape(-1, 4) 109 | Y = y.reshape(-1, 4) 110 | 111 | x0, x1, x2, x3 = (X[:, 0], X[:, 1], X[:, 2], X[:, 3]) 112 | y0, y1, y2, y3 = (Y[:, 0], Y[:, 1], Y[:, 2], Y[:, 3]) 113 | 114 | # LL part 115 | LL_r, LL_i = _multiply_complex(x0, x1, y0, y1) 116 | 117 | # combined (Karatsuba middle) 118 | cA0 = _addmod(x0, x2) 119 | cA1 = _addmod(x1, x3) 120 | cB0 = _addmod(y0, y2) 121 | cB1 = _addmod(y1, y3) 122 | C_r, C_i = _multiply_complex(cA0, cA1, cB0, cB1) 123 | 124 | # RR part 125 | RR_r, RR_i = _multiply_complex(x2, x3, y2, y3) 126 | 127 | z0 = _addmod(_submod(LL_r, RR_r), _mod31(RR_i.astype(np.uint64) * 2)) 128 | z1 = _submod(_submod(LL_i, RR_i), _mod31(RR_r.astype(np.uint64) * 2)) 129 | z2 = _submod(C_r, _addmod(LL_r, RR_r)) 130 | z3 = _submod(C_i, _addmod(LL_i, RR_i)) 131 | 132 | new_shape = np.broadcast_shapes(x.shape, y.shape) 133 | Z = np.stack([z0, z1, z2, z3], axis=1).reshape(new_shape) 134 | return Z.astype(np.uint32) 135 | 136 | 137 | # ---------- modular inverse ---------- 138 | 139 | def _modinv_vectorized_base(x): 140 | """ 141 | Implements the same exponentiation schedule as the CUDA kernel: 142 | 143 | o = x 144 | pow = x^2 145 | for i in range(29): 146 | pow = pow^2 147 | o *= pow 148 | 149 | That evaluates x^(2^31 - 3) = x^(p-2) for p = 2^31 - 1. 150 | Returns uint32 array (0 maps to 0). 151 | """ 152 | x = _as_u32(x) 153 | o = x.copy() 154 | pow_x = _mulmod(x, x) # x^2 155 | for _ in range(29): 156 | pow_x = _mulmod(pow_x, pow_x) # square 157 | o = _mulmod(o, pow_x) # multiply 158 | return o 159 | 160 | def modinv(x): 161 | """ 162 | Elementwise modular inverse in F_p (p = 2^31-1). 163 | Matches the CUDA kernel behavior (0 -> 0). 164 | """ 165 | x = _as_u32(x) 166 | return _modinv_vectorized_base(x) 167 | 168 | 169 | # ---------- vectorized "extension" inverse on groups of 4 ---------- 170 | 171 | def modinv_ext(x): 172 | """ 173 | Vectorized inverse on 4-limb groups, matching the CUDA kernel layout. 174 | x is uint32 array with size % 4 == 0. 175 | """ 176 | x = _as_u32(x) 177 | assert x.size % 4 == 0, "Input length must be a multiple of 4" 178 | X = x.reshape(-1, 4) 179 | 180 | x0 = X[:, 0]; x1 = X[:, 1]; x2 = X[:, 2]; x3 = X[:, 3] 181 | 182 | x0_sq = _mulmod(x0, x0) 183 | x1_sq = _mulmod(x1, x1) 184 | x0x1 = _mulmod(x0, x1) 185 | 186 | x2_sq = _mulmod(x2, x2) 187 | x3_sq = _mulmod(x3, x3) 188 | x2x3 = _mulmod(x2, x3) 189 | 190 | r20 = _submod(x2_sq, x3_sq) 191 | r21 = _mod31((x2x3.astype(np.uint64) << 1)) # 2*x2x3 mod p 192 | 193 | # denom0 = (x0^2 - x1^2) + (r20 - 2*r21) (all mod p) 194 | t1 = _submod(x0_sq, x1_sq) 195 | t2 = _submod(r20, _shl1_mod(r21)) # r20 - 2*r21 196 | denom0 = _addmod(t1, t2) 197 | 198 | # denom1 = 2*x0x1 + r21 + 2*r20 (all mod p) 199 | denom1 = _addmod(_addmod(_shl1_mod(x0x1), r21), _shl1_mod(r20)) 200 | 201 | inv_denom_norm = modinv(_addmod(_mulmod(denom0, denom0), 202 | _mulmod(denom1, denom1))) 203 | inv_denom0 = _mulmod(denom0, inv_denom_norm) 204 | inv_denom1 = _mulmod(_submod(np.uint32(0), denom1), inv_denom_norm) # -denom1 * inv_norm 205 | 206 | z0 = _submod(_mulmod(x0, inv_denom0), _mulmod(x1, inv_denom1)) 207 | z1 = _addmod(_mulmod(x0, inv_denom1), _mulmod(x1, inv_denom0)) 208 | z2 = _submod(_mulmod(x3, inv_denom1), _mulmod(x2, inv_denom0)) 209 | z3 = _submod(np.uint32(0), _addmod(_mulmod(x2, inv_denom1), _mulmod(x3, inv_denom0))) 210 | 211 | Z = np.stack([z0, z1, z2, z3], axis=1).reshape(x.shape) 212 | return Z.astype(np.uint32) 213 | 214 | 215 | # ---------- small utility wrappers (NumPy versions) ---------- 216 | 217 | def zeros(shape): 218 | return np.zeros(shape, dtype=np.uint32) 219 | 220 | def array(x): 221 | return np.array(x, dtype=np.uint32) 222 | 223 | def arange(*args): 224 | return np.arange(*args, dtype=np.uint32) 225 | 226 | def append(*args, axis=0): 227 | return np.concatenate(args, axis=axis) 228 | 229 | def tobytes(x): 230 | return np.asarray(x).tobytes() 231 | 232 | def eq(x, y): 233 | x = _as_u32(x); y = _as_u32(y) 234 | return np.array_equal(x % _M31_u32, y % _M31_u32) 235 | 236 | def iszero(x): 237 | x = _as_u32(x) 238 | return not np.any(x % _M31_u32) 239 | 240 | def zeros_like(obj): 241 | if obj.__class__ == np.ndarray: 242 | return np.zeros_like(obj) 243 | else: 244 | return obj.__class__.zeros(obj.shape) 245 | -------------------------------------------------------------------------------- /zorch/koalabear/koalabear_numpy_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | # ------------------------ 4 | # Field parameters (KoalaBear) 5 | # ------------------------ 6 | modulus = 2**31 - 2**24 + 1 # p = 2,130,706,433 7 | _NR = np.uint64(3) # u^2 = 3 (in F_p) 8 | _ALPHA0 = np.uint64(1) # v^2 = ALPHA0 + ALPHA1 * u 9 | _ALPHA1 = np.uint64(1) 10 | modulus32 = np.uint32(2**31 - 2**24 + 1) # p = 2,130,706,433 11 | modulus64 = np.uint64(2**31 - 2**24 + 1) # p = 2,130,706,433 12 | 13 | # ------------------------ 14 | # Small helpers 15 | # ------------------------ 16 | def _u64(x): return x.astype(np.uint64, copy=False) 17 | def _modp(x): return (x % modulus).astype(np.uint32, copy=False) 18 | def _neg(x): return _modp(modulus - _u64(x)) 19 | 20 | # ------------------------ 21 | # Base-field ops (F_p) 22 | # ------------------------ 23 | def add(x, y): 24 | o = (x + y) 25 | return o - (o >= modulus32) * modulus32 26 | 27 | def sub(x, y): 28 | o = (x + modulus32 - y) 29 | return o - (o >= modulus32) * modulus32 30 | 31 | def mul(x, y): 32 | return _modp(_u64(x) * _u64(y)) 33 | 34 | def pow3(x): 35 | x = _u64(x) 36 | xx = (x * x) % modulus64 37 | return ((xx * x) % modulus64).astype(np.uint32, copy=False) 38 | 39 | def sum(x, axis=0): 40 | return np.uint32(np.sum(_u64(x), axis=axis) % modulus) 41 | 42 | def modinv(x): 43 | xf = _modp(_u64(x)) # ensure < p 44 | out = np.empty_like(xf, dtype=np.uint32) 45 | flat_x = xf.ravel() 46 | flat_o = out.ravel() 47 | for i in range(flat_x.size): 48 | xi = int(flat_x[i]) 49 | flat_o[i] = 0 if xi == 0 else pow(xi, int(modulus) - 2, int(modulus)) 50 | return out.reshape(xf.shape) 51 | 52 | # ------------------------ 53 | # Quadratic extension ops (F_p[u] with u^2 = _NR) 54 | # Represent a = a0 + a1*u as (..., 2) array [a0, a1] 55 | # ------------------------ 56 | def _cplx_add(A, B): 57 | return np.stack((_modp(_u64(A[...,0]) + _u64(B[...,0])), 58 | _modp(_u64(A[...,1]) + _u64(B[...,1]))), axis=-1) 59 | 60 | def _cplx_sub(A, B): 61 | return np.stack((_modp(_u64(A[...,0]) + modulus - _u64(B[...,0])), 62 | _modp(_u64(A[...,1]) + modulus - _u64(B[...,1]))), axis=-1) 63 | 64 | def _cplx_mul(A, B): 65 | a0, a1 = _u64(A[...,0]), _u64(A[...,1]) 66 | b0, b1 = _u64(B[...,0]), _u64(B[...,1]) 67 | r0 = _modp(a0 * b0 + _NR * a1 * b1) # (a0*b0 + NR*a1*b1) mod p 68 | r1 = _modp(a0 * b1 + a1 * b0) # (a0*b1 + a1*b0) mod p 69 | return np.stack((r0, r1), axis=-1) 70 | 71 | def _cplx_neg(A): 72 | return np.stack((_neg(A[...,0]), _neg(A[...,1])), axis=-1) 73 | 74 | def _cplx_square(A): 75 | # (a0 + a1 u)^2 = (a0^2 + NR*a1^2) + (2*a0*a1) u 76 | a0, a1 = _u64(A[...,0]), _u64(A[...,1]) 77 | r0 = _modp(a0*a0 + _NR*a1*a1) 78 | r1 = _modp((a0*a1) << 1) 79 | return np.stack((r0, r1), axis=-1) 80 | 81 | def _cplx_inv(A): 82 | # (a0 + a1 u)^(-1) = (a0 - a1 u) / (a0^2 - NR*a1^2) 83 | a0, a1 = _u64(A[...,0]), _u64(A[...,1]) 84 | denom = _modp(a0*a0 + modulus - (_NR*a1*a1) % modulus) # in F_p 85 | inv_d = modinv(denom) 86 | r0 = _modp(a0 * _u64(inv_d)) 87 | r1 = _modp((modulus - a1) * _u64(inv_d)) 88 | return np.stack((r0, r1), axis=-1) 89 | 90 | # constant alpha = _ALPHA0 + _ALPHA1*u as a 2-vector 91 | _ALPHA = np.array([np.uint32(_ALPHA0), np.uint32(_ALPHA1)], dtype=np.uint32) 92 | 93 | t31m1 = np.uint64(2**31-1) 94 | overflow = np.uint64(2**24 - 1) 95 | p = 2**31 - 2**24 + 1 96 | 97 | def weakmod(x): 98 | return ((x & t31m1) + overflow * (x >> 31)) 99 | 100 | def mul_ext(x, y): 101 | """ 102 | Multiply in F_p[X]/(X^4 - 3). 103 | Pack as (...,4): a0 + a1*X + a2*X^2 + a3*X^3. 104 | """ 105 | x = np.asarray(x, dtype=np.uint32) 106 | y = np.asarray(y, dtype=np.uint32) 107 | 108 | p = np.uint64(2**31 - 2**24 + 1) 109 | NR = np.uint64(3) # X^4 = 3 110 | 111 | X = x.astype(np.uint64, copy=False) 112 | Y = y.astype(np.uint64, copy=False) 113 | 114 | a0, a1, a2, a3 = X[...,0], X[...,1], X[...,2], X[...,3] 115 | b0, b1, b2, b3 = Y[...,0], Y[...,1], Y[...,2], Y[...,3] 116 | 117 | # 16 base-field muls 118 | ab00 = a0*b0; ab01 = a0*b1; ab02 = a0*b2; ab03 = a0*b3 119 | ab10 = a1*b0; ab11 = a1*b1; ab12 = a1*b2; ab13 = a1*b3 120 | ab20 = a2*b0; ab21 = a2*b1; ab22 = a2*b2; ab23 = a2*b3 121 | ab30 = a3*b0; ab31 = a3*b1; ab32 = a3*b2; ab33 = a3*b3 122 | 123 | z0 = ab00 + NR*weakmod(ab13 + ab22 + ab31) 124 | z1 = ab01 + ab10 + NR*weakmod(ab23 + ab32) 125 | z2 = ab02 + ab11 + ab20 + NR*weakmod(ab33) 126 | z3 = ab03 + ab12 + ab21 + ab30 127 | 128 | Z = np.empty(np.broadcast_shapes(x.shape, y.shape), dtype=np.uint32) 129 | Z[...,0] = (z0 % p) 130 | Z[...,1] = (z1 % p) 131 | Z[...,2] = (z2 % p) 132 | Z[...,3] = (z3 % p) 133 | return Z 134 | 135 | def modinv_ext(x): 136 | """ 137 | Inverse in F_p[X]/(X^4 - 3), packed as (...,4). 138 | Returns zeros where input is zero (no division-by-zero check). 139 | """ 140 | x = np.asarray(x, dtype=np.uint32) 141 | assert x.dtype == np.uint32 and x.size % 4 == 0 142 | 143 | p = np.uint64(2**31 - 2**24 + 1) 144 | NR = np.uint64(3) # Y^2 = 3 where Y = X^2 145 | 146 | X = x.reshape(-1, 4).astype(np.uint64, copy=False) 147 | a0, a1, a2, a3 = X[:,0], X[:,1], X[:,2], X[:,3] 148 | 149 | # K helpers: K = F_p[Y]/(Y^2=3) 150 | def k_add(r1, i1, r2, i2): 151 | srr = r1 + r2 152 | sii = i1 + i2 153 | srr -= p * (srr >= p) 154 | sii -= p * (sii >= p) 155 | return srr, sii 156 | 157 | def k_sub(r1, i1, r2, i2): 158 | drr = r1 + p - r2 159 | dii = i1 + p - i2 160 | drr -= p * (drr >= p) 161 | dii -= p * (dii >= p) 162 | return drr, dii 163 | 164 | def k_mul(r1, i1, r2, i2): 165 | rr = (r1*r2 + NR*i1*i2) % p 166 | ii = (r1*i2 + i1*r2) % p 167 | return rr, ii 168 | 169 | def k_sqr(r, i): 170 | # (r + iY)^2 = (r^2 + 3 i^2) + (2 r i) Y 171 | rr = (r*r + NR*i*i) % p 172 | ii = ((r << 1) * i) % p 173 | return rr, ii 174 | 175 | def k_mulY(rr, ii): 176 | return (NR*ii) % p, rr 177 | 178 | def inv_fp(z): 179 | return modinv(z) 180 | 181 | def k_inv(r, i): 182 | # (r + iY)^{-1} = (r - iY) / (r^2 - 3 i^2) 183 | denom = weakmod(r*r + p - (NR * (i*i) % p)) 184 | denom_inv = inv_fp(denom) 185 | cr = (r * denom_inv) % p 186 | ci = ((p - i) * denom_inv) % p 187 | return cr.astype(np.uint64), ci.astype(np.uint64) 188 | 189 | # Build A = a0 + a2 Y ; B = a1 + a3 Y in K 190 | Ar, Ai = a0, a2 191 | Br, Bi = a1, a3 192 | 193 | # denom = A^2 - Y*B^2 in K 194 | A2_r, A2_i = k_sqr(Ar, Ai) 195 | B2_r, B2_i = k_sqr(Br, Bi) 196 | YB2_r, YB2_i = k_mulY(B2_r, B2_i) 197 | denom_r, denom_i = k_sub(A2_r, A2_i, YB2_r, YB2_i) 198 | 199 | invd_r, invd_i = k_inv(denom_r, denom_i) 200 | 201 | # (A - X B) * inv_d in the outer quadratic 202 | # real = A * inv_d 203 | real_r, real_i = k_mul(Ar, Ai, invd_r, invd_i) 204 | # imag = (-B) * inv_d 205 | nBr = (p - Br) % p 206 | nBi = (p - Bi) % p 207 | imag_r, imag_i = k_mul(nBr, nBi, invd_r, invd_i) 208 | 209 | Z = np.empty_like(X, dtype=np.uint32) 210 | Z[:,0] = (real_r % p).astype(np.uint32) 211 | Z[:,2] = (real_i % p).astype(np.uint32) 212 | Z[:,1] = (imag_r % p).astype(np.uint32) 213 | Z[:,3] = (imag_i % p).astype(np.uint32) 214 | return Z.reshape(x.shape) 215 | 216 | # ------------------------ 217 | # Convenience wrappers (match the original API surface) 218 | # ------------------------ 219 | def zeros(shape): 220 | return np.zeros(shape, dtype=np.uint32) 221 | 222 | def array(x): 223 | return np.array(x, dtype=np.uint32) 224 | 225 | def arange(*args): 226 | return np.arange(*args, dtype=np.uint32) 227 | 228 | def append(*args, axis=0): 229 | return np.concatenate(tuple(args), axis=axis).astype(np.uint32, copy=False) 230 | 231 | def tobytes(x): 232 | return np.asarray(x, dtype=np.uint32).tobytes() 233 | 234 | def eq(x, y): 235 | return np.array_equal(_modp(_u64(x)), _modp(_u64(y))) 236 | 237 | def iszero(x): 238 | return not np.any(_modp(_u64(x))) 239 | 240 | def zeros_like(obj): 241 | if isinstance(obj, np.ndarray): 242 | return np.zeros_like(obj, dtype=np.uint32) 243 | else: 244 | return obj.__class__.zeros(obj.shape) 245 | -------------------------------------------------------------------------------- /zorch/m31/m31_utils.py: -------------------------------------------------------------------------------- 1 | import cupy as cp 2 | 3 | modulus = 2**31-1 4 | 5 | add = cp.ElementwiseKernel( 6 | 'uint32 x, uint32 y', # input argument list 7 | 'uint32 z', # output argument list 8 | 'z = (x + y); z = (z & 2147483647) + (z >> 31)', # loop body code 9 | 'add') # kernel name 10 | 11 | sub = cp.ElementwiseKernel( 12 | 'uint32 x, uint32 y', # input argument list 13 | 'uint32 z', # output argument list 14 | 'const unsigned int M31 = 2147483647; z = (x + M31 - y); z = (z & M31) + (z >> 31)', # loop body code 15 | 'sub') # kernel name 16 | 17 | pow5 = cp.ElementwiseKernel( 18 | 'uint32 x', # input argument list 19 | 'uint32 o', # output argument list 20 | preamble=''' 21 | const unsigned int M31 = 2147483647; 22 | 23 | unsigned int mulmod(unsigned int a, unsigned int b) { 24 | unsigned int z = (a * b); 25 | z = (z & M31) + (z >> 31) + __umulhi(a, b) * 2; 26 | return (z & M31) + (z >> 31); 27 | }; 28 | ''', 29 | operation=''' 30 | 31 | unsigned int xpow = mulmod(x, x); 32 | xpow = mulmod(xpow, xpow); 33 | o = mulmod(xpow, x); 34 | ''', 35 | name='pow5', # kernel name 36 | ) 37 | 38 | sum = cp.ReductionKernel( 39 | 'uint32 x', # input params 40 | 'uint32 y', # output params 41 | 'x', # map 42 | '(a + b) % 2147483647', # reduce 43 | 'y = a', # post-reduction map 44 | '0', # identity value 45 | 'sum' # kernel name 46 | ) 47 | 48 | mul = cp.ElementwiseKernel( 49 | 'uint32 x, uint32 y', # input argument list 50 | 'uint32 z', # output argument list 51 | ''' 52 | unsigned int z1 = (x * y); 53 | z = (z1 & 2147483647) + (z1 >> 31) + __umulhi(x, y) * 2; 54 | z = (z & 2147483647) + (z >> 31)''', # loop body code 55 | 'mul') # kernel name 56 | 57 | _mul_ext = cp.ElementwiseKernel( 58 | 'complex128 x, complex128 y', 59 | 'complex128 z', 60 | preamble=''' 61 | 62 | const unsigned int M31 = 2147483647; 63 | 64 | __device__ unsigned int submod(unsigned int x, unsigned int y) { 65 | unsigned int z = (x + M31 - y); 66 | return (z & M31) + (z >> 31); 67 | } 68 | 69 | __device__ unsigned int mulmod(unsigned int x, unsigned int y) { 70 | unsigned int z1 = (x * y); 71 | unsigned int z = (z1 & M31) + (z1 >> 31) + __umulhi(x, y) * 2; 72 | return (z & M31) + (z >> 31); 73 | } 74 | 75 | __device__ unsigned int mod31(unsigned int x) { 76 | return (x & M31) + (x >> 31); 77 | } 78 | 79 | __device__ void multiply_complex(unsigned int* o_r, 80 | unsigned int* o_i, 81 | unsigned int A0, 82 | unsigned int A1, 83 | unsigned int B0, 84 | unsigned int B1) { 85 | unsigned int low = mulmod(A0, B0); 86 | unsigned int high = mulmod(A1, B1); 87 | unsigned int med = mulmod(mod31(A0 + A1), mod31(B0 + B1)); 88 | *o_r = submod(low, high); 89 | *o_i = submod(med, mod31(low + high)); 90 | } 91 | 92 | ''', 93 | operation=''' 94 | 95 | thrust::complex _x = x; 96 | thrust::complex _y = y; 97 | 98 | unsigned int x0 = reinterpret_cast(&_x)[0]; 99 | unsigned int x1 = reinterpret_cast(&_x)[1]; 100 | unsigned int x2 = reinterpret_cast(&_x)[2]; 101 | unsigned int x3 = reinterpret_cast(&_x)[3]; 102 | 103 | unsigned int y0 = reinterpret_cast(&_y)[0]; 104 | unsigned int y1 = reinterpret_cast(&_y)[1]; 105 | unsigned int y2 = reinterpret_cast(&_y)[2]; 106 | unsigned int y3 = reinterpret_cast(&_y)[3]; 107 | 108 | unsigned int o_LL_r, o_LL_i; 109 | multiply_complex( 110 | &o_LL_r, &o_LL_i, 111 | x0, x1, 112 | y0, y1 113 | ); 114 | 115 | unsigned int o_comb_r, o_comb_i; 116 | multiply_complex( 117 | &o_comb_r, &o_comb_i, 118 | mod31(x0 + x2), mod31(x1 + x3), 119 | mod31(y0 + y2), mod31(y1 + y3) 120 | ); 121 | 122 | unsigned int o_RR_r, o_RR_i; 123 | multiply_complex( 124 | &o_RR_r, &o_RR_i, 125 | x2, x3, 126 | y2, y3 127 | ); 128 | 129 | reinterpret_cast(&z)[0] = mod31(submod(o_LL_r, o_RR_r) + mod31(o_RR_i * 2)); 130 | reinterpret_cast(&z)[1] = submod(submod(o_LL_i, o_RR_i), mod31(o_RR_r * 2)); 131 | reinterpret_cast(&z)[2] = submod(o_comb_r, mod31(o_LL_r + o_RR_r)); 132 | reinterpret_cast(&z)[3] = submod(o_comb_i, mod31(o_LL_i + o_RR_i)); 133 | 134 | ''' 135 | ) 136 | 137 | def mul_ext(x, y): 138 | xc128 = x.view(cp.complex128) 139 | yc128 = y.view(cp.complex128) 140 | zc128 = _mul_ext(xc128, yc128) 141 | return zc128.view(cp.uint32) 142 | 143 | kernel_code = r''' 144 | const unsigned int M31 = 2147483647; 145 | 146 | __device__ unsigned int mod31(unsigned int x) { 147 | return (x & M31) + (x >> 31); 148 | } 149 | 150 | __device__ unsigned int submod(unsigned int x, unsigned int y) { 151 | return mod31(x + M31 - y); 152 | } 153 | 154 | __device__ unsigned int mulmod(unsigned int x, unsigned int y) { 155 | unsigned int z1 = (x * y); 156 | unsigned int z2 = __umulhi(x, y); 157 | return mod31((z1 & M31) + (z1 >> 31) + z2 * 2); 158 | } 159 | 160 | 161 | __device__ unsigned int modinv(unsigned int x) { 162 | unsigned int o = x; 163 | unsigned int pow_of_x = mulmod(x, x); 164 | for (int i = 0; i < 29; i++) { 165 | pow_of_x = mulmod(pow_of_x, pow_of_x); 166 | o = mulmod(o, pow_of_x); 167 | } 168 | return o; 169 | }; 170 | 171 | extern "C" __global__ 172 | void vectorized_modinv(const unsigned int* x, 173 | unsigned int* z, 174 | int num_blocks) { 175 | int idx = blockIdx.x * blockDim.x + threadIdx.x; 176 | if (idx < num_blocks) { 177 | int base_idx = idx * 4; 178 | 179 | unsigned int x0_sq = mulmod(x[base_idx], x[base_idx]); 180 | unsigned int x1_sq = mulmod(x[base_idx + 1], x[base_idx + 1]); 181 | unsigned int x0x1 = mulmod(x[base_idx], x[base_idx + 1]); 182 | unsigned int x2_sq = mulmod(x[base_idx + 2], x[base_idx + 2]); 183 | unsigned int x3_sq = mulmod(x[base_idx + 3], x[base_idx + 3]); 184 | unsigned int x2x3 = mulmod(x[base_idx + 2], x[base_idx + 3]); 185 | unsigned int r20 = submod(x2_sq, x3_sq); 186 | unsigned int r21 = mod31(x2x3 << 1); 187 | unsigned int denom0 = mod31( 188 | submod(x0_sq, x1_sq) 189 | + submod(r20, mod31(r21 << 1)) 190 | ); 191 | unsigned int denom1 = mod31( 192 | mod31(mod31(x0x1 << 1) + r21) 193 | + mod31(r20 << 1) 194 | ); 195 | unsigned int inv_denom_norm = modinv(mod31( 196 | mulmod(denom0, denom0) + mulmod(denom1, denom1) 197 | )); 198 | unsigned int inv_denom0 = mulmod(denom0, inv_denom_norm); 199 | unsigned int inv_denom1 = mulmod(M31 - denom1, inv_denom_norm); 200 | 201 | z[base_idx] = submod( 202 | mulmod(x[base_idx], inv_denom0), 203 | mulmod(x[base_idx + 1], inv_denom1) 204 | ); 205 | z[base_idx + 1] = mod31( 206 | mulmod(x[base_idx], inv_denom1) 207 | + mulmod(x[base_idx + 1], inv_denom0) 208 | ); 209 | z[base_idx + 2] = submod( 210 | mulmod(x[base_idx + 3], inv_denom1), 211 | mulmod(x[base_idx + 2], inv_denom0) 212 | ); 213 | z[base_idx + 3] = M31 - mod31( 214 | mulmod(x[base_idx + 2], inv_denom1) 215 | + mulmod(x[base_idx + 3], inv_denom0) 216 | ); 217 | } 218 | } 219 | 220 | extern "C" __global__ 221 | void vectorized_basic_modinv(const unsigned int* x, 222 | unsigned int* z, 223 | int num_blocks) { 224 | 225 | int idx = blockIdx.x * blockDim.x + threadIdx.x; 226 | if (idx < num_blocks) { 227 | z[idx] = modinv(x[idx]); 228 | } 229 | } 230 | ''' 231 | 232 | # Load the kernel 233 | modinv_kernel = cp.RawKernel(kernel_code, 'vectorized_basic_modinv') 234 | modinv_ext_kernel = cp.RawKernel(kernel_code, 'vectorized_modinv') 235 | 236 | # Wrapper function 237 | def modinv(x): 238 | assert x.dtype == cp.uint32 239 | 240 | x_flat = x.ravel() 241 | z = cp.zeros_like(x_flat) 242 | 243 | num_blocks = x_flat.size 244 | threads_per_block = 256 245 | blocks_per_grid = (num_blocks + threads_per_block - 1) // threads_per_block 246 | 247 | modinv_kernel((blocks_per_grid,), (threads_per_block,), 248 | (x_flat, z, num_blocks)) 249 | 250 | return z.reshape(x.shape) 251 | 252 | # Wrapper function 253 | def modinv_ext(x): 254 | assert x.dtype == cp.uint32 255 | 256 | x_flat = x.ravel() 257 | z = cp.zeros_like(x_flat) 258 | 259 | num_blocks = x_flat.size // 4 260 | threads_per_block = 256 261 | blocks_per_grid = (num_blocks + threads_per_block - 1) // threads_per_block 262 | 263 | modinv_ext_kernel((blocks_per_grid,), (threads_per_block,), 264 | (x_flat, z, num_blocks)) 265 | 266 | return z.reshape(x.shape) 267 | 268 | def zeros(shape): 269 | return cp.zeros(shape, dtype=cp.uint32) 270 | 271 | def array(x): 272 | return cp.array(x, dtype=cp.uint32) 273 | 274 | def arange(*args): 275 | return cp.arange(*args, dtype=cp.uint32) 276 | 277 | def append(*args, axis=0): 278 | return cp.concatenate((*args,), axis=axis) 279 | 280 | def tobytes(x): 281 | return x.tobytes() 282 | 283 | def eq(x, y): 284 | return cp.array_equal(x % modulus, y % modulus) 285 | 286 | def iszero(x): 287 | return not cp.any(x % modulus) 288 | 289 | def zeros_like(obj): 290 | if obj.__class__ == cp.ndarray: 291 | return cp.zeros_like(obj) 292 | else: 293 | return obj.__class__.zeros(obj.shape) 294 | -------------------------------------------------------------------------------- /zorch/binary/binary_field.py: -------------------------------------------------------------------------------- 1 | from .utils import ( 2 | mul, zeros, arange, array, xor, inv, cp, big_mul 3 | ) 4 | 5 | class Binary(): 6 | def __init__(self, x): 7 | if isinstance(x, (int, list)): 8 | x = array(x) 9 | elif isinstance(x, Binary): 10 | x = x.value 11 | self.value = x 12 | assert self.value.dtype == cp.uint16 13 | 14 | @classmethod 15 | def zeros(cls, shape): 16 | return cls(zeros(shape)) 17 | 18 | @classmethod 19 | def arange(cls, *args): 20 | return cls(arange(*args)) 21 | 22 | @classmethod 23 | def append(cls, *args, axis=0): 24 | return cls(append(*(x.value for x in args), axis=axis)) 25 | 26 | @classmethod 27 | def sum(cls, arg, axis): 28 | return cls(xor(arg.value, axis=axis)) 29 | 30 | @property 31 | def shape(self): 32 | return self.value.shape 33 | 34 | def reshape(self, shape): 35 | return Binary(self.value.reshape(shape)) 36 | 37 | def swapaxes(self, ax1, ax2): 38 | return Binary(self.value.swapaxes(ax1, ax2)) 39 | 40 | def copy(self): 41 | return Binary(cp.copy(self.value)) 42 | 43 | @property 44 | def ndim(self): 45 | return self.value.ndim 46 | 47 | def __getitem__(self, index): 48 | return Binary(self.value[index]) 49 | 50 | def __setitem__(self, index, value): 51 | if isinstance(value, int): 52 | self.value[index] = value 53 | elif isinstance(value, Binary): 54 | self.value[index] = value.value 55 | else: 56 | raise Exception(f"Bad input for setitem: {value}") 57 | 58 | def to_extended(self, limbs): 59 | o = zeros(self.value.shape + (limbs,)) 60 | o[...,0] = self.value 61 | return ExtendedBinary(o) 62 | 63 | def __add__(self, other): 64 | raise Exception("Use xor with binary fields") 65 | 66 | def __neg__(self): 67 | raise Exception("x = -x in binary fields") 68 | 69 | def __sub__(self, other): 70 | raise Exception("Use xor with binary fields") 71 | 72 | def __xor__(self, other): 73 | if isinstance(other, ExtendedBinary): 74 | return self.to_extended(other.limbs) ^ other 75 | elif isinstance(other, int): 76 | other = Binary(other) 77 | return Binary(self.value ^ other.value) 78 | 79 | def __mul__(self, other): 80 | if isinstance(other, ExtendedBinary): 81 | return ExtendedBinary(mul( 82 | self.value.reshape(self.value.shape + (1,)), 83 | other.value 84 | )) 85 | elif isinstance(other, int): 86 | other = Binary(other) 87 | return Binary(mul(self.value, other.value)) 88 | 89 | def __pow__(self, other): 90 | assert isinstance(other, int) 91 | if other == 0: 92 | return Binary(cp.ones(self.value.shape)) 93 | elif other == 1: 94 | return self 95 | elif other % 2 == 1: 96 | sub = self ** (other // 2) 97 | return sub * sub * self 98 | else: 99 | sub = self ** (other // 2) 100 | return sub * sub 101 | 102 | def inv(self): 103 | return Binary(inv[self.value]) 104 | 105 | def __truediv__(self, other): 106 | if isinstance(other, int): 107 | other = Binary(other) 108 | return self * other.inv() 109 | 110 | def __rtruediv__(self, other): 111 | return self.inv() * other 112 | 113 | __rxor__ = __xor__ 114 | __radd__ = __add__ 115 | __rmul__ = __mul__ 116 | __rsub__ = __sub__ 117 | 118 | def __repr__(self): 119 | return f'{self.value}' 120 | 121 | def __int__(self): 122 | return int(self.value) 123 | 124 | def __len__(self): 125 | return len(self.value) 126 | 127 | def tobytes(self): 128 | return self.value.tobytes() 129 | 130 | def __eq__(self, other): 131 | if isinstance(other, int): 132 | other = Binary(other) 133 | elif isinstance(other, ExtendedBinary): 134 | return self.to_extended(other.limbs) == other 135 | shape = cp.broadcast_shapes(self.value.shape, other.value.shape) 136 | return cp.array_equal( 137 | cp.broadcast_to(self.value, shape), 138 | cp.broadcast_to(other.value, shape) 139 | ) 140 | 141 | def match_limbs(a, b): 142 | if a.shape[-1] == b.shape[-1]: 143 | return a, b 144 | elif a.shape[-1] < b.shape[-1]: 145 | padding = b.shape[-1] - a.shape[-1] 146 | pad_params = [(0,0)] * (a.ndim - 1) + [(0, padding)] 147 | return cp.pad(a, pad_params, mode='constant', constant_values=0), b 148 | else: 149 | padding = a.shape[-1] - b.shape[-1] 150 | pad_params = [(0,0)] * (b.ndim - 1) + [(0, padding)] 151 | return a, cp.pad(b, pad_params, mode='constant', constant_values=0) 152 | 153 | class ExtendedBinary(): 154 | def __init__(self, x): 155 | if isinstance(x, int): 156 | raise Exception("Initializing directly from int not supported") 157 | elif isinstance(x, list): 158 | x = array(x) 159 | elif isinstance(x, Binary): 160 | raise Exception("Initializing directly from Binary not supported") 161 | elif isinstance(x, ExtendedBinary): 162 | x = x.value 163 | assert x.shape[-1] & (x.shape[-1] - 1) == 0 164 | self.value = x 165 | assert self.value.dtype == cp.uint16 166 | 167 | @classmethod 168 | def zeros(cls, shape): 169 | return cls(zeros(shape + (4,))) 170 | 171 | @classmethod 172 | def append(cls, *args, axis=0): 173 | return cls(append(*(x.value for x in args), axis=axis)) 174 | 175 | @classmethod 176 | def sum(cls, arg, axis): 177 | adjusted_axis = axis if axis >= 0 else axis-1 178 | return cls(xor(arg.value, axis=adjusted_axis)) 179 | 180 | @property 181 | def shape(self): 182 | return self.value.shape[:-1] 183 | 184 | @property 185 | def limbs(self): 186 | return self.value.shape[-1] 187 | 188 | def reshape(self, shape): 189 | return ExtendedBinary(self.value.reshape(shape + self.shape[-1:])) 190 | 191 | def swapaxes(self, ax1, ax2): 192 | adjusted_ax1 = ax1 if ax1 >= 0 else ax1-1 193 | adjusted_ax2 = ax2 if ax2 >= 0 else ax2-1 194 | return ExtendedBinary(self.value.swapaxes(adjusted_ax1, adjusted_ax2)) 195 | 196 | def copy(self): 197 | return ExtendedBinary(cp.copy(self.value)) 198 | 199 | @property 200 | def ndim(self): 201 | return self.value.ndim - 1 202 | 203 | def to_extended(self): 204 | return self 205 | 206 | def __getitem__(self, index): 207 | return ExtendedBinary(self.value[index]) 208 | 209 | def __setitem__(self, index, value): 210 | if isinstance(value, int): 211 | self.value[index] = value 212 | elif isinstance(value, Binary): 213 | self.value[index] = value.to_extended(self.limbs).value 214 | elif isinstance(value, ExtendedBinary): 215 | self.value[index] = value.value 216 | else: 217 | raise Exception(f"Bad input for setitem: {value}") 218 | 219 | def __add__(self, other): 220 | raise Exception("Use xor with binary fields") 221 | 222 | def __neg__(self): 223 | raise Exception("x = -x in binary fields") 224 | 225 | def __sub__(self, other): 226 | raise Exception("Use xor with binary fields") 227 | 228 | def __xor__(self, other): 229 | if isinstance(other, Binary): 230 | other = other.to_extended(self.limbs) 231 | elif isinstance(other, int): 232 | other = Binary(other).to_extended(self.limbs) 233 | a, b = match_limbs(self.value, other.value) 234 | return ExtendedBinary(a ^ b) 235 | 236 | def __mul__(self, other): 237 | if isinstance(other, Binary): 238 | return ExtendedBinary(mul( 239 | self.value, 240 | other.value.reshape(other.value.shape + (1,)) 241 | )) 242 | elif isinstance(other, int): 243 | return ExtendedBinary(mul( 244 | self.value, 245 | cp.array(other, dtype=cp.uint16) 246 | )) 247 | a, b = match_limbs(self.value, other.value) 248 | return ExtendedBinary(big_mul(a, b)) 249 | 250 | def __pow__(self, other): 251 | assert isinstance(other, int) 252 | if other == 0: 253 | return ExtendedBinary(cp.ones(self.shape)) 254 | elif other == 1: 255 | return self 256 | elif other % 2 == 1: 257 | sub = self ** (other // 2) 258 | return sub * sub * self 259 | else: 260 | sub = self ** (other // 2) 261 | return sub * sub 262 | 263 | def inv(self): 264 | # Waaaaay under-optimized, todo fix 265 | power = 2**(16 * self.limbs) - 2 266 | return self ** power 267 | 268 | def __truediv__(self, other): 269 | if isinstance(other, int): 270 | other = Binary(other) 271 | return self * other.inv() 272 | 273 | def __rtruediv__(self, other): 274 | return self.inv() * other 275 | 276 | __rxor__ = __xor__ 277 | __radd__ = __add__ 278 | __rmul__ = __mul__ 279 | __rsub__ = __sub__ 280 | 281 | def __repr__(self): 282 | return f'{self.value}' 283 | 284 | def tobytes(self): 285 | return self.value.tobytes() 286 | 287 | def __len__(self): 288 | return len(self.value) 289 | 290 | def __eq__(self, other): 291 | if isinstance(other, int): 292 | other = Binary(other) 293 | if isinstance(other, Binary): 294 | other = other.to_extended(self.limbs) 295 | a, b = match_limbs(self.value, other.value) 296 | shape = cp.broadcast_shapes(a.shape, b.shape) 297 | return cp.array_equal( 298 | cp.broadcast_to(a, shape), cp.broadcast_to(b, shape) 299 | ) 300 | -------------------------------------------------------------------------------- /zorch/m31/m31_field.py: -------------------------------------------------------------------------------- 1 | try: 2 | import cupy 3 | from .m31_utils import ( 4 | zeros, arange, array, append, add, sub, mul, cp as np, pow5, modinv, mul_ext, 5 | modinv_ext, modulus, sum as m31_sum 6 | ) 7 | except: 8 | from .m31_numpy_utils import ( 9 | zeros, arange, array, append, add, sub, mul, np, pow5, modinv, mul_ext, 10 | modinv_ext, modulus, sum as m31_sum 11 | ) 12 | 13 | def mod31_py_obj(inp): 14 | if isinstance(inp, int): 15 | return inp % modulus 16 | else: 17 | return [mod31_py_obj(x) for x in inp] 18 | 19 | class M31(): 20 | def __init__(self, x): 21 | if isinstance(x, (int, list)): 22 | x = array(mod31_py_obj(x)) 23 | elif isinstance(x, M31): 24 | x = x.value 25 | self.value = x 26 | assert self.value.dtype == np.uint32 27 | 28 | @classmethod 29 | def zeros(cls, shape): 30 | return cls(zeros(shape)) 31 | 32 | @classmethod 33 | def arange(cls, *args): 34 | return cls(arange(*args)) 35 | 36 | @classmethod 37 | def append(cls, *args, axis=0): 38 | return cls(append(*(x.value for x in args), axis=axis)) 39 | 40 | @classmethod 41 | def sum(cls, arg, axis): 42 | return cls(m31_sum(arg.value, axis=axis)) 43 | 44 | @property 45 | def shape(self): 46 | return self.value.shape 47 | 48 | def reshape(self, shape): 49 | return M31(self.value.reshape(shape)) 50 | 51 | def swapaxes(self, ax1, ax2): 52 | return M31(self.value.swapaxes(ax1, ax2)) 53 | 54 | def copy(self): 55 | return M31(np.copy(self.value)) 56 | 57 | @property 58 | def ndim(self): 59 | return self.value.ndim 60 | 61 | def to_extended(self): 62 | o = zeros(self.value.shape + (4,)) 63 | o[...,0] = self.value 64 | return ExtendedM31(o) 65 | 66 | def __getitem__(self, index): 67 | return M31(self.value[index]) 68 | 69 | def __setitem__(self, index, value): 70 | if isinstance(value, int): 71 | self.value[index] = value 72 | elif isinstance(value, M31): 73 | self.value[index] = value.value 74 | else: 75 | raise Exception(f"Bad input for setitem: {value}") 76 | 77 | def __add__(self, other): 78 | if isinstance(other, ExtendedM31): 79 | return self.to_extended() + other 80 | elif isinstance(other, int): 81 | other = M31(other) 82 | return M31(add(self.value, other.value)) 83 | 84 | def __neg__(self): 85 | return M31(modulus - self.value) 86 | 87 | def __sub__(self, other): 88 | if isinstance(other, ExtendedM31): 89 | return self.to_extended() - other 90 | elif isinstance(other, int): 91 | other = M31(other) 92 | return M31(sub(self.value, other.value)) 93 | 94 | def __mul__(self, other): 95 | if isinstance(other, ExtendedM31): 96 | return ExtendedM31(mul( 97 | self.value.reshape(self.value.shape + (1,)), 98 | other.value 99 | )) 100 | elif isinstance(other, int): 101 | other = M31(other) 102 | return M31(mul(self.value, other.value)) 103 | 104 | def __pow__(self, other): 105 | assert isinstance(other, int) 106 | if other == 5: 107 | # Optimize common special case 108 | return M31(pow5(self.value)) 109 | elif other == 0: 110 | return M31(np.ones(self.value.shape)) 111 | elif other == 1: 112 | return self 113 | elif other % 2 == 1: 114 | sub = self ** (other // 2) 115 | return sub * sub * self 116 | else: 117 | sub = self ** (other // 2) 118 | return sub * sub 119 | 120 | def inv(self): 121 | return M31(modinv(self.value)) 122 | 123 | def __truediv__(self, other): 124 | if isinstance(other, int): 125 | other = M31(other) 126 | return self * other.inv() 127 | 128 | def __rtruediv__(self, other): 129 | return self.inv() * other 130 | 131 | __radd__ = __add__ 132 | __rmul__ = __mul__ 133 | __rsub__ = lambda self, other: -(self - other) 134 | 135 | def __repr__(self): 136 | return f'{self.value}' 137 | 138 | def __int__(self): 139 | return int(self.value) 140 | 141 | def __len__(self): 142 | return len(self.value) 143 | 144 | def tobytes(self): 145 | return (self.value % modulus).tobytes() 146 | 147 | def __eq__(self, other): 148 | if isinstance(other, int): 149 | other = M31(other) 150 | elif isinstance(other, ExtendedM31): 151 | return self.to_extended() == other 152 | shape = np.broadcast_shapes(self.value.shape, other.value.shape) 153 | return np.array_equal( 154 | np.broadcast_to(self.value, shape) % modulus, 155 | np.broadcast_to(other.value, shape) % modulus 156 | ) 157 | 158 | class ExtendedM31(): 159 | def __init__(self, x): 160 | if isinstance(x, int): 161 | x = array([x % modulus, 0, 0, 0]) 162 | elif isinstance(x, list): 163 | x = array(mod31_py_obj(x)) 164 | elif isinstance(x, M31): 165 | x = x.to_extended().value 166 | elif isinstance(x, ExtendedM31): 167 | x = x.value 168 | assert x.shape[-1] == 4 169 | self.value = x 170 | assert self.value.dtype == np.uint32 171 | 172 | @classmethod 173 | def zeros(cls, shape): 174 | return cls(zeros(shape + (4,))) 175 | 176 | @classmethod 177 | def append(cls, *args, axis=0): 178 | return cls(append(*(x.value for x in args), axis=axis)) 179 | 180 | @classmethod 181 | def sum(cls, arg, axis): 182 | adjusted_axis = axis if axis >= 0 else axis-1 183 | return cls(m31_sum(arg.value, axis=adjusted_axis)) 184 | 185 | @property 186 | def shape(self): 187 | return self.value.shape[:-1] 188 | 189 | def reshape(self, shape): 190 | return ExtendedM31(self.value.reshape(shape + (4,))) 191 | 192 | def swapaxes(self, ax1, ax2): 193 | adjusted_ax1 = ax1 if ax1 >= 0 else ax1-1 194 | adjusted_ax2 = ax2 if ax2 >= 0 else ax2-1 195 | return ExtendedM31(self.value.swapaxes(adjusted_ax1, adjusted_ax2)) 196 | 197 | def copy(self): 198 | return ExtendedM31(np.copy(self.value)) 199 | 200 | @property 201 | def ndim(self): 202 | return self.value.ndim - 1 203 | 204 | def to_extended(self): 205 | return self 206 | 207 | def __getitem__(self, index): 208 | return ExtendedM31(self.value[index]) 209 | 210 | def __setitem__(self, index, value): 211 | if isinstance(value, int): 212 | self.value[index] = value 213 | elif isinstance(value, M31): 214 | self.value[index] = value.to_extended().value 215 | elif isinstance(value, ExtendedM31): 216 | self.value[index] = value.value 217 | else: 218 | raise Exception(f"Bad input for setitem: {value}") 219 | 220 | def __add__(self, other): 221 | if isinstance(other, M31): 222 | other = other.to_extended() 223 | elif isinstance(other, int): 224 | other = ExtendedM31(other) 225 | return ExtendedM31(add(self.value, other.value)) 226 | 227 | def __neg__(self): 228 | return ExtendedM31(modulus - self.value) 229 | 230 | def __sub__(self, other): 231 | if isinstance(other, M31): 232 | other = other.to_extended() 233 | elif isinstance(other, int): 234 | other = ExtendedM31(other) 235 | return ExtendedM31(sub(self.value, other.value)) 236 | 237 | def __mul__(self, other): 238 | if isinstance(other, int): 239 | other = M31(other) 240 | if isinstance(other, M31): 241 | return ExtendedM31(mul( 242 | self.value, 243 | other.value.reshape(other.value.shape + (1,)) 244 | )) 245 | return ExtendedM31(mul_ext(self.value, other.value)) 246 | 247 | def __pow__(self, other): 248 | assert isinstance(other, int) 249 | if other == 0: 250 | return M31(np.ones(self.shape)).to_extended() 251 | elif other == 1: 252 | return self 253 | elif other % 2 == 1: 254 | sub = self ** (other // 2) 255 | return sub * sub * self 256 | else: 257 | sub = self ** (other // 2) 258 | return sub * sub 259 | 260 | def inv(self): 261 | return ExtendedM31(modinv_ext(self.value)) 262 | 263 | def __truediv__(self, other): 264 | if isinstance(other, int): 265 | other = M31(other) 266 | return self * other.inv() 267 | 268 | def __rtruediv__(self, other): 269 | return self.inv() * other 270 | 271 | __radd__ = __add__ 272 | __rmul__ = __mul__ 273 | __rsub__ = lambda self, other: -(self - other) 274 | 275 | def __repr__(self): 276 | return f'{self.value}' 277 | 278 | def tobytes(self): 279 | return (self.value % modulus).tobytes() 280 | 281 | def __len__(self): 282 | return len(self.value) 283 | 284 | def __eq__(self, other): 285 | if isinstance(other, int): 286 | other = M31(other) 287 | if isinstance(other, M31): 288 | other = other.to_extended() 289 | shape = np.broadcast_shapes(self.value.shape, other.value.shape) 290 | return np.array_equal( 291 | np.broadcast_to(self.value, shape) % modulus, 292 | np.broadcast_to(other.value, shape) % modulus 293 | ) 294 | 295 | def matmul(a, b, assume_second_input_small=False): 296 | if not isinstance(a, (M31, ExtendedM31)): 297 | raise Exception("First input must be M31 or extended M31") 298 | elif not isinstance(b, M31): 299 | raise Exception("Second input must be M31") 300 | a_value = a.value if isinstance(a, M31) else a.value.swapaxes(-2, -1) 301 | if assume_second_input_small: 302 | data1 = np.matmul(a_value & 65535, b.value) 303 | data2 = np.matmul(a_value >> 16, b.value) 304 | o = add(data1, mul(data2, array(65536))) 305 | else: 306 | data1 = a_value.astype(np.uint64) 307 | data2 = b.value.astype(np.uint64) 308 | o1 = np.matmul(data1 & 65535, data2) 309 | o2 = np.matmul(data1 >> 16, data2) 310 | o = ((o1 + ((o2 % modulus) << 16)) % modulus).astype(np.uint32) 311 | if isinstance(a, M31): 312 | return M31(o) 313 | else: 314 | return ExtendedM31(o.swapaxes(-2, -1)) 315 | -------------------------------------------------------------------------------- /zorch/koalabear/koalabear_field.py: -------------------------------------------------------------------------------- 1 | from .koalabear_numpy_utils import ( 2 | zeros, arange, array, append, add, sub, mul, np, pow3, modinv, mul_ext, 3 | modinv_ext, modulus, sum as m31_sum 4 | ) 5 | 6 | def mod31_py_obj(inp): 7 | if isinstance(inp, int): 8 | return inp % modulus 9 | else: 10 | return [mod31_py_obj(x) for x in inp] 11 | 12 | class KoalaBear(): 13 | def __init__(self, x): 14 | if isinstance(x, (int, list)): 15 | x = array(mod31_py_obj(x)) 16 | elif isinstance(x, KoalaBear): 17 | x = x.value 18 | self.value = x 19 | assert self.value.dtype == np.uint32 20 | 21 | @classmethod 22 | def zeros(cls, shape): 23 | return cls(zeros(shape)) 24 | 25 | @classmethod 26 | def arange(cls, *args): 27 | return cls(arange(*args)) 28 | 29 | @classmethod 30 | def append(cls, *args, axis=0): 31 | return cls(append(*(x.value for x in args), axis=axis)) 32 | 33 | @classmethod 34 | def sum(cls, arg, axis): 35 | return cls(m31_sum(arg.value, axis=axis)) 36 | 37 | @property 38 | def shape(self): 39 | return self.value.shape 40 | 41 | def reshape(self, shape): 42 | return KoalaBear(self.value.reshape(shape)) 43 | 44 | def swapaxes(self, ax1, ax2): 45 | return KoalaBear(self.value.swapaxes(ax1, ax2)) 46 | 47 | def copy(self): 48 | return KoalaBear(np.copy(self.value)) 49 | 50 | @property 51 | def ndim(self): 52 | return self.value.ndim 53 | 54 | def to_extended(self): 55 | o = zeros(self.value.shape + (4,)) 56 | o[...,0] = self.value 57 | return ExtendedKoalaBear(o) 58 | 59 | def __getitem__(self, index): 60 | return KoalaBear(self.value[index]) 61 | 62 | def __setitem__(self, index, value): 63 | if isinstance(value, int): 64 | self.value[index] = value 65 | elif isinstance(value, KoalaBear): 66 | self.value[index] = value.value 67 | else: 68 | raise Exception(f"Bad input for setitem: {value}") 69 | 70 | def __add__(self, other): 71 | if isinstance(other, ExtendedKoalaBear): 72 | return self.to_extended() + other 73 | elif isinstance(other, int): 74 | other = KoalaBear(other) 75 | return KoalaBear(add(self.value, other.value)) 76 | 77 | def __neg__(self): 78 | return KoalaBear(modulus - self.value) 79 | 80 | def __sub__(self, other): 81 | if isinstance(other, ExtendedKoalaBear): 82 | return self.to_extended() - other 83 | elif isinstance(other, int): 84 | other = KoalaBear(other) 85 | return KoalaBear(sub(self.value, other.value)) 86 | 87 | def __mul__(self, other): 88 | if isinstance(other, ExtendedKoalaBear): 89 | return ExtendedKoalaBear(mul( 90 | self.value.reshape(self.value.shape + (1,)), 91 | other.value 92 | )) 93 | elif isinstance(other, int): 94 | other = KoalaBear(other) 95 | return KoalaBear(mul(self.value, other.value)) 96 | 97 | def __pow__(self, other): 98 | assert isinstance(other, int) 99 | if other == 3: 100 | # Optimize common special case 101 | return KoalaBear(pow3(self.value)) 102 | elif other == 0: 103 | return KoalaBear(np.ones(self.value.shape)) 104 | elif other == 1: 105 | return self 106 | elif other % 2 == 1: 107 | sub = self ** (other // 2) 108 | return sub * sub * self 109 | else: 110 | sub = self ** (other // 2) 111 | return sub * sub 112 | 113 | def inv(self): 114 | return KoalaBear(modinv(self.value)) 115 | 116 | def __truediv__(self, other): 117 | if isinstance(other, int): 118 | other = KoalaBear(other) 119 | return self * other.inv() 120 | 121 | def __rtruediv__(self, other): 122 | return self.inv() * other 123 | 124 | __radd__ = __add__ 125 | __rmul__ = __mul__ 126 | __rsub__ = lambda self, other: -(self - other) 127 | 128 | def __repr__(self): 129 | return f'{self.value}' 130 | 131 | def __int__(self): 132 | return int(self.value) 133 | 134 | def __len__(self): 135 | return len(self.value) 136 | 137 | def tobytes(self): 138 | return (self.value % modulus).tobytes() 139 | 140 | def __eq__(self, other): 141 | if isinstance(other, int): 142 | other = KoalaBear(other) 143 | elif isinstance(other, ExtendedKoalaBear): 144 | return self.to_extended() == other 145 | shape = np.broadcast_shapes(self.value.shape, other.value.shape) 146 | return np.array_equal( 147 | np.broadcast_to(self.value, shape) % modulus, 148 | np.broadcast_to(other.value, shape) % modulus 149 | ) 150 | 151 | class ExtendedKoalaBear(): 152 | def __init__(self, x): 153 | if isinstance(x, int): 154 | x = array([x % modulus, 0, 0, 0]) 155 | elif isinstance(x, list): 156 | x = array(mod31_py_obj(x)) 157 | elif isinstance(x, KoalaBear): 158 | x = x.to_extended().value 159 | elif isinstance(x, ExtendedKoalaBear): 160 | x = x.value 161 | assert x.shape[-1] == 4 162 | self.value = x 163 | assert self.value.dtype == np.uint32 164 | 165 | @classmethod 166 | def zeros(cls, shape): 167 | return cls(zeros(shape + (4,))) 168 | 169 | @classmethod 170 | def append(cls, *args, axis=0): 171 | return cls(append(*(x.value for x in args), axis=axis)) 172 | 173 | @classmethod 174 | def sum(cls, arg, axis): 175 | adjusted_axis = axis if axis >= 0 else axis-1 176 | return cls(m31_sum(arg.value, axis=adjusted_axis)) 177 | 178 | @property 179 | def shape(self): 180 | return self.value.shape[:-1] 181 | 182 | def reshape(self, shape): 183 | return ExtendedKoalaBear(self.value.reshape(shape + (4,))) 184 | 185 | def swapaxes(self, ax1, ax2): 186 | adjusted_ax1 = ax1 if ax1 >= 0 else ax1-1 187 | adjusted_ax2 = ax2 if ax2 >= 0 else ax2-1 188 | return ExtendedKoalaBear(self.value.swapaxes(adjusted_ax1, adjusted_ax2)) 189 | 190 | def copy(self): 191 | return ExtendedKoalaBear(np.copy(self.value)) 192 | 193 | @property 194 | def ndim(self): 195 | return self.value.ndim - 1 196 | 197 | def to_extended(self): 198 | return self 199 | 200 | def __getitem__(self, index): 201 | return ExtendedKoalaBear(self.value[index]) 202 | 203 | def __setitem__(self, index, value): 204 | if isinstance(value, int): 205 | self.value[index] = value 206 | elif isinstance(value, KoalaBear): 207 | self.value[index] = value.to_extended().value 208 | elif isinstance(value, ExtendedKoalaBear): 209 | self.value[index] = value.value 210 | else: 211 | raise Exception(f"Bad input for setitem: {value}") 212 | 213 | def __add__(self, other): 214 | if isinstance(other, KoalaBear): 215 | other = other.to_extended() 216 | elif isinstance(other, int): 217 | other = ExtendedKoalaBear(other) 218 | return ExtendedKoalaBear(add(self.value, other.value)) 219 | 220 | def __neg__(self): 221 | return ExtendedKoalaBear(modulus - self.value) 222 | 223 | def __sub__(self, other): 224 | if isinstance(other, KoalaBear): 225 | other = other.to_extended() 226 | elif isinstance(other, int): 227 | other = ExtendedKoalaBear(other) 228 | return ExtendedKoalaBear(sub(self.value, other.value)) 229 | 230 | def __mul__(self, other): 231 | if isinstance(other, int): 232 | other = KoalaBear(other) 233 | if isinstance(other, KoalaBear): 234 | return ExtendedKoalaBear(mul( 235 | self.value, 236 | other.value.reshape(other.value.shape + (1,)) 237 | )) 238 | return ExtendedKoalaBear(mul_ext(self.value, other.value)) 239 | 240 | def __pow__(self, other): 241 | assert isinstance(other, int) 242 | if other == 0: 243 | return KoalaBear(np.ones(self.shape)).to_extended() 244 | elif other == 1: 245 | return self 246 | elif other % 2 == 1: 247 | sub = self ** (other // 2) 248 | return sub * sub * self 249 | else: 250 | sub = self ** (other // 2) 251 | return sub * sub 252 | 253 | def inv(self): 254 | return ExtendedKoalaBear(modinv_ext(self.value)) 255 | 256 | def __truediv__(self, other): 257 | if isinstance(other, int): 258 | other = KoalaBear(other) 259 | return self * other.inv() 260 | 261 | def __rtruediv__(self, other): 262 | return self.inv() * other 263 | 264 | __radd__ = __add__ 265 | __rmul__ = __mul__ 266 | __rsub__ = lambda self, other: -(self - other) 267 | 268 | def __repr__(self): 269 | return f'{self.value}' 270 | 271 | def tobytes(self): 272 | return (self.value % modulus).tobytes() 273 | 274 | def __len__(self): 275 | return len(self.value) 276 | 277 | def __eq__(self, other): 278 | if isinstance(other, int): 279 | other = KoalaBear(other) 280 | if isinstance(other, KoalaBear): 281 | other = other.to_extended() 282 | shape = np.broadcast_shapes(self.value.shape, other.value.shape) 283 | return np.array_equal( 284 | np.broadcast_to(self.value, shape) % modulus, 285 | np.broadcast_to(other.value, shape) % modulus 286 | ) 287 | 288 | def matmul(a, b, assume_second_input_small=False): 289 | if not isinstance(a, (KoalaBear, ExtendedKoalaBear)): 290 | raise Exception("First input must be KoalaBear or extended KoalaBear") 291 | if not isinstance(b, (KoalaBear, ExtendedKoalaBear)): 292 | raise Exception("Second input must be KoalaBear or extended KoalaBear") 293 | if isinstance(a, ExtendedKoalaBear) and isinstance(b, ExtendedKoalaBear): 294 | raise Exception("inputs cannot both be extended") 295 | a_value = a.value if isinstance(a, KoalaBear) else a.value.swapaxes(-2, -1) 296 | if assume_second_input_small: 297 | data1 = np.matmul(a_value & 65535, b.value) 298 | data2 = np.matmul(a_value >> 16, b.value) 299 | o = add(data1, mul(data2, array(65536))) 300 | else: 301 | data1 = a_value.astype(np.uint64) 302 | data2 = b.value.astype(np.uint64) 303 | o1 = np.matmul(data1 & 65535, data2) 304 | o2 = np.matmul(data1 >> 16, data2) 305 | o = ((o1 + ((o2 % modulus) << 16)) % modulus).astype(np.uint32) 306 | if isinstance(a, KoalaBear) and isinstance(b, KoalaBear): 307 | return KoalaBear(o) 308 | elif isinstance(a, ExtendedKoalaBear): 309 | return ExtendedKoalaBear(o.swapaxes(-2, -1)) 310 | elif isinstance(b, ExtendedKoalaBear): 311 | return ExtendedKoalaBear(o) 312 | else: 313 | raise Exception("wat") 314 | --------------------------------------------------------------------------------