├── README.md ├── ntt.py ├── generate_prime.py ├── helper.py ├── poly.py ├── BFV_demo.py └── BFV.py /README.md: -------------------------------------------------------------------------------- 1 | # bfv-python 2 | Simple Python implementation of Brakerski/Fan-Vercauteren (BFV) homomorphic encryption scheme following the definitions in the [paper](https://eprint.iacr.org/2012/144.pdf). The implementation is for educational purposes and it is not optimized for performance. 3 | -------------------------------------------------------------------------------- /ntt.py: -------------------------------------------------------------------------------- 1 | import math 2 | from helper import * 3 | 4 | # --- NTT functions --- 5 | 6 | # Iterative NTT (Forward and Inverse) 7 | # arrayIn = input polynomial 8 | # P = modulus 9 | # W = nth root of unity 10 | # inv = 0: forward NTT / 1: inverse NTT 11 | # Input: Standard order/Output: Standard order 12 | 13 | def NTT(A, W_table, q): 14 | # print("DEBUG q:{}".format(q)) 15 | n = len(A) 16 | B = [x for x in A] 17 | 18 | v = int(math.log(n, 2)) 19 | 20 | for i in range(0, v): 21 | for j in range(0, (2 ** i)): 22 | for k in range(0, (2 ** (v - i - 1))): 23 | s = j * (2 ** (v - i)) + k 24 | t = s + (2 ** (v - i - 1)) 25 | 26 | # w = (W ** ((2 ** i) * k)) % q 27 | w = W_table[((2 ** i) * k)] 28 | 29 | as_temp = B[s] 30 | at_temp = B[t] 31 | 32 | B[s] = (as_temp + at_temp) % q 33 | B[t] = ((as_temp - at_temp) * w) % q 34 | 35 | B = indexReverse(B, v) 36 | 37 | return B 38 | 39 | def INTT(A, W_table, q): 40 | n = len(A) 41 | B = [x for x in A] 42 | 43 | v = int(math.log(n, 2)) 44 | 45 | for i in range(0, v): 46 | for j in range(0, (2 ** i)): 47 | for k in range(0, (2 ** (v - i - 1))): 48 | s = j * (2 ** (v - i)) + k 49 | t = s + (2 ** (v - i - 1)) 50 | 51 | # w = (W ** ((2 ** i) * k)) % q 52 | w = W_table[((2 ** i) * k)] 53 | 54 | as_temp = B[s] 55 | at_temp = B[t] 56 | 57 | B[s] = (as_temp + at_temp) % q 58 | B[t] = ((as_temp - at_temp) * w) % q 59 | 60 | B = indexReverse(B, v) 61 | 62 | n_inv = modinv(n, q) 63 | for i in range(n): 64 | B[i] = (B[i] * n_inv) % q 65 | 66 | return B 67 | 68 | -------------------------------------------------------------------------------- /generate_prime.py: -------------------------------------------------------------------------------- 1 | # Copyright 2015 Pedro Alves 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import random 16 | import math 17 | import sys 18 | 19 | def miller_rabin(p,s=11): 20 | #computes p-1 decomposition in 2**u*r 21 | r = p-1 22 | u = 0 23 | while r&1 == 0:#true while the last bit of r is zero 24 | u += 1 25 | r = int(r/2) 26 | 27 | # apply miller_rabin primality test 28 | for i in range(s): 29 | a = random.randrange(2,p-1) # choose random a in {2,3,...,p-2} 30 | z = pow(a,r,p) 31 | 32 | if z != 1 and z != p-1: 33 | for j in range(u-1): 34 | if z != p-1: 35 | z = pow(z,2,p) 36 | if z == 1: 37 | return False 38 | else: 39 | break 40 | if z != p-1: 41 | return False 42 | return True 43 | 44 | 45 | def is_prime(n,s=11): 46 | #lowPrimes is all primes (sans 2, which is covered by the bitwise and operator) 47 | #under 1000. taking n modulo each lowPrime allows us to remove a huge chunk 48 | #of composite numbers from our potential pool without resorting to Rabin-Miller 49 | lowPrimes = [3,5,7,11,13,17,19,23,29,31,37,41,43,47,53,59,61,67,71,73,79,83,89,97 50 | ,101,103,107,109,113,127,131,137,139,149,151,157,163,167,173,179 51 | ,181,191,193,197,199,211,223,227,229,233,239,241,251,257,263,269 52 | ,271,277,281,283,293,307,311,313,317,331,337,347,349,353,359,367 53 | ,373,379,383,389,397,401,409,419,421,431,433,439,443,449,457,461 54 | ,463,467,479,487,491,499,503,509,521,523,541,547,557,563,569,571 55 | ,577,587,593,599,601,607,613,617,619,631,641,643,647,653,659,661 56 | ,673,677,683,691,701,709,719,727,733,739,743,751,757,761,769,773 57 | ,787,797,809,811,821,823,827,829,839,853,857,859,863,877,881,883 58 | ,887,907,911,919,929,937,941,947,953,967,971,977,983,991,997] 59 | if (n >= 3): 60 | if (n&1 != 0): 61 | for p in lowPrimes: 62 | if (n == p): 63 | return True 64 | if (n % p == 0): 65 | return False 66 | return miller_rabin(n,s) 67 | return False 68 | 69 | def generate_large_prime(k,s=11): 70 | #print "Generating prime of %d bits" % k 71 | #k is the desired bit length 72 | 73 | # using security parameter s=11, we have a error probability of less than 74 | # 2**-80 75 | 76 | r=int(100*(math.log(k,2)+1)) #number of max attempts 77 | while r>0: 78 | #randrange is mersenne twister and is completely deterministic 79 | #unusable for serious crypto purposes 80 | n = random.randrange(2**(k-1),2**(k)) 81 | r-=1 82 | if is_prime(n,s) == True: 83 | return n 84 | raise Exception("Failure after %d tries." % r) 85 | -------------------------------------------------------------------------------- /helper.py: -------------------------------------------------------------------------------- 1 | 2 | from generate_prime import * 3 | from random import randint 4 | 5 | # Modular inverse of an integer 6 | def egcd(a, b): 7 | if a == 0: 8 | return (b, 0, 1) 9 | else: 10 | g, y, x = egcd(b % a, a) 11 | return (g, x - (b // a) * y, y) 12 | 13 | def modinv(a, m): 14 | g, x, y = egcd(a, m) 15 | if g != 1: 16 | raise Exception('Modular inverse does not exist') 17 | else: 18 | return x % m 19 | 20 | # GCD of two integers 21 | def gcd(n1, n2): 22 | a = n1 23 | b = n2 24 | while b != 0: 25 | a, b = b, a % b 26 | return a 27 | 28 | # Bit-Reverse integer 29 | def intReverse(a,n): 30 | b = ('{:0'+str(n)+'b}').format(a) 31 | return int(b[::-1],2) 32 | 33 | # Bit-Reversed index 34 | def indexReverse(a,r): 35 | n = len(a) 36 | b = [0]*n 37 | for i in range(n): 38 | rev_idx = intReverse(i,r) 39 | b[rev_idx] = a[i] 40 | return b 41 | 42 | # Reference Polynomial Multiplication 43 | # with f(x) = x^n + 1 44 | def RefPolMul(A, B, M): 45 | C = [0] * (2 * len(A)) 46 | D = [0] * (len(A)) 47 | for indexA, elemA in enumerate(A): 48 | for indexB, elemB in enumerate(B): 49 | C[indexA + indexB] = (C[indexA + indexB] + elemA * elemB) % M 50 | 51 | for i in range(len(A)): 52 | D[i] = (C[i] - C[i + len(A)]) % M 53 | return D 54 | 55 | # Reference Polynomial Multiplication (w/ modulus) 56 | # with f(x) = x^n + 1 57 | def RefPolMulv2(A, B): 58 | C = [0] * (2 * len(A)) 59 | D = [0] * (len(A)) 60 | for indexA, elemA in enumerate(A): 61 | for indexB, elemB in enumerate(B): 62 | C[indexA + indexB] = (C[indexA + indexB] + elemA * elemB) 63 | 64 | for i in range(len(A)): 65 | D[i] = (C[i] - C[i + len(A)]) 66 | return D 67 | 68 | # Check if input is m-th (could be n or 2n) primitive root of unity of q 69 | def isrootofunity(w,m,q): 70 | if w == 0: 71 | return False 72 | elif pow(w,m//2,q) == (q-1): 73 | return True 74 | else: 75 | return False 76 | 77 | # Returns a proper NTT-friendly prime 78 | def GetProperPrime(n,logq): 79 | factor = 2*n 80 | value = (1< lbound): 83 | if is_prime(value) == True: 84 | return value 85 | else: 86 | value = value - factor 87 | raise Exception("Failed to find a proper prime.") 88 | 89 | # Returns a primitive root 90 | def FindPrimitiveRoot(m,q): 91 | g = (q-1)//m 92 | 93 | if (q-1) != g*m: 94 | return False 95 | 96 | attempt_ctr = 0 97 | attempt_max = 100 98 | 99 | while(attempt_ctr < attempt_max): 100 | a = randint(2,q-1) 101 | b = pow(a,g,q) 102 | # check 103 | if isrootofunity(b,m,q): 104 | return True,b 105 | else: 106 | attempt_ctr = attempt_ctr+1 107 | 108 | return True,0 109 | 110 | # Generate necessary BFV parameters given n and log(q) 111 | def ParamGen(n,logq): 112 | pfound = False 113 | while (not(pfound)): 114 | # first, find a proper prime 115 | q = GetProperPrime(n,logq) 116 | # then find primitive root 117 | pfound, psi = FindPrimitiveRoot(2*n,q) 118 | psiv= modinv(psi,q) 119 | w = pow(psi,2,q) 120 | wv = modinv(w,q) 121 | return q,psi,psiv,w,wv 122 | -------------------------------------------------------------------------------- /poly.py: -------------------------------------------------------------------------------- 1 | # Poly 2 | 3 | from random import randint,gauss 4 | from ntt import * 5 | 6 | class Poly: 7 | def __init__(self, n, q, np=[0,0,0,0]): 8 | self.n = n 9 | self.q = q 10 | self.np= np # NTT parameters: [w,w_inv,psi,psi_inv] 11 | self.F = [0]*n 12 | self.inNTT = False 13 | # 14 | def randomize(self, B, domain=False, type=0, mu=0, sigma=0): 15 | # type:0 --> uniform 16 | # type:1 --> gauss 17 | if type == 0: 18 | self.F = [randint(-(B//2), B//2)%self.q for i in range(self.n)] 19 | self.inNTT = domain 20 | else: 21 | self.F = [int(gauss(mu,sigma))%self.q for i in range(self.n)] 22 | self.inNTT = domain 23 | # 24 | def __str__(self): 25 | pstr = str(self.F[0]) 26 | tmp = min(self.n,8) 27 | 28 | for i in range(1,tmp): 29 | pstr = pstr+" + "+str(self.F[i])+"*x^"+str(i) 30 | 31 | if self.n > 8: 32 | pstr = pstr + " + ..." 33 | return pstr 34 | # 35 | def __add__(self, b): 36 | if self.inNTT != b.inNTT: 37 | raise Exception("Polynomial Addiditon: Inputs must be in the same domain.") 38 | elif self.q != b.q: 39 | raise Exception("Polynomial Addiditon: Inputs must have the same modulus") 40 | else: 41 | c = Poly(self.n, self.q, self.np) 42 | c.F = [(x+y)%self.q for x,y in zip(self.F,b.F)] 43 | c.inNTT = self.inNTT 44 | return c 45 | # 46 | def __sub__(self, b): 47 | if self.inNTT != b.inNTT: 48 | raise Exception("Polynomial Subtraction: Inputs must be in the same domain.") 49 | elif self.q != b.q: 50 | raise Exception("Polynomial Subtraction: Inputs must have the same modulus") 51 | else: 52 | c = Poly(self.n, self.q, self.np) 53 | c.F = [(x-y)%self.q for x,y in zip(self.F,b.F)] 54 | c.inNTT = self.inNTT 55 | return c 56 | # 57 | def __mul__(self, b): 58 | if self.inNTT != b.inNTT: 59 | raise Exception("Polynomial Multiplication: Inputs must be in the same domain.") 60 | elif self.q != b.q: 61 | raise Exception("Polynomial Multiplication: Inputs must have the same modulus") 62 | else: 63 | """ 64 | Assuming both inputs in POL/NTT domain 65 | If in NTT domain --> Coeff-wise multiplication 66 | If in POL domain --> Full polynomial multiplication 67 | """ 68 | c = Poly(self.n, self.q, self.np) 69 | if self.inNTT == True and b.inNTT == True: 70 | c.F = [((x*y)%self.q) for x,y in zip(self.F,b.F)] 71 | c.inNTT = True 72 | else: 73 | # x1=self*psi, x2=b*psi 74 | # x1n = NTT(x1,w), x2n = NTT(x2,w) 75 | # x3n = x1n*x2n 76 | # x3 = INTT(x3n,w_inv) 77 | # c = x3*psi_inv 78 | 79 | w_table = self.np[0] 80 | wv_table = self.np[1] 81 | psi_table = self.np[2] 82 | psiv_table = self.np[3] 83 | 84 | s_p = [(x*psi_table[pwr])%self.q for pwr,x in enumerate(self.F)] 85 | b_p = [(x*psi_table[pwr])%self.q for pwr,x in enumerate(b.F)] 86 | s_n = NTT(s_p,w_table,self.q) 87 | b_n = NTT(b_p,w_table,self.q) 88 | sb_n= [(x*y)%self.q for x,y in zip(s_n,b_n)] 89 | sb_p= INTT(sb_n,wv_table,self.q) 90 | sb = [(x*psiv_table[pwr])%self.q for pwr,x in enumerate(sb_p)] 91 | 92 | c.F = sb 93 | c.inNTT = False 94 | return c 95 | # 96 | def __mod__(self,base): 97 | b = Poly(self.n, self.q, self.np) 98 | b.F = [(x%base) for x in self.F] 99 | b.inNTT = self.inNTT 100 | return b 101 | # 102 | def __round__(self): 103 | b = Poly(self.n, self.q, self.np) 104 | b.F = [round(x) for x in self.F] 105 | b.inNTT = self.inNTT 106 | return b 107 | # 108 | def __eq__(self, b): 109 | if self.n != b.n: 110 | return False 111 | elif self.q != b.q: 112 | return False 113 | else: 114 | for i,j in zip(self.F,b.F): 115 | if i != j: 116 | return False 117 | return True 118 | # 119 | def __neg__(self): 120 | b = Poly(self.n, self.q, self.np) 121 | b.F = [((-x) % self.q) for x in self.F] 122 | b.inNTT = self.inNTT 123 | return b 124 | # 125 | def toNTT(self): 126 | b = Poly(self.n, self.q, self.np) 127 | if self.inNTT == False: 128 | b.F = NTT(self.F,self.np[0],self.q) 129 | b.inNTT = True 130 | else: 131 | b.F = [x for x in self.F] 132 | b.inNTT = True 133 | return b 134 | # 135 | def toPOL(self): 136 | b = Poly(self.n, self.q, self.np) 137 | if self.inNTT == False: 138 | b.F = [x for x in self.F] 139 | b.inNTT = False 140 | else: 141 | b.F = INTT(self.F,self.np[1],self.q) 142 | b.inNTT = False 143 | return b 144 | # 145 | -------------------------------------------------------------------------------- /BFV_demo.py: -------------------------------------------------------------------------------- 1 | from BFV import * 2 | from helper import * 3 | 4 | from random import randint 5 | from math import log,ceil 6 | 7 | # This implementation follows the description at https://eprint.iacr.org/2012/144.pdf 8 | # Brakerski/Fan-Vercauteren (BFV) somewhat homomorphic encryption scheme 9 | # 10 | # Polynomial arithmetic on ciphertext domain is performed in Z[x]_q/x^n+1 11 | # Polynomial arithmetic on plaintext domain is performed in Z[x]_t/x^n+1 12 | # * n: ring size 13 | # * q: ciphertext coefficient modulus 14 | # * t: plaintext coefficient modulus (if t is equal to 2, no negative values is accepted) 15 | # * psi,psiv,w,wv: polynomial arithmetic parameters 16 | # 17 | # Note that n,q,t parameters together determine the multiplicative depth. 18 | 19 | # Parameter generation (pre-defined or generate parameters) 20 | PD = 0 # 0: generate -- 1: pre-defined 21 | 22 | if PD == 0: 23 | # Select one of the parameter sets below 24 | t = 16; n, q, psi = 1024 , 132120577 , 73993 # log(q) = 27 25 | # t = 256; n, q, psi = 2048 , 137438691329 , 22157790 # log(q) = 37 26 | # t = 1024; n, q, psi = 4096 , 288230376135196673, 60193018759093 # log(q) = 58 27 | 28 | # other necessary parameters 29 | psiv= modinv(psi,q) 30 | w = pow(psi,2,q) 31 | wv = modinv(w,q) 32 | else: 33 | # Enter proper parameters below 34 | t, n, logq = 16, 1024, 27 35 | # t, n, logq = 256, 2048, 37 36 | # t, n, logq = 1024, 4096, 58 37 | 38 | # other necessary parameters (based on n and log(q) determine other parameter) 39 | q,psi,psiv,w,wv = ParamGen(n,logq) 40 | 41 | # Determine mu, sigma (for discrete gaussian distribution) 42 | mu = 0 43 | sigma = 0.5 * 3.2 44 | 45 | # Determine T, p (for relinearization and galois keys) based on noise analysis 46 | T = 256 47 | p = q**3 + 1 48 | 49 | # Generate polynomial arithmetic tables 50 | w_table = [1]*n 51 | wv_table = [1]*n 52 | psi_table = [1]*n 53 | psiv_table = [1]*n 54 | for i in range(1,n): 55 | w_table[i] = ((w_table[i-1] *w) % q) 56 | wv_table[i] = ((wv_table[i-1] *wv) % q) 57 | psi_table[i] = ((psi_table[i-1] *psi) % q) 58 | psiv_table[i] = ((psiv_table[i-1]*psiv) % q) 59 | 60 | qnp = [w_table,wv_table,psi_table,psiv_table] 61 | 62 | print("--- Starting BFV Demo") 63 | 64 | # Generate BFV evaluator 65 | Evaluator = BFV(n, q, t, mu, sigma, qnp) 66 | 67 | # Generate Keys 68 | Evaluator.SecretKeyGen() 69 | Evaluator.PublicKeyGen() 70 | Evaluator.EvalKeyGenV1(T) 71 | Evaluator.EvalKeyGenV2(p) 72 | 73 | # print system parameters 74 | print(Evaluator) 75 | 76 | # Generate random message 77 | # n1, n2 = 15, -5 78 | n1, n2 = randint(-(2**15),2**15-1), randint(-(2**15),2**15-1) 79 | 80 | print("--- Random integers n1 and n2 are generated.") 81 | print("* n1: {}".format(n1)) 82 | print("* n2: {}".format(n2)) 83 | print("* n1+n2: {}".format(n1+n2)) 84 | print("* n1-n2: {}".format(n1-n2)) 85 | print("* n1*n2: {}".format(n1*n2)) 86 | print("") 87 | 88 | # Encode random messages into plaintext polynomials 89 | print("--- n1 and n2 are encoded as polynomials m1(x) and m2(x).") 90 | m1 = Evaluator.IntEncode(n1) 91 | m2 = Evaluator.IntEncode(n2) 92 | 93 | print("* m1(x): {}".format(m1)) 94 | print("* m2(x): {}".format(m2)) 95 | print("") 96 | 97 | # Encrypt message 98 | ct1 = Evaluator.Encryption(m1) 99 | ct2 = Evaluator.Encryption(m2) 100 | 101 | print("--- m1 and m2 are encrypted as ct1 and ct2.") 102 | print("* ct1[0]: {}".format(ct1[0])) 103 | print("* ct1[1]: {}".format(ct1[1])) 104 | print("* ct2[0]: {}".format(ct2[0])) 105 | print("* ct2[1]: {}".format(ct2[1])) 106 | print("") 107 | 108 | # Homomorphic Addition 109 | ct = Evaluator.HomomorphicAddition(ct1,ct2) 110 | mt = Evaluator.Decryption(ct) 111 | 112 | nr = Evaluator.IntDecode(mt) 113 | ne = (n1+n2) 114 | 115 | print("--- Performing ct_add = Enc(m1) + Enc(m2)") 116 | print("* ct_add[0] :{}".format(ct[0])) 117 | print("* ct_add[1] :{}".format(ct[1])) 118 | print("--- Performing ct_dec = Dec(ct_add)") 119 | print("* ct_dec :{}".format(mt)) 120 | print("--- Performing ct_dcd = Decode(ct_dec)") 121 | print("* ct_dcd :{}".format(nr)) 122 | 123 | if nr == ne: 124 | print("* Homomorphic addition works.") 125 | else: 126 | print("* Homomorphic addition does not work.") 127 | print("") 128 | 129 | # Homomorphic Subtraction 130 | ct = Evaluator.HomomorphicSubtraction(ct1,ct2) 131 | mt = Evaluator.Decryption(ct) 132 | 133 | nr = Evaluator.IntDecode(mt) 134 | ne = (n1-n2) 135 | 136 | print("--- Performing ct_sub = Enc(m1) - Enc(m2)") 137 | print("* ct_sub[0] :{}".format(ct[0])) 138 | print("* ct_sub[1] :{}".format(ct[1])) 139 | print("--- Performing ct_dec = Dec(ct_sub)") 140 | print("* ct_dec :{}".format(mt)) 141 | print("--- Performing ct_dcd = Decode(ct_dec)") 142 | print("* ct_dcd :{}".format(nr)) 143 | 144 | if nr == ne: 145 | print("* Homomorphic subtraction works.") 146 | else: 147 | print("* Homomorphic subtraction does not work.") 148 | print("") 149 | 150 | # Multiply two message (no relinearization) 151 | ct = Evaluator.HomomorphicMultiplication(ct1,ct2) 152 | mt = Evaluator.DecryptionV2(ct) 153 | 154 | nr = Evaluator.IntDecode(mt) 155 | ne = (n1*n2) 156 | 157 | print("--- Performing ct_mul = Enc(m1) * Enc(m2) (no relinearization)") 158 | print("* ct_mul[0] :{}".format(ct[0])) 159 | print("* ct_mul[1] :{}".format(ct[1])) 160 | print("--- Performing ct_dec = Dec(ct_sub)") 161 | print("* ct_dec :{}".format(mt)) 162 | print("--- Performing ct_dcd = Decode(ct_dec)") 163 | print("* ct_dcd :{}".format(nr)) 164 | 165 | if nr == ne: 166 | print("* Homomorphic multiplication works.") 167 | else: 168 | print("* Homomorphic multiplication does not work.") 169 | print("") 170 | 171 | # Multiply two message (relinearization v1) 172 | ct = Evaluator.HomomorphicMultiplication(ct1,ct2) 173 | ct = Evaluator.RelinearizationV1(ct) 174 | mt = Evaluator.Decryption(ct) 175 | 176 | nr = Evaluator.IntDecode(mt) 177 | ne = (n1*n2) 178 | 179 | print("--- Performing ct_mul = Enc(m1) * Enc(m2) (with relinearization v1)") 180 | print("* ct_mul[0] :{}".format(ct[0])) 181 | print("* ct_mul[1] :{}".format(ct[1])) 182 | print("--- Performing ct_dec = Dec(ct_sub)") 183 | print("* ct_dec :{}".format(mt)) 184 | print("--- Performing ct_dcd = Decode(ct_dec)") 185 | print("* ct_dcd :{}".format(nr)) 186 | 187 | if nr == ne: 188 | print("* Homomorphic multiplication works.") 189 | else: 190 | print("* Homomorphic multiplication does not work.") 191 | print("") 192 | 193 | """ 194 | # Multiply two message (relinearization v2) 195 | ct = Evaluator.HomomorphicMultiplication(ct1,ct2) 196 | ct = Evaluator.RelinearizationV2(ct) 197 | mt = Evaluator.Decryption(ct) 198 | 199 | nr = Evaluator.IntDecode(mt) 200 | ne = (n1*n2) 201 | 202 | print("--- Performing ct_mul = Enc(m1) * Enc(m2) (with relinearization v2)") 203 | print("* ct_mul[0] :{}".format(ct[0])) 204 | print("* ct_mul[1] :{}".format(ct[1])) 205 | print("--- Performing ct_dec = Dec(ct_sub)") 206 | print("* ct_dec :{}".format(mt)) 207 | print("--- Performing ct_dcd = Decode(ct_dec)") 208 | print("* ct_dcd :{}".format(nr)) 209 | 210 | if nr == ne: 211 | print("* Homomorphic multiplication works.") 212 | else: 213 | print("* Homomorphic multiplication does not work.") 214 | """ 215 | # 216 | -------------------------------------------------------------------------------- /BFV.py: -------------------------------------------------------------------------------- 1 | # BFV 2 | 3 | from poly import * 4 | 5 | class BFV: 6 | # Definitions 7 | # Z_q[x]/f(x) = x^n + 1 where n=power-of-two 8 | 9 | # Operations 10 | # -- SecretKeyGen 11 | # -- PublicKeyGen 12 | # -- Encryption 13 | # -- Decryption 14 | # -- EvaluationKeyGenV1 15 | # -- EvaluationKeyGenV2 (need to be fixed) 16 | # -- HomAdd 17 | # -- HomMult 18 | # -- RelinV1 19 | # -- RelinV2 (need to be fixed) 20 | 21 | # Parameters 22 | # (From outside) 23 | # -- n (ring size) 24 | # -- q (ciphertext modulus) 25 | # -- t (plaintext modulus) 26 | # -- mu (distribution mean) 27 | # -- sigma (distribution std. dev.) 28 | # -- qnp (NTT parameters: [w,w_inv,psi,psi_inv]) 29 | # (Generated with parameters) 30 | # -- sk 31 | # -- pk 32 | # -- rlk1, rlk2 33 | 34 | def __init__(self, n, q, t, mu, sigma, qnp): 35 | self.n = n 36 | self.q = q 37 | self.t = t 38 | self.T = 0 39 | self.l = 0 40 | self.p = 0 41 | self.mu = mu 42 | self.sigma = sigma 43 | self.qnp= qnp # array NTT parameters: [w,w_inv,psi,psi_inv] 44 | # 45 | self.sk = [] 46 | self.pk = [] 47 | self.rlk1 = [] 48 | self.rlk2 = [] 49 | # 50 | def __str__(self): 51 | str = "\n--- Parameters:\n" 52 | str = str + "n : {}\n".format(self.n) 53 | str = str + "q : {}\n".format(self.q) 54 | str = str + "t : {}\n".format(self.t) 55 | str = str + "T : {}\n".format(self.T) 56 | str = str + "l : {}\n".format(self.l) 57 | str = str + "p : {}\n".format(self.p) 58 | str = str + "mu : {}\n".format(self.mu) 59 | str = str + "sigma: {}\n".format(self.sigma) 60 | return str 61 | # 62 | def SecretKeyGen(self): 63 | """ 64 | sk <- R_2 65 | """ 66 | s = Poly(self.n,self.q,self.qnp) 67 | s.randomize(2) 68 | self.sk = s 69 | # 70 | def PublicKeyGen(self): 71 | """ 72 | a <- R_q 73 | e <- X 74 | pk[0] <- (-(a*sk)+e) mod q 75 | pk[1] <- a 76 | """ 77 | a, e = Poly(self.n,self.q,self.qnp), Poly(self.n,self.q,self.qnp) 78 | a.randomize(self.q) 79 | e.randomize(0, domain=False, type=1, mu=self.mu, sigma=self.sigma) 80 | pk0 = -(a*self.sk + e) 81 | pk1 = a 82 | self.pk = [pk0,pk1] 83 | # 84 | def EvalKeyGenV1(self, T): 85 | self.T = T 86 | self.l = int(math.floor(math.log(self.q,self.T))) 87 | 88 | rlk1 = [] 89 | 90 | sk2 = (self.sk * self.sk) 91 | 92 | for i in range(self.l+1): 93 | ai , ei = Poly(self.n,self.q,self.qnp), Poly(self.n,self.q,self.qnp) 94 | ai.randomize(self.q) 95 | ei.randomize(0, domain=False, type=1, mu=self.mu, sigma=self.sigma) 96 | 97 | Ts2 = Poly(self.n,self.q,self.qnp) 98 | Ts2.F = [((self.T**i)*j) % self.q for j in sk2.F] 99 | 100 | rlki0 = Ts2 - (ai*self.sk + ei) 101 | rlki1 = ai 102 | 103 | rlk1.append([rlki0,rlki1]) 104 | 105 | self.rlk1 = rlk1 106 | # 107 | def EvalKeyGenV2(self, p): 108 | """ 109 | a <- R_p*q 110 | e <- X' 111 | rlk[0] = [-(a*sk+e)+p*s^2]_p*q 112 | rlk[1] = a 113 | """ 114 | self.p = p 115 | 116 | rlk2 = [] 117 | 118 | a, e = Poly(self.n,self.p*self.q), Poly(self.n,self.p*self.q) 119 | a.randomize(self.p*self.q) 120 | e.randomize(0, domain=False, type=1, mu=self.mu, sigma=self.sigma) 121 | 122 | c0 = RefPolMulv2(a.F,self.sk.F) 123 | c0 = [c0_+e_ for c0_,e_ in zip(c0,e.F)] 124 | c1 = RefPolMulv2(self.sk.F,self.sk.F) 125 | c1 = [self.p*c1_ for c1_ in c1] 126 | c2 = [(c1_-c0_)%(self.p*self.q) for c0_,c1_ in zip(c0,c1)] 127 | 128 | c = Poly(self.n,self.p*self.q) 129 | c.F = c2 130 | 131 | rlk2.append(c) 132 | rlk2.append(a) 133 | 134 | self.rlk2 = rlk2 135 | # 136 | def Encryption(self, m): 137 | """ 138 | delta = floor(q/t) 139 | 140 | u <- random polynomial from R_2 141 | e1 <- random polynomial from R_B 142 | e2 <- random polynomial from R_B 143 | 144 | c0 <- pk0*u + e1 + m*delta 145 | c1 <- pk1*u + e2 146 | """ 147 | delta = int(math.floor(self.q/self.t)) 148 | 149 | u, e1, e2 = Poly(self.n,self.q,self.qnp), Poly(self.n,self.q,self.qnp), Poly(self.n,self.q,self.qnp) 150 | 151 | u.randomize(2) 152 | e1.randomize(0, domain=False, type=1, mu=self.mu, sigma=self.sigma) 153 | e2.randomize(0, domain=False, type=1, mu=self.mu, sigma=self.sigma) 154 | 155 | md = Poly(self.n,self.q,self.qnp) 156 | md.F = [(delta*x) % self.q for x in m.F] 157 | 158 | c0 = self.pk[0]*u + e1 159 | c0 = c0 + md 160 | c1 = self.pk[1]*u + e2 161 | 162 | return [c0,c1] 163 | # 164 | def Decryption(self, ct): 165 | """ 166 | ct <- c1*s + c0 167 | ct <- floot(ct*(t/q)) 168 | m <- [ct]_t 169 | """ 170 | m = ct[1]*self.sk + ct[0] 171 | m.F = [((self.t*x)/self.q) for x in m.F] 172 | m = round(m) 173 | m = m % self.t 174 | mr = Poly(self.n,self.t,self.qnp) 175 | mr.F = m.F 176 | mr.inNTT = m.inNTT 177 | return mr 178 | # 179 | def DecryptionV2(self, ct): 180 | """ 181 | ct <- c2*s^2 + c1*s + c0 182 | ct <- floot(ct*(t/q)) 183 | m <- [ct]_t 184 | """ 185 | sk2 = (self.sk * self.sk) 186 | m = ct[0] 187 | m = (m + (ct[1]*self.sk)) 188 | m = (m + (ct[2]*sk2)) 189 | m.F = [((self.t * x) / self.q) for x in m.F] 190 | m = round(m) 191 | m = m % self.t 192 | mr = Poly(self.n,self.t,self.qnp) 193 | mr.F = m.F 194 | mr.inNTT = m.inNTT 195 | return mr 196 | # 197 | def RelinearizationV1(self,ct): 198 | c0 = ct[0] 199 | c1 = ct[1] 200 | c2 = ct[2] 201 | 202 | # divide c2 into base T 203 | c2i = [] 204 | 205 | c2q = Poly(self.n,self.q,self.qnp) 206 | c2q.F = [x for x in c2.F] 207 | 208 | for i in range(self.l+1): 209 | c2r = Poly(self.n,self.q,self.qnp) 210 | 211 | for j in range(self.n): 212 | qt = int(c2q.F[j]/self.T) 213 | rt = c2q.F[j] - qt*self.T 214 | 215 | c2q.F[j] = qt 216 | c2r.F[j] = rt 217 | 218 | c2i.append(c2r) 219 | 220 | c0r = Poly(self.n,self.q,self.qnp) 221 | c1r = Poly(self.n,self.q,self.qnp) 222 | c0r.F = [x for x in c0.F] 223 | c1r.F = [x for x in c1.F] 224 | 225 | for i in range(self.l+1): 226 | c0r = c0r + (self.rlk1[i][0] * c2i[i]) 227 | c1r = c1r + (self.rlk1[i][1] * c2i[i]) 228 | 229 | return [c0r,c1r] 230 | # 231 | def RelinearizationV2(self,ct): 232 | c0 = ct[0] 233 | c1 = ct[1] 234 | c2 = ct[2] 235 | 236 | c2_0 = RefPolMulv2(c2.F,self.rlk2[0].F) 237 | c2_0 = [round(c/self.p) for c in c2_0] 238 | c2_0 = [(c % self.q) for c in c2_0] 239 | 240 | c2_1 = RefPolMulv2(c2.F,self.rlk2[1].F) 241 | c2_1 = [round(c/self.p) for c in c2_1] 242 | c2_1 = [(c % self.q) for c in c2_1] 243 | 244 | c0e = Poly(self.n,self.q,self.qnp); c0e.F = c2_0 245 | c1e = Poly(self.n,self.q,self.qnp); c1e.F = c2_1 246 | 247 | c0r = c0e + c0 248 | c1r = c1e + c1 249 | 250 | return [c0r,c1r] 251 | # 252 | def IntEncode(self,m): # integer encode 253 | mr = Poly(self.n,self.t) 254 | if m >0: 255 | mt = m 256 | for i in range(self.n): 257 | mr.F[i] = (mt % 2) 258 | mt = (mt // 2) 259 | elif m<0: 260 | mt = -m 261 | for i in range(self.n): 262 | mr.F[i] = (self.t-(mt % 2)) % self.t 263 | mt = (mt // 2) 264 | else: 265 | mr = mr 266 | return mr 267 | # 268 | def IntDecode(self,m): # integer decode 269 | mr = 0 270 | thr_ = 2 if(self.t == 2) else ((self.t+1)>>1) 271 | for i,c in enumerate(m.F): 272 | if c >= thr_: 273 | c_ = -(self.t-c) 274 | else: 275 | c_ = c 276 | mr = (mr + (c_ * pow(2,i))) 277 | return mr 278 | # 279 | def HomomorphicAddition(self, ct0, ct1): 280 | ct0_b = ct0[0] + ct1[0] 281 | ct1_b = ct0[1] + ct1[1] 282 | return [ct0_b,ct1_b] 283 | # 284 | def HomomorphicSubtraction(self, ct0, ct1): 285 | ct0_b = ct0[0] - ct1[0] 286 | ct1_b = ct0[1] - ct1[1] 287 | return [ct0_b,ct1_b] 288 | # 289 | def HomomorphicMultiplication(self, ct0, ct1): 290 | ct00 = ct0[0] 291 | ct01 = ct0[1] 292 | ct10 = ct1[0] 293 | ct11 = ct1[1] 294 | 295 | r0 = RefPolMulv2(ct00.F,ct10.F) 296 | r1 = RefPolMulv2(ct00.F,ct11.F) 297 | r2 = RefPolMulv2(ct01.F,ct10.F) 298 | r3 = RefPolMulv2(ct01.F,ct11.F) 299 | 300 | c0 = [x for x in r0] 301 | c1 = [x+y for x,y in zip(r1,r2)] 302 | c2 = [x for x in r3] 303 | 304 | c0 = [((self.t * x) / self.q) for x in c0] 305 | c1 = [((self.t * x) / self.q) for x in c1] 306 | c2 = [((self.t * x) / self.q) for x in c2] 307 | 308 | c0 = [(round(x) % self.q) for x in c0] 309 | c1 = [(round(x) % self.q) for x in c1] 310 | c2 = [(round(x) % self.q) for x in c2] 311 | 312 | # Move to regular modulus 313 | r0 = Poly(self.n,self.q,self.qnp) 314 | r1 = Poly(self.n,self.q,self.qnp) 315 | r2 = Poly(self.n,self.q,self.qnp) 316 | 317 | r0.F = c0 318 | r1.F = c1 319 | r2.F = c2 320 | 321 | return [r0,r1,r2] 322 | --------------------------------------------------------------------------------