├── .gitignore ├── accumulator.py └── contracts └── MillerRabin.sol /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | -------------------------------------------------------------------------------- /accumulator.py: -------------------------------------------------------------------------------- 1 | from os import urandom 2 | from math import log2, ceil, floor 3 | from Crypto.Util import number 4 | from hashlib import sha256 5 | from typing import Tuple, Union, Iterable 6 | from collections import namedtuple 7 | from dataclasses import dataclass 8 | import operator 9 | 10 | # https://eprint.iacr.org/2018/1188.pdf 11 | 12 | RSA_2048 = 25195908475657893494027183240048398571429282126204032027777137836043662020707595556264018525880784406918290641249515082189298559149176184502808489120072844992687392807287776735971418347270261896375014971824691165077613379859095700097330459748808428401797429100642458691817195118746121515172654632282216869987549182422433637259085141865462043576798423387184774447920739934236584823824281198163815010674810451660377306056201619676256133844143603833904414952634432190114657544454178424020924616515723350778707749817125772467962926386356373289912154831438167899885040445364023527381951378636564391212010397122822120720357 13 | 14 | HASH_SIZE_BYTES = 32 15 | HASH_SIZE_BITS = HASH_SIZE_BYTES * 8 16 | HASH = sha256 17 | 18 | ADDITIVE_IDENTITY = 0 19 | MULTIPLICATIVE_IDENTITY = 1 20 | 21 | 22 | def egcd(a: int, b: int) -> Tuple[int, int, int]: 23 | u, u1 = 1, 0 24 | v, v1 = 0, 1 25 | g, g1 = a, b 26 | while g1: 27 | q = g // g1 28 | u, u1 = u1, u - q * u1 29 | v, v1 = v1, v - q * v1 30 | g, g1 = g1, g - q * g1 31 | return g, u, v 32 | 33 | 34 | def modinv(a: int, m: int) -> int: 35 | g, x, y = egcd(a, m) 36 | if g != 1: 37 | raise Exception('modular inverse does not exist') 38 | return x % m 39 | 40 | 41 | def number_bits_bytes_ceil(num: int, max_bits: int = None) -> Tuple[int, int]: 42 | n_bits = ceil(log2(num)) 43 | if max_bits is not None and n_bits > max_bits: 44 | n_bits = max_bits 45 | n_bytes = (n_bits + (8 - (n_bits % 8))) // 8 46 | return n_bits, n_bytes 47 | 48 | 49 | def number_bits_bytes_floor(num: int, max_bits: int = None) -> Tuple[int, int]: 50 | n_bits = floor(log2(num)) 51 | if max_bits is not None and n_bits > max_bits: 52 | n_bits = max_bits 53 | n_bytes = (n_bits - (n_bits % 8)) // 8 54 | return n_bits, n_bytes 55 | 56 | 57 | def random_prime(n_bits: int, easy_sqrt: bool = False) -> int: 58 | while True: 59 | n = number.getPrime(n_bits, urandom) 60 | if not easy_sqrt: 61 | return n 62 | # Congruency to 3 mod 4, means `sqrt(n) <- n^{{p+1}/4}` 63 | if n % 4 == 3: 64 | return n 65 | 66 | 67 | def random_rsa_modulus(n_bits: int, easy_sqrt: bool = False) -> int: 68 | p = random_prime(n_bits, easy_sqrt) 69 | q = random_prime(n_bits, easy_sqrt) 70 | return p * q 71 | 72 | 73 | def Bezout(x: int, y: int) -> Tuple[int, int]: 74 | _, a, b = egcd(x, y) 75 | return a, b 76 | 77 | 78 | def ShamirTrick(w1: int, w2: int, x: int, y: int) -> int: 79 | w1_x = pow(w1, x) 80 | w2_y = pow(w2, y) 81 | if w1_x == w2_y: 82 | raise RuntimeError("Equivalent") 83 | a, b = Bezout(x, y) 84 | return w1_x * w2_y 85 | 86 | 87 | def xor_bytes(a: bytes, b: bytes) -> bytes: 88 | assert len(a) == len(b) 89 | return bytes(x ^ y for x, y in zip(a, b)) 90 | 91 | 92 | class HashSet(object): 93 | """ 94 | Maintains a set of items, which can be succinctly represented as a single 32 byte value 95 | This is used to reduce the overhead of hashing all of the items each time you need that representation 96 | While still making it easy/fast to update the set without having to re-hash all of the items again 97 | e.g. turns an `O(n)` operation to re-hash the set into an `O(1)` operation. 98 | """ 99 | def __init__(self, items=None): 100 | self._guid = b'\0' * HASH_SIZE_BYTES 101 | if items: 102 | for item in items: 103 | self.add(item) 104 | 105 | @property 106 | def guid(self) -> bytes: 107 | return self._guid 108 | 109 | @classmethod 110 | def as_bytes(cls, item) -> bytes: 111 | if isinstance(item, bytes): 112 | return item 113 | elif isinstance(item, (int, GroupElement)): 114 | if isinstance(item, GroupElement): 115 | n_bits = ceil(log2(item.modulus)) 116 | value = item.value 117 | else: 118 | n_bits = ceil(log2(item)) 119 | value = item 120 | n_bits += 8 - (n_bits % 8) 121 | n_bytes = n_bits // 8 122 | return int.to_bytes(value, n_bytes, 'little') 123 | raise TypeError(item) 124 | 125 | @classmethod 126 | def hash_item(cls, item) -> bytes: 127 | item_bytes = cls.as_bytes(item) 128 | result = HASH(item_bytes).digest() 129 | return result 130 | 131 | def add(self, item) -> bytes: 132 | hashed_item = self.hash_item(item) 133 | self._guid = xor_bytes(self._guid, hashed_item) 134 | return hashed_item 135 | 136 | def remove(self, item) -> bytes: 137 | hashed_item = self.hash_item(item) 138 | self._guid = xor_bytes(self._guid, hashed_item) 139 | return hashed_item 140 | 141 | 142 | class PrimeSet(object): 143 | """ 144 | A set of prime numbers (derived from items), where the product of the primes 145 | is used as its representation. 146 | """ 147 | 148 | def __init__(self, items=None, n_bytes: int = HASH_SIZE_BYTES): 149 | self._n_bytes = n_bytes 150 | self._product = MULTIPLICATIVE_IDENTITY 151 | if items: 152 | for item in items: 153 | self.add(items) 154 | 155 | def __contains__(self, item): 156 | return (self._product % item) == 0 157 | 158 | def as_prime(self, item) -> int: 159 | """ 160 | Consistently convert 'raw elements' into prime numbers 161 | suitable for inclusion in an RSA accumulator. 162 | """ 163 | return hash2prime(item, output_bytes=self._n_bytes) 164 | 165 | @property 166 | def product(self) -> int: 167 | return self._product 168 | 169 | def add_prime(self, prime_item) -> int: 170 | assert prime_item not in self # ignored when optimised 171 | assert number.isPrime(prime_item) # ignored when optimised 172 | self._product *= prime_item 173 | return prime_item 174 | 175 | def add(self, item) -> int: 176 | return self.add_prime(self.as_prime(item)) 177 | 178 | def remove_prime(self, prime_item) -> int: 179 | assert prime_item in self # ignored when optimised 180 | assert number.isPrime(prime_item) # ignored when optimised 181 | self._product = self._product // prime_item 182 | return prime_item 183 | 184 | def remove(self, item) -> int: 185 | return self.remove_prime(self.as_prime(item)) 186 | 187 | 188 | class PrimeHashSet(object): 189 | def __init__(self, items=None, n_bytes=HASH_SIZE_BYTES): 190 | self._ps = PrimeSet(items, n_bytes=n_bytes) 191 | self._hs = HashSet(items) 192 | self._items = set() 193 | 194 | def __contains__(self, item): 195 | return item in self._items 196 | 197 | def as_prime(self, *args) -> int: 198 | return self._ps.as_prime(*args) 199 | 200 | def __iter__(self): 201 | return iter(self._items) 202 | 203 | @property 204 | def guid(self) -> bytes: 205 | return self._hs.guid 206 | 207 | @property 208 | def product(self) -> int: 209 | return self._ps.product 210 | 211 | def add(self, item) -> int: 212 | prime_item = self._ps.add(item) 213 | self._hs.add(item) 214 | self._items.add(item) 215 | return prime_item 216 | 217 | def remove(self, item) -> int: 218 | prime_item = self._ps.remove(item) 219 | self._hs.remove(item) 220 | self._items.remove(prime_item) 221 | return prime_item 222 | 223 | 224 | class GroupElement(object): 225 | def __init__(self, value: int, modulus: int): 226 | self.value = value 227 | self.modulus = modulus 228 | 229 | @classmethod 230 | def as_int(cls, other: Union[int, '__class__']): 231 | if isinstance(other, cls): 232 | return other.value 233 | assert isinstance(other, int) 234 | return other 235 | 236 | def hash_to_group(self, *args) -> '__class__': 237 | new_value = hash2intmod(self.modulus, *args) 238 | return __class__(new_value, self.modulus) 239 | 240 | @classmethod 241 | def generator(cls, modulus) -> '__class__': 242 | return cls(2, modulus) 243 | 244 | def as_bytes(self, endian='little') -> bytes: 245 | n_bits, n_bytes = number_bits_bytes_ceil(self.modulus) 246 | return int.to_bytes(self.value, n_bytes, endian) 247 | 248 | def __repr__(self): 249 | group_name = "RSA-2048" if self.modulus == RSA_2048 else str(self.modulus) 250 | return "%s<%d mod %s>" % (type(self).__name__, self.value, group_name) 251 | 252 | def __div__(self, other: Union[int, '__class__']) -> '__class__': 253 | value = self.as_int(other) 254 | base = modinv(self.value, self.modulus) 255 | result = pow(base, value, self.modulus) 256 | return __class__(result, modulus) 257 | 258 | def __pow__(self, other: Union[int, '__class__'], modulus: int = None) -> '__class__': 259 | if modulus is not None: 260 | if modulus != self.modulus: 261 | raise RuntimeError("Cannot exponentiate using different modulus") 262 | exponent = self.as_int(other) 263 | 264 | # Allow for negative exponentiation, exponent our inverse by its absolute value 265 | # e.g. (x ** -10) == (1/x) ** 10 266 | if exponent < 0: 267 | base = modinv(self.value, self.modulus) 268 | exponent = -exponent 269 | else: 270 | base = self.value 271 | 272 | result = pow(base, exponent, self.modulus) 273 | return __class__(result, self.modulus) 274 | 275 | def __int__(self): 276 | return self.value 277 | 278 | def __mul__(self, other: Union[int, '__class__']): 279 | value = self.as_int(other) 280 | return __class__((self.value * value) % self.modulus, self.modulus) 281 | 282 | def __add__(self, other: Union[int, '__class__']): 283 | value = self.as_int(other) 284 | return __class__((self.value + value) % self.modulus, self.modulus) 285 | 286 | def __hash__(self): 287 | return hash((type(self), self.value, self.modulus)) 288 | 289 | def __eq__(self, other): 290 | if isinstance(other, int): 291 | return self.value == other 292 | elif isinstance(other, type(self)): 293 | return self.value == other.value and self.modulus == other.modulus 294 | raise TypeError("Cannot compare %r with %r" % (self, other)) 295 | 296 | 297 | def as_bytes(*all_items): 298 | """ 299 | Convert all items into a concatenated sequence of bytes 300 | """ 301 | result = [] 302 | for item in all_items: 303 | # Convert item into hashable bytes 304 | if isinstance(item, GroupElement): 305 | item = item.as_bytes() 306 | elif isinstance(item, int): 307 | _, n_bytes = number_bits_bytes_ceil(item) 308 | item = int.to_bytes(item, n_bytes, 'little') 309 | elif not isinstance(item, bytes): 310 | raise TypeError(item) 311 | result.append(item) 312 | return b''.join(result) 313 | 314 | 315 | def hash2bytes(*items: Union[GroupElement, int, bytes], n_bytes: int = HASH_SIZE_BYTES): 316 | """ 317 | Convert input items to bytes, concatenate them together, then hash 318 | Return exactly `n_bytes` worth of hash, truncating the hash and 319 | extending it by iteratively hashing if more bytes are needed. 320 | """ 321 | item = as_bytes(*items) 322 | result = hashed_item = HASH(item).digest() 323 | while len(result) < n_bytes: 324 | hashed_item = HASH(hashed_item).digest() 325 | result += hashed_item 326 | return result[:n_bytes] 327 | 328 | 329 | def hash2int(*args, **kwargs): 330 | return int.from_bytes(hash2bytes(*args, **kwargs), 'little') 331 | 332 | 333 | def hash2intmod(modulus : int, *args): 334 | _, n_bytes = number_bits_bytes_ceil(modulus) 335 | return hash2int(*args, n_bytes=n_bytes) % modulus 336 | 337 | 338 | def hash2prime(*items: Union[GroupElement, int, bytes], output_bytes: int = HASH_SIZE_BYTES): 339 | item = as_bytes(*items) 340 | 341 | while True: 342 | item = HASH(item).digest() 343 | i = int.from_bytes(item[:output_bytes], 'little') 344 | if number.isPrime(i): 345 | return i 346 | 347 | 348 | class NI_PoE(object): 349 | """ 350 | Non-Interactive Proofs of Exponentiation (NI-PoE). See BBF'18 (pages 8 and 42) for details. 351 | """ 352 | 353 | @classmethod 354 | def prove(cls, base: GroupElement, exp: int, result: GroupElement) -> GroupElement: 355 | l = hash2prime(base, exp, result) 356 | q, _ = divmod(exp, l) 357 | return base ** q 358 | 359 | @classmethod 360 | def verify(cls, base: GroupElement, exp: int, result: GroupElement, proof: GroupElement) -> bool: 361 | l = hash2prime(base, exp, result) 362 | r = exp % l 363 | w = (proof ** l) * (base ** r) 364 | return w == result 365 | 366 | 367 | @dataclass 368 | class NI_PoKE2: 369 | """ 370 | Non-Interactive Proofs of Knowledge of Exponent (NI-PoKE2). 371 | 372 | See BBF'18: 373 | 374 | - pg 10 (§3.2, "Extending PoKE for general bases") 375 | - pg 42 (appendix D) 376 | """ 377 | 378 | z: GroupElement 379 | Q: GroupElement 380 | r: int 381 | 382 | def __iter__(self) -> Tuple[GroupElement, GroupElement, int]: 383 | return iter((self.z, self.Q, self.r)) 384 | 385 | @classmethod 386 | def prove(cls, u: GroupElement, x: int, w: GroupElement): 387 | # u = base 388 | # x = exponent 389 | # w = result 390 | # XXX: cambrian implementation differs with `g` 391 | # paper specifies hash to group, cambrian/accumulator uses generator (2) 392 | g = u.hash_to_group(u, w) 393 | z = g ** x 394 | l = hash2prime(u, w, z) 395 | alpha = hash2int(u, w, z, l) # XXX: paper doesn't specify how big \alpha is 396 | q, r = divmod(x, l) 397 | Q = (u * (g ** alpha)) ** q # (ug^a)^q 398 | return __class__(z, Q, r) 399 | 400 | @classmethod 401 | def verify(cls, u: GroupElement, w: GroupElement, proof: '__class__'): 402 | # u = base 403 | # w = result 404 | z, Q, r = proof 405 | g = u.hash_to_group(u, w) 406 | l = hash2prime(u, w, z) 407 | alpha = hash2int(u, w, z, l) 408 | rhs = w * (z ** alpha) 409 | beta = (u * (g ** alpha)) ** r # ug^a 410 | lhs = (Q ** l) * beta # Q^l(ug^a)^r 411 | return lhs == rhs 412 | 413 | 414 | @dataclass 415 | class NonMembershipProof: 416 | item: int 417 | a: int 418 | B: GroupElement 419 | 420 | def __iter__(self): 421 | return iter((self.item, self.a, self.B)) 422 | 423 | 424 | @dataclass 425 | class MembershipProof: 426 | item: int 427 | witness: GroupElement 428 | 429 | def __iter__(self): 430 | return iter((self.item, self.witness)) 431 | 432 | 433 | class Accumulator(object): 434 | def __init__(self, modulus: int): 435 | self.p_bits, self.p_bytes = number_bits_bytes_floor(modulus, HASH_SIZE_BITS) 436 | self.value = self.generator = GroupElement.generator(modulus) 437 | self._items = PrimeHashSet(n_bytes=self.p_bytes) 438 | 439 | def as_prime(self, *args): 440 | return self._items.as_prime(*args) 441 | 442 | @property 443 | def guid(self) -> bytes: 444 | return self._items.guid 445 | 446 | @property 447 | def product(self) -> int: 448 | return self._items.product 449 | 450 | def __iter__(self): 451 | return iter(self._items) 452 | 453 | def __contains__(self, item): 454 | return item in self._items 455 | 456 | def add_prime(self, prime_item: int) -> int: 457 | # We assume primality of item is guaranteed 458 | old_value = self.value 459 | self.value = self.value ** prime_item 460 | return old_value 461 | 462 | def add(self, item) -> int: 463 | if isinstance(item, type(self)): 464 | item = item.value 465 | return self.add_prime(self._items.add(item)) 466 | 467 | def DelWMem(self, proof: MembershipProof) -> int: 468 | """ 469 | BBF'18 pg 16 § 4.2 figure 2 470 | 471 | Remove an item, by providing its witness 472 | 473 | - The witness is the previous value of the accumulator 474 | - Thus, we verify the witness for the item (proof of inclusion) 475 | - Then the new value for the accumulator becomes the witness 476 | """ 477 | item, witness = proof 478 | if not self.VerMem(item, witness): 479 | raise KeyError("invalid witness") 480 | if len(self.items) == 1: 481 | if witness != int(self.generator): 482 | raise RuntimeError("With 1 item in set, witness must be generator!") 483 | prime_item = self._items.remove(item) 484 | self.value = witness 485 | return prime_item 486 | 487 | def VerMem(self, item, witness: GroupElement) -> bool: 488 | prime_item = self._items.as_prime(item) 489 | result = witness ** prime_item 490 | return result == self.value 491 | 492 | def BatchAdd(self, items: Iterable[int]): 493 | """ 494 | BBF'18 pg 16 § 4.2 figure 2 495 | 496 | Perform a batch-update of the accumulator 497 | 498 | - Returns old value, and proof of the exponentation 499 | """ 500 | x_star = reduce(operator.mul, items) 501 | old_value = self.value 502 | self.value = self.value ** x_star 503 | return old_value, NI_PoE.prove(old_value, x_star, self.value) 504 | 505 | def BatchDel(self, items: Iterable[MembershipProof]): 506 | """ 507 | BBF'18 pg 16 § 4.2 figure 2 508 | 509 | Given witnesses for many items, delete them all from the accumulator 510 | 511 | - For each of the items, we verify their inclusion. 512 | """ 513 | items = iter(items) 514 | 515 | proof = next(items) 516 | item, witness = proof 517 | if not self.VerMem(item, witness): 518 | raise RuntimeError("Invalid witness %r" % (proof,)) 519 | 520 | new_value = old_value = self.value 521 | # new_value = A_{t+1} 522 | 523 | for proof in items: 524 | x_i, witness = proof 525 | if not self.VerMem(x_i, witness): 526 | raise RuntimeError("Invalid witness %r" % (proof,)) 527 | new_value = ShamirTrick(new_value, witness, x_star, x_i) 528 | x_star *= x_i 529 | 530 | return new_value, NI_PoE.prove(old_value, x_star, new_value) 531 | 532 | def MemWitCreate(self, our_item) -> GroupElement: 533 | """ 534 | Provide a witness for an item which exists within the accumulator 535 | This is, essentially, the accumulator but without our item 536 | """ 537 | assert our_item in self 538 | witness = self.generator 539 | for item in self: 540 | if our_item == item: 541 | continue 542 | prime_item = self._items.as_prime(item) 543 | witness = witness ** prime_item 544 | return witness 545 | 546 | def NonMemWitCreateFast_cambrial(self, our_item): 547 | """ 548 | Translated from cambrian/accumulator::`prove_nonmembership` 549 | """ 550 | # proof = self.NonMemWitCreate(our_item) 551 | x = self.as_prime(our_item) 552 | gcd, a, b = egcd(self.product, x) 553 | # assert gcd != 1 554 | g = self.generator 555 | d = g ** a 556 | v = self.value ** b 557 | g_inv = g / v 558 | poke2_proof = NI_PoKE2.prove(self.value, b, v) 559 | poe_proof = NI_PoE.prove(d, x, g_inv) 560 | return (d, v, g_inv, poke2_proof, poe_proof) 561 | 562 | def NonMemWitCreateFast_paper(self, s_star, x_star): 563 | """ 564 | BBF'18 pg 16 `NonMemWitCreate*` 565 | """ 566 | a, b = Bezout(s_star, x_star) 567 | V = self.value ** a 568 | B = self.generator ** b 569 | pi_V = NI_PoKE2.prove(self.value, a, V) # V = A^a 570 | pi_g = NI_PoE.prove(x_star, B, self.generator / V) # B^x = g * (V^-1) 571 | # Is this an equivalent proof satisfiable for `VerNonMem` 572 | return (V, B, pi_V, pi_g) 573 | 574 | def NonMemWitCreate(self, our_item) -> NonMembershipProof: 575 | x = self.as_prime(our_item) 576 | g, a, b = egcd(self.product, x) 577 | """ 578 | # XXX: we allow non-membership proofs to be created 579 | # this shows than even invalid non-membership proofs fail to validate 580 | if g != 1: 581 | raise RuntimeError("Inputs not co-prime") 582 | """ 583 | B = self.generator ** b 584 | return NonMembershipProof(our_item, a, B) 585 | 586 | def VerNonMem(self, proof: NonMembershipProof): 587 | our_item, a, B = proof 588 | x = self.as_prime(our_item) 589 | c = self.value ** a # A^a 590 | d = B ** x # B^x 591 | return (c * d) == self.generator # (A^a)(B^x) == g 592 | 593 | 594 | # -------------------- 595 | 596 | 597 | if __name__ == "__main__": 598 | import sys 599 | """ 600 | P = 7919 601 | Q = 7907 602 | N = P * Q 603 | Phi = (P-1)*(Q-1) 604 | 605 | A = Accumulator(N) 606 | our_item = 10 607 | A.add(our_item) 608 | found = 0 609 | for i, x in enumerate(range(2,N-1)): 610 | wit = A.NonMemWitCreate(x) 611 | result = A.VerNonMem(wit, x) 612 | if result: 613 | found += 1 614 | elif x != our_item: 615 | print('.', log2(Phi), 1/log2(Phi), found/i, A.as_prime(x), A.as_prime(our_item), wit) 616 | sys.exit(1) 617 | """ 618 | 619 | # PoE test 620 | exp = 20 621 | base = GroupElement(2, RSA_2048) 622 | result = GroupElement(1048576, RSA_2048) 623 | proof = NI_PoE.prove(base, exp, result) 624 | assert True == NI_PoE.verify(base, exp, result, proof) 625 | 626 | # PoKE2 test 627 | proof = NI_PoKE2.prove(base, exp, result) 628 | assert True == NI_PoKE2.verify(base, result, proof) 629 | 630 | left_acc = Accumulator(RSA_2048) 631 | left_acc.add(1) 632 | left_acc.add(2) 633 | left_wit = left_acc.add(left_acc) 634 | left_a = left_acc.MemWitCreate(1) 635 | left_b = left_acc.MemWitCreate(2) 636 | 637 | # Interestingly, the witness for the accumulator is it's self... 638 | assert left_acc.VerMem(left_wit, left_wit) 639 | assert left_acc.VerMem(1, left_a) 640 | assert left_acc.VerMem(2, left_b) 641 | 642 | left_exc = left_acc.NonMemWitCreate(3) 643 | assert True == left_acc.VerNonMem(left_exc) 644 | 645 | left_exc = left_acc.NonMemWitCreate(1) 646 | assert False == left_acc.VerNonMem(left_exc) 647 | 648 | 649 | # -------------------- 650 | 651 | 652 | right_acc = Accumulator(RSA_2048) 653 | right_acc.add(3) 654 | right_acc.add(4) 655 | right_wit = right_acc.add(right_acc) 656 | right_a = right_acc.MemWitCreate(3) 657 | right_b = right_acc.MemWitCreate(4) 658 | 659 | assert right_acc.VerMem(right_wit, right_wit) 660 | assert right_acc.VerMem(3, right_a) 661 | assert right_acc.VerMem(4, right_b) 662 | 663 | 664 | # -------------------- 665 | 666 | 667 | root_acc = Accumulator(RSA_2048) 668 | root_acc.add(left_acc) 669 | root_acc.add(right_acc) 670 | 671 | # TODO: prove 'left' exists in 'root' 672 | # TODO: prove 'right' exists in 'root' 673 | 674 | -------------------------------------------------------------------------------- /contracts/MillerRabin.sol: -------------------------------------------------------------------------------- 1 | pragma solidity ^0.5.0; 2 | 3 | contract MillerRabin 4 | { 5 | function modexp_rsa2048(uint256[8] memory b, uint256 e) 6 | public view returns(uint256[8] memory result) 7 | { 8 | bool success; 9 | assembly { 10 | let freemem := mload(0x40) 11 | 12 | // Length of base, exponent and modulus 13 | mstore(freemem, 0x100) // base (2048) 14 | mstore(add(freemem,0x20), 0x20) // exponent (256) 15 | mstore(add(freemem,0x40), 0x100) // modulus (2048) 16 | 17 | // 2048bit base 18 | success := staticcall(sub(gas, 2000), 4, b, 0x100, add(freemem,0x60), 0x100) 19 | 20 | // 256bit exponent 21 | mstore(add(freemem,0x160), e) 22 | 23 | // Hard-coded RSA-2048 modulus 24 | mstore(add(freemem,0x180), 0xc7970ceedcc3b0754490201a7aa613cd73911081c790f5f1a8726f463550bb5b) 25 | mstore(add(freemem,0x1A0), 0x7ff0db8e1ea1189ec72f93d1650011bd721aeeacc2acde32a04107f0648c2813) 26 | mstore(add(freemem,0x1C0), 0xa31f5b0b7765ff8b44b4b6ffc93384b646eb09c7cf5e8592d40ea33c80039f35) 27 | mstore(add(freemem,0x1E0), 0xb4f14a04b51f7bfd781be4d1673164ba8eb991c2c4d730bbbe35f592bdef524a) 28 | mstore(add(freemem,0x200), 0xf7e8daefd26c66fc02c479af89d64d373f442709439de66ceb955f3ea37d5159) 29 | mstore(add(freemem,0x220), 0xf6135809f85334b5cb1813addc80cd05609f10ac6a95ad65872c909525bdad32) 30 | mstore(add(freemem,0x240), 0xbc729592642920f24c61dc5b3c3b7923e56b16a4d9d373d8721f24a3fc0f1b31) 31 | mstore(add(freemem,0x260), 0x31f55615172866bccc30f95054c824e733a5eb6817f7bc16399d48c6361cc7e5) 32 | 33 | success := staticcall(sub(gas, 2000), 0x0000000000000000000000000000000000000005, freemem, 0x280, result, 0x100) 34 | } 35 | require( success ); 36 | } 37 | 38 | function modexp(uint256 b, uint256 e, uint256 m) 39 | public view returns(uint256 result) 40 | { 41 | bool success; 42 | assembly { 43 | let freemem := mload(0x40) 44 | mstore(freemem, 0x20) 45 | mstore(add(freemem,0x20), 0x20) 46 | mstore(add(freemem,0x40), 0x20) 47 | mstore(add(freemem,0x60), b) 48 | mstore(add(freemem,0x80), e) 49 | mstore(add(freemem,0xA0), m) 50 | success := staticcall(sub(gas, 2000), 5, freemem, 0xC0, freemem, 0x20) 51 | result := mload(freemem) 52 | } 53 | require(success); 54 | } 55 | 56 | function IsPrime(uint256 n, uint32 k, uint256 entropy) 57 | public view returns (bool) 58 | { 59 | if(n == 2) 60 | return true; 61 | 62 | if( n < 2 || n % 2 == 0 ) 63 | return false; 64 | 65 | uint256 d = n - 1; 66 | uint256 s = 0; 67 | 68 | while( d % 2 == 0 ) { 69 | d = d / 2; 70 | s += 1; 71 | } 72 | 73 | while( k-- != 0 ) { 74 | // XXX: This is supposed to replace `rand(2, n-1)` but has small probabilitiy of incorrect value 75 | entropy = uint256(keccak256(abi.encodePacked(entropy, n))) % n; 76 | 77 | uint256 x = modexp(entropy, d, n); 78 | if (x == 1 || x == n-1) 79 | continue; 80 | 81 | bool ok = false; 82 | 83 | for( uint j = 1; j < s; j++ ) { 84 | x = mulmod(x, x, n); 85 | if( x == 1 ) 86 | return false; 87 | if(x == n-1) { 88 | ok = true; 89 | break; 90 | } 91 | } 92 | if ( false == ok ) { 93 | return false; 94 | } 95 | } 96 | return true; 97 | } 98 | } --------------------------------------------------------------------------------