├── .gitignore ├── CRTPoly.py ├── Ctxt.py ├── FHE.py ├── NTT.py ├── README.md ├── Test_CRTPoly.py ├── Test_Ctxt.py ├── Test_FHE.py ├── Test_NTT.py └── numTh.py /.gitignore: -------------------------------------------------------------------------------- 1 | misc/ 2 | __pycache__/ 3 | profile/ 4 | -------------------------------------------------------------------------------- /CRTPoly.py: -------------------------------------------------------------------------------- 1 | from NTT import NTT 2 | from sympy.ntheory.modular import crt 3 | import numpy as np 4 | 5 | 6 | class CRTPoly: 7 | """ 8 | Data structure: 9 | crt_poly, prime_set 10 | """ 11 | 12 | def __init__(self, poly=None, primes=None, fft=True, crt=False, N=None): 13 | self.do_fft = fft 14 | if crt: 15 | self.N = N 16 | self.initial_w_crt(poly, primes) 17 | else: 18 | self.N = len(poly) 19 | self.initial_wo_crt(poly, primes) 20 | 21 | def initial_wo_crt(self, poly, primes): 22 | self.prime_set = list(primes) 23 | self.crt_poly = self.crtPoly(poly) 24 | 25 | def initial_w_crt(self, poly, primes): 26 | self.prime_set = list(primes) 27 | self.crt_poly = poly 28 | 29 | def __name__(self): 30 | return "CRTPoly" 31 | 32 | def __str__(self): 33 | return str(self.crt_poly) 34 | 35 | def __add__(self, other): 36 | 37 | assert self.do_fft == other.do_fft 38 | assert self.N == other.N 39 | 40 | add_result = [] 41 | same_size = False 42 | 43 | # check if two poly prime sets are equal 44 | if len(self.crt_poly) > len(other.crt_poly): 45 | small_obj = other 46 | large_obj = self 47 | elif len(self.crt_poly) < len(other.crt_poly): 48 | small_obj = self 49 | large_obj = other 50 | else: 51 | small_obj = self 52 | same_size = True 53 | 54 | for i, poly in enumerate(small_obj.crt_poly): 55 | if self.do_fft: 56 | add_result.append(poly + other.crt_poly[i]) 57 | 58 | else: 59 | _result = np.asarray(poly) + np.asarray(other.crt_poly[i]) 60 | result = np.fmod(_result, self.prime_set[i]) 61 | add_result.append(result.tolist()) 62 | 63 | if not same_size: 64 | prime_set = list(small_obj.prime_set) 65 | N = small_obj.N 66 | else: 67 | prime_set = list(self.prime_set) 68 | N = self.N 69 | 70 | return CRTPoly(add_result, prime_set, self.do_fft, True, N=N) 71 | 72 | def __sub__(self, other): 73 | 74 | assert self.do_fft == other.do_fft 75 | assert self.N == other.N 76 | assert len(self.prime_set) >= len(other.prime_set) 77 | 78 | sub_result = [] 79 | 80 | for i, poly in enumerate(other.crt_poly): 81 | if self.do_fft: 82 | sub_result.append(self.crt_poly[i] - poly) 83 | else: 84 | _result = np.asarray(self.crt_poly[i]) - np.asarray(poly) 85 | result = np.fmod(_result, self.prime_set[i]) 86 | sub_result.append(result.tolist()) 87 | 88 | prime_set = list(other.prime_set) 89 | return CRTPoly(sub_result, prime_set, self.do_fft, True, N=other.N) 90 | 91 | def __mul__(self, other): 92 | 93 | mul_result = [] 94 | 95 | if type(other).__name__ == 'int': 96 | for poly in self.crt_poly: 97 | if self.do_fft: 98 | result = poly * other 99 | else: 100 | result = (np.asarray(poly) * other).tolist() 101 | mul_result.append(result) 102 | prime_set = self.prime_set 103 | else: 104 | assert self.do_fft and other.do_fft 105 | 106 | for i, poly in enumerate(self.crt_poly): 107 | result = poly * other.crt_poly[i] 108 | mul_result.append(result) 109 | if len(self.prime_set) <= len(other.prime_set): 110 | prime_set = self.prime_set 111 | else: 112 | prime_set = self.prime_set 113 | 114 | return CRTPoly(mul_result, prime_set, self.do_fft, True, N=self.N) 115 | 116 | def crtPoly(self, poly, primes=None): 117 | """ 118 | Transform poly to CRT form, then transform each CRT poly to frequency domain, 119 | if do_fft is true. 120 | """ 121 | if primes is None: 122 | primes = self.prime_set 123 | result = [] 124 | for prime in primes: 125 | crt_poly = np.remainder(poly, prime).tolist() 126 | if self.do_fft: 127 | fft_crt_poly = NTT(crt_poly, prime, self.N) 128 | result.append(fft_crt_poly) 129 | else: 130 | result.append(crt_poly) 131 | return result 132 | 133 | def toPoly(self): 134 | """ 135 | CRT poly. 136 | """ 137 | if self.do_fft: 138 | polys = [] 139 | for fft_poly in self.crt_poly: 140 | polys.append(fft_poly.intt()) 141 | 142 | else: 143 | polys = self.crt_poly 144 | 145 | poly = [] 146 | residue_array = np.asarray(polys).T 147 | for residues in residue_array: 148 | coeff, M = crt(self.prime_set, residues) 149 | poly.append(coeff) 150 | 151 | return poly 152 | -------------------------------------------------------------------------------- /Ctxt.py: -------------------------------------------------------------------------------- 1 | from FHE import FHE 2 | from CRTPoly import CRTPoly 3 | import numpy as np 4 | 5 | 6 | class Ctxt(): 7 | """ 8 | """ 9 | 10 | def __init__(self, d, stdev, primes, P, L, cur_level=0, c=None): 11 | self.prime_set = list(primes) 12 | self.prime_set.sort(reverse=True) 13 | del self.prime_set[:cur_level] 14 | self.L = cur_level 15 | self.P = P 16 | self.f = FHE(d, stdev, primes, P, L, cur_level=cur_level) 17 | self.c = c 18 | 19 | def __name__(self): 20 | return 'Ctxt' 21 | 22 | def __add__(self, other): 23 | """ 24 | Homomorphic addition. 25 | Return a new Ctxt object with its ciphertext is c + c' = (c0 + c'0, c1 + c'1) 26 | """ 27 | 28 | assert other.__name__() == 'Ctxt' 29 | 30 | while self.L != other.L: 31 | if self.L < other.L: 32 | self.scaleDown() 33 | else: 34 | other.scaleDown() 35 | 36 | modulus = 1 37 | for prime in self.prime_set: 38 | modulus *= prime 39 | 40 | result0 = (np.asarray(self.c[0]) + np.asarray(other.c[0])) % modulus 41 | result1 = (np.asarray(self.c[1]) + np.asarray(other.c[1])) % modulus 42 | 43 | result = [] 44 | result.append(result0.tolist()) 45 | result.append(result1.tolist()) 46 | 47 | return Ctxt(self.f.d, self.f.stdev, self.f.prime_set, self.P, self.f.L, self.L, result) 48 | 49 | def __mul__(self, other): 50 | """ 51 | Homomorphic multiplication. 52 | Return a new Ctxt object with its ciphertext is c * c' = (c0 * c'0, c0 * c'1 + c'0 * c1, c1 * c'1) 53 | """ 54 | assert other.__name__() == 'Ctxt' 55 | 56 | while self.L != other.L: 57 | if self.L < other.L: 58 | self.scaleDown() 59 | else: 60 | other.scaleDown() 61 | 62 | crt_c10 = CRTPoly(self.c[0], self.prime_set) 63 | crt_c11 = CRTPoly(self.c[1], self.prime_set) 64 | crt_c20 = CRTPoly(other.c[0], other.prime_set) 65 | crt_c21 = CRTPoly(other.c[1], other.prime_set) 66 | 67 | crt_result0 = crt_c10 * crt_c20 68 | crt_result1 = crt_c10 * crt_c21 + crt_c11 * crt_c20 69 | crt_result2 = crt_c11 * crt_c21 70 | 71 | result = [] 72 | result.append(crt_result0.toPoly()) 73 | result.append(crt_result1.toPoly()) 74 | result.append(crt_result2.toPoly()) 75 | 76 | return Ctxt(self.f.d, self.f.stdev, self.f.prime_set, self.P, self.f.L, self.L, result) 77 | 78 | def enc(self, m, pk): 79 | """ 80 | FHE encryption. 81 | """ 82 | self.c = self.f.homoEnc(m, pk) 83 | return self.c 84 | 85 | def dec(self, sk): 86 | """ 87 | FHE decryption. 88 | """ 89 | m = self.f.homoDec(self.c, sk) 90 | return m 91 | 92 | def scaleDown(self): 93 | """ 94 | Modulus switching down to next level. 95 | """ 96 | self.c = self.f.modSwitch(self.c, self.L) 97 | self.L += 1 98 | del self.prime_set[0] 99 | 100 | def relinearize(self, switch_key): 101 | """ 102 | Do key switching and modulus switching. 103 | """ 104 | self.c = self.f.keySwitch(self.c, switch_key) 105 | self.scaleDown() 106 | -------------------------------------------------------------------------------- /FHE.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from CRTPoly import CRTPoly 3 | from numTh import uniform_sample, gauss_sample, hamming_sample, small_sample 4 | 5 | 6 | class FHE: 7 | """ 8 | Implementation of BGV-FHE scheme. 9 | """ 10 | 11 | def __init__(self, d, stdev, primes, P, L, cur_level=0): 12 | """ 13 | Initialize parameters. 14 | L : limitation of homomorphic multiplications 15 | cur_level : homomorphic multiplication times 16 | d : polynomial degree 17 | stdev : standard deviation of gaussian distribution 18 | prime_set : total primes 19 | modulus : product of primes 20 | """ 21 | self.L = L # levels 22 | self.cur_level = cur_level # current level 23 | self.d = d # polynomial degree 24 | self.stdev = stdev # standard deviation 25 | self.prime_set = list(primes) # primes 26 | self.prime_set.sort(reverse=True) 27 | self.special_prime = P 28 | self.modulus = 1 # modulus 29 | for i in range(cur_level, L): 30 | self.modulus *= primes[i] 31 | 32 | def setCoeffs(self, poly, q=None): 33 | """ 34 | Let each coefficient in the polynomial in the range of [-q/2,q/2]. 35 | """ 36 | if q is None: 37 | q = self.modulus 38 | for i, coeff in enumerate(poly): 39 | if coeff > q // 2: 40 | poly[i] -= q 41 | 42 | def secretKeyGen(self, h): 43 | """ 44 | Generate secret key. 45 | sk = (1, s') 46 | """ 47 | secret_key = [] 48 | sk0 = [0] * self.d 49 | sk0[0] = 1 50 | sk1 = hamming_sample(self.d, h) 51 | secret_key.append(sk0) 52 | secret_key.append(sk1) 53 | return secret_key 54 | 55 | def publicKeyGen(self, sk, modulus=None): 56 | """ 57 | Generate public key. 58 | pk = (b, -A'), b = A's'+2e. 59 | """ 60 | prime_set = list(self.prime_set) 61 | if modulus is None: 62 | modulus = self.modulus 63 | else: 64 | prime_set.append(self.special_prime) 65 | 66 | public_key = [] 67 | e = gauss_sample(self.d, self.stdev) 68 | A = uniform_sample(modulus, self.d) 69 | # set coefficients range [-q/2,q/2] 70 | self.setCoeffs(A, modulus) 71 | # CRT-FFT representation 72 | fft_sk1 = CRTPoly(sk[1], prime_set) 73 | fft_A = CRTPoly(A, prime_set) 74 | fft_2e = CRTPoly((2 * np.asarray(e)).tolist(), prime_set) 75 | # b = A's'+2e 76 | fft_b = fft_A * fft_sk1 + fft_2e 77 | # polynomial representation 78 | b = fft_b.toPoly() 79 | # set coefficients range [-q/2,q/2] 80 | self.setCoeffs(b, modulus) 81 | neg_A = (-(np.asarray(A))).tolist() 82 | public_key.append(b) 83 | public_key.append(neg_A) 84 | return public_key 85 | 86 | def switchKeyGen(self, sk): 87 | """ 88 | Generate L-1 switch keys. 89 | Each switch key is in R_Qi, where Qi = P * modulus_i and i is level. 90 | And each switch key is (b + P * s^2, -a), 91 | where b = a * s + 2e, and a is sampled uniformly in [-Qi/2,Qi/2]. 92 | """ 93 | modulus = self.modulus * self.special_prime 94 | prime_set = list(self.prime_set) 95 | prime_set.append(self.special_prime) 96 | switch_keys = [] 97 | switch_key = [] 98 | for i in range(0, self.L - 1): 99 | switch_key = [] 100 | if i != 0: 101 | modulus //= self.prime_set[i - 1] 102 | pk = self.publicKeyGen(sk, modulus) # pk = (a * s + 2e, -a) 103 | # CRT-FFT representation 104 | crt_b = CRTPoly(pk[0], prime_set[i:]) 105 | crt_sk1 = CRTPoly(sk[1], prime_set[i:]) 106 | # key0 = a * s + 2e + P * s^2 107 | crt_switch_key0 = crt_b + crt_sk1 * crt_sk1 * self.special_prime 108 | # polynomial representation 109 | key0 = crt_switch_key0.toPoly() 110 | # set coefficients in [-Q/2, Q/2] 111 | self.setCoeffs(key0, modulus) 112 | # switch key = (b + P * s^2, -a) 113 | switch_key.append(key0) 114 | switch_key.append(pk[1]) 115 | switch_keys.append(switch_key) 116 | return switch_keys 117 | 118 | def homoEnc(self, m, pk): 119 | """ 120 | FHE encryption: 121 | c = (c0, c1) 122 | c0 = pk0 * r + 2e0 + m 123 | c1 = pk1 * r + 2e1 124 | """ 125 | r = small_sample(self.d) 126 | e0 = gauss_sample(self.d, self.stdev) 127 | e1 = gauss_sample(self.d, self.stdev) 128 | if len(m) < self.d: 129 | m += [0] * (self.d - len(m)) 130 | # CRT-FFT representation 131 | crt_m = CRTPoly(m, self.prime_set) 132 | crt_pk0 = CRTPoly(pk[0], self.prime_set) 133 | crt_pk1 = CRTPoly(pk[1], self.prime_set) 134 | crt_r = CRTPoly(r, self.prime_set) 135 | crt_2e0 = CRTPoly((2 * np.asarray(e0)).tolist(), self.prime_set) 136 | crt_2e1 = CRTPoly((2 * np.asarray(e1)).tolist(), self.prime_set) 137 | # c0 = pk0 * r + 2e0 + m, c1 = pk1 * r + 2e1 138 | crt_c0 = crt_m + crt_2e0 139 | crt_c1 = crt_2e1 140 | crt_c0 += crt_pk0 * crt_r 141 | crt_c1 += crt_pk1 * crt_r 142 | # polynomial representation 143 | c0 = crt_c0.toPoly() 144 | c1 = crt_c1.toPoly() 145 | # set coefficients in [-q/2,q/2] 146 | self.setCoeffs(c0) 147 | self.setCoeffs(c1) 148 | c = [] 149 | c.append(c0) 150 | c.append(c1) 151 | return c 152 | 153 | def homoDec(self, c, sk): 154 | """ 155 | FHE decryption: 156 | m = (c0 + c1 * s') mod 2 157 | """ 158 | # CRT-FFT representation 159 | crt_c0 = CRTPoly(c[0], self.prime_set[self.cur_level:]) 160 | crt_c1 = CRTPoly(c[1], self.prime_set[self.cur_level:]) 161 | crt_sk1 = CRTPoly(sk[1], self.prime_set[self.cur_level:]) 162 | # m = [[] mod p] mod 2 163 | crt_m = crt_c0 + crt_c1 * crt_sk1 164 | # polynomial represenatation 165 | m = crt_m.toPoly() 166 | # set coefficients range [-q/2,q/2] 167 | self.setCoeffs(m) 168 | return np.remainder(m, 2).tolist() 169 | 170 | def scale(self, c, from_q, to_q): 171 | """ 172 | Change the modulus. 173 | c = p_t * qoutient + rem 174 | odd number coefficients in rem +- p_t (i.e +- 1 becomes even) 175 | _c = p_t * (qoutient +- 1) + rem 176 | the coefficient is even or odd is effected by rem 177 | result = (c - _c) / p_t = quotient +- 1 178 | """ 179 | p_t = from_q // to_q 180 | _c = np.asarray(c) % p_t 181 | for i, _c_i in enumerate(_c): 182 | self.setCoeffs(_c_i, p_t) 183 | for j, coeff in enumerate(_c_i): 184 | if coeff % 2 == 1: 185 | if coeff > 0: 186 | _c[i][j] -= p_t 187 | else: 188 | _c[i][j] += p_t 189 | c_dagger = np.asarray(c) - _c 190 | result = c_dagger // p_t 191 | return np.remainder(result, to_q,).tolist() 192 | 193 | def modSwitch(self, c, level): 194 | """ 195 | Scale modulus down. c' is closest to c/p. 196 | And new c' must satisfy c' = c mod 2. 197 | """ 198 | assert level < self.L - 1, "cannot reduce noise" 199 | # new modulus 200 | to_modulus = self.modulus // self.prime_set[level] 201 | # scale(c, q, q') 202 | result = self.scale(c, self.modulus, to_modulus) 203 | # down to new modulus 204 | self.modulus = to_modulus 205 | # increase level 206 | self.cur_level += 1 207 | return np.remainder(result, self.modulus).tolist() 208 | 209 | def keySwitch(self, c, switch_key): 210 | """ 211 | Key switching. 212 | """ 213 | modulus = self.modulus * self.special_prime 214 | prime_set = list(self.prime_set[self.cur_level:]) 215 | prime_set.append(self.special_prime) 216 | # CRT-FFT representation 217 | crt_c0 = CRTPoly(c[0], prime_set) 218 | crt_c1 = CRTPoly(c[1], prime_set) 219 | crt_c2 = CRTPoly(c[2], prime_set) 220 | crt_b = CRTPoly(switch_key[0], prime_set) 221 | crt_a = CRTPoly(switch_key[1], prime_set) 222 | # c'0 = P * c0 + b * c2, c'1 = P * c1 + a * c2 223 | crt_result0 = crt_c0 * self.special_prime + crt_b * crt_c2 224 | crt_result1 = crt_c1 * self.special_prime + crt_a * crt_c2 225 | # polynomial representation 226 | result0 = crt_result0.toPoly() 227 | result1 = crt_result1.toPoly() 228 | # set coefficients in [-Q/2, Q/2] 229 | self.setCoeffs(result0, modulus) 230 | self.setCoeffs(result1, modulus) 231 | 232 | result = [] 233 | result.append(result0) 234 | result.append(result1) 235 | # scale(c', Q, q) 236 | result = self.scale(result, modulus, self.modulus) 237 | return result 238 | -------------------------------------------------------------------------------- /NTT.py: -------------------------------------------------------------------------------- 1 | ############################################################ 2 | # # 3 | # File Name: NTT.py # 4 | # # 5 | # Author: Jyun-Neng Ji (jyunnengji@gmail.com) # 6 | # # 7 | # Creation Date: 2018/02/15 # 8 | # # 9 | # Last Modified: 2018/03/25 # 10 | # # 11 | # Description: Implement Cooley-Tukey FFT in finite field. # 12 | # # 13 | ############################################################ 14 | from numTh import * 15 | from sympy.core.numbers import mod_inverse 16 | from sympy.ntheory import sqrt_mod 17 | 18 | 19 | class NTT: 20 | def __init__(self, poly, M, N, ideal=True, ntt=False, w=None, phi=None): 21 | """ 22 | Initialize the parameters for NTT 23 | and transfer poly into frequency domain if ntt is True. 24 | Parameters: mod(modulus), N(NTT points), w(root of unity), phi(square root of unity), 25 | ideal(ideal ring), fft_poly(poly in freq domain). 26 | """ 27 | if ntt: # poly is in frequency domain 28 | self.initial_w_ntt(poly, M, N, ideal, w, phi) 29 | else: # poly is not in frequency domain 30 | self.initial_wo_ntt(poly, M, N, ideal, w, phi) 31 | 32 | def initial_wo_ntt(self, poly, M, N, ideal, w, phi): 33 | self.mod = M # modulus 34 | self.N = N # N points NTT 35 | if w is None and phi is None: # generate root of unity and square root of unity 36 | self.w = findPrimitiveNthRoot(M, N) # primitive Nth root of unity 37 | self.phi = sqrt_mod(self.w, M) # phi: w^2 = 1 (mod M) 38 | else: 39 | self.w = w 40 | self.phi = phi 41 | self.ideal = ideal 42 | # If it is computed over Z[x]/, it should multiply phi 43 | if ideal: 44 | poly_bar = self.mulPhi(poly) 45 | else: 46 | poly_bar = poly 47 | self.fft_poly = self.ntt(poly_bar) 48 | 49 | def initial_w_ntt(self, poly, M, N, ideal, w, phi): 50 | self.mod = M # modulus 51 | self.N = N # N points NTT 52 | self.w = w # primitive Nth root of unity 53 | self.phi = phi # phi: w^2 = 1 (mod M) 54 | self.ideal = ideal 55 | self.fft_poly = poly 56 | 57 | def __name__(self): 58 | return "NTT" 59 | 60 | def __str__(self): 61 | poly = " ".join(str(coeff) for coeff in self.fft_poly) 62 | return "NTT points [" + poly + "] modulus " + str(self.mod) 63 | 64 | def __mul__(self, other): 65 | """ 66 | Multiply in frequency domain. 67 | """ 68 | assert type(other).__name__ == 'NTT' or type(other).__name__ == 'int', 'type error' 69 | 70 | 71 | if type(other).__name__ == 'int': 72 | mul_result = self.mulConstant(other) 73 | 74 | else: 75 | assert self.N == other.N, "points different" 76 | assert self.mod == other.mod, "modulus different" 77 | assert self.ideal == other.ideal 78 | 79 | mul_result = [] 80 | for i, point in enumerate(self.fft_poly): 81 | mul_result.append((point * other.fft_poly[i]) % self.mod) 82 | 83 | return NTT(mul_result, self.mod, self.N, self.ideal, True, self.w, self.phi) 84 | 85 | def __add__(self, other): 86 | """ 87 | Addition in frequency domain 88 | """ 89 | assert self.N == other.N, 'points different' 90 | assert self.mod == other.mod, 'modulus different' 91 | assert self.ideal == other.ideal 92 | 93 | add_result = [] 94 | for i, point in enumerate(self.fft_poly): 95 | add_result.append((point + other.fft_poly[i]) % self.mod) 96 | 97 | return NTT(add_result, self.mod, self.N, self.ideal, True, self.w, self.phi) 98 | 99 | def __sub__(self, other): 100 | """ 101 | Substraction in frequency domain 102 | """ 103 | assert self.N == other.N, 'points different' 104 | assert self.mod == other.mod, 'modulus different' 105 | assert self.ideal == other.ideal 106 | 107 | sub_result = [] 108 | for i, point in enumerate(self.fft_poly): 109 | sub_result.append((point - other.fft_poly[i]) % self.mod) 110 | 111 | return NTT(sub_result, self.mod, self.N, self.ideal, True, self.w, self.phi) 112 | 113 | def bitReverse(self, num, len): 114 | """ 115 | Reverse bits of a number. 116 | 117 | Examples 118 | ======== 119 | 120 | >>> NTT.bitReverse(3, 3) # 3(011) 121 | 6 # 6(110) 122 | """ 123 | rev_num = 0 124 | 125 | for i in range(0, len): 126 | 127 | if (num >> i) & 1: 128 | rev_num |= 1 << (len - 1 - i) 129 | 130 | return rev_num 131 | 132 | def orderReverse(self, poly, N_bit): 133 | """ 134 | Change the order of coefficients of polynomial to fit DIT FFT input. 135 | 136 | Examples 137 | ======== 138 | 139 | >>> NTT.orderReverse([1, 2, 3, 4], 2) 140 | [1, 3, 2, 4] 141 | """ 142 | _poly = list(poly) 143 | for i, coeff in enumerate(_poly): 144 | 145 | rev_i = self.bitReverse(i, N_bit) 146 | 147 | if rev_i > i: 148 | coeff ^= _poly[rev_i] 149 | _poly[rev_i] ^= coeff 150 | coeff ^= _poly[rev_i] 151 | _poly[i] = coeff 152 | 153 | return _poly 154 | 155 | def ntt(self, poly, w=None): 156 | """ 157 | Compute FFT in finite feild. 158 | Use Cooley-Tukey DIT FFT algorithm. 159 | input: polynomial, primitive nth root of unity 160 | output FFT(poly) 161 | complexity: O(N/2 log N) 162 | """ 163 | 164 | if w is None: 165 | w = self.w 166 | 167 | N_bit = self.N.bit_length() - 1 168 | rev_poly = self.orderReverse(poly, N_bit) 169 | 170 | for i in range(0, N_bit): 171 | 172 | points1, points2 = [], [] 173 | 174 | for j in range(0, int(self.N / 2)): 175 | shift_bits = N_bit - 1 - i 176 | P = (j >> shift_bits) << shift_bits 177 | w_P = pow(w, P, self.mod) 178 | odd = rev_poly[2 * j + 1] * w_P 179 | even = rev_poly[2 * j] 180 | points1.append((even + odd) % self.mod) 181 | points2.append((even - odd) % self.mod) 182 | points = points1 + points2 183 | 184 | if i != N_bit: 185 | rev_poly = points 186 | 187 | return points 188 | 189 | def intt(self): 190 | """ 191 | Compute IFFT in finite feild. 192 | The algorithm is the same as NTT, but it need to change w to w^(-1) 193 | and multiply N^(-1) for each coefficient of polynomial. 194 | input: FFT(poly), primitive nth root of unity 195 | output: polynomial 196 | complexity: (N/2 log N) 197 | """ 198 | 199 | inv_w = mod_inverse(self.w, self.mod) 200 | inv_N = mod_inverse(self.N, self.mod) 201 | 202 | poly = self.ntt(self.fft_poly, inv_w) 203 | 204 | for i in range(0, self.N): 205 | poly[i] = poly[i] * inv_N % self.mod 206 | 207 | # if it is computed over Z[x]/, it should multiply phi^(-1) 208 | if self.ideal: 209 | inv_phi = mod_inverse(self.phi, self.mod) 210 | poly = self.mulPhi(poly, inv_phi) 211 | 212 | return poly 213 | 214 | def mulPhi(self, poly, phi=None): 215 | """ 216 | The coefficients in polynomial multiply phi^i mod modulus 217 | a_i_bar = a_i * phi^i (mod m) 218 | 219 | It's used before NTT multiplication over Zp/ 220 | """ 221 | 222 | if phi is None: 223 | phi = self.phi 224 | 225 | poly_bar = list(poly) 226 | 227 | for i, coeff in enumerate(poly): 228 | poly_bar[i] = (poly[i] * pow(phi, i, self.mod)) % self.mod 229 | 230 | return poly_bar 231 | 232 | def mulConstant(self, constant): 233 | mul_result = [] 234 | for coeff in self.fft_poly: 235 | result = coeff * constant % self.mod 236 | mul_result.append(result) 237 | return mul_result 238 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## PyFHE 2 | PyFHE is an implementation of the [Brakerski-Gentry-Vaikuntanathan][1] (BGV) scheme, along with some optimizations described in the [Gentry-Halevi-Smart][2] optimizations. This project takes [HElib][3] as reference but is simpler than HElib. It only implements some basic homomorphic operations such as multiplication and addition and does not use ciphertext packing techniques. 3 | 4 | This project is written in Python 3.6.5 and uses Python packages [sympy][4] and [numpy][5]. 5 | 6 | ## Demo 7 | Run `Test_Ctxt.py`, which executes a 4 levels 4096 degree polynomial BGV encryption. 8 | ``` 9 | > python Test_Ctxt.py 10 | ``` 11 | ## TODO 12 | Use multiprocessing. 13 | 14 | [1]: http://eprint.iacr.org/2011/277 "BGV12" 15 | [2]: http://eprint.iacr.org/2012/099 "GHS12" 16 | [3]: https://github.com/shaih/HElib "HELIB" 17 | [4]: http://www.sympy.org/en/index.html "SYMPY" 18 | [5]: http://www.numpy.org "NUMPY" 19 | -------------------------------------------------------------------------------- /Test_CRTPoly.py: -------------------------------------------------------------------------------- 1 | from CRTPoly import CRTPoly 2 | from numTh import * 3 | import numpy as np 4 | 5 | poly = [12, 33, 14, 53] 6 | poly2 = [8, 7, 6, 7] 7 | primes = [7, 11, 13] 8 | crt_poly = CRTPoly(poly, primes, fft=False) 9 | crt_poly2 = CRTPoly(poly2, primes, fft=False) 10 | print(crt_poly) 11 | mul_const_result = crt_poly2 * 300 12 | add_result = crt_poly2 + crt_poly 13 | sub_result = crt_poly - crt_poly2 14 | print(add_result.toPoly()) 15 | print(sub_result.toPoly()) 16 | print(mul_const_result.toPoly()) 17 | 18 | -------------------------------------------------------------------------------- /Test_Ctxt.py: -------------------------------------------------------------------------------- 1 | from Ctxt import Ctxt 2 | from FHE import FHE 3 | from numTh import * 4 | import numpy as np 5 | import cProfile as cp 6 | import time 7 | 8 | def setup(d, stdev, bits_per_level, P_bits, L): 9 | primes, total_bits = findPrimes(bits_per_level, d, L) 10 | special_primes, total_bits = findPrimes(P_bits, d, 1) 11 | P = special_primes[0] 12 | f = FHE(d, stdev, primes, P, L) 13 | return f, primes, P 14 | 15 | def keyGen(f, h): 16 | print('-----Secret Key Generation-----') 17 | sk = f.secretKeyGen(h) 18 | print('-----Public Key Generation-----') 19 | pk = f.publicKeyGen(sk) 20 | print('-----Switch Key Generation-----') 21 | swk = f.switchKeyGen(sk) 22 | return sk, pk, swk 23 | 24 | def set_timer_profile(): 25 | pr = cp.Profile() 26 | pr.enable() 27 | start = time.time() 28 | return pr, start 29 | 30 | def end_timer_profile(pr, start, filename): 31 | cost = time.time() - start 32 | pr.disable() 33 | pr.dump_stats(filename) 34 | return cost 35 | 36 | def hex_rep(m): 37 | binary_str = ''.join(str(bit) for bit in m) 38 | return hex(int(binary_str, 2)) 39 | 40 | key_profile = './profile/keygeneration.prof' 41 | enc_profile = './profile/encrytpion.prof' 42 | dec_profile = './profile/decryption.prof' 43 | l0_mul_profile = './profile/l0_mul.prof' 44 | l1_mul_profile = './profile/l1_mul.prof' 45 | add_profile = './profile/add.prof' 46 | f, primes, P = setup(4096, 3.2, 22, 49, 4) 47 | 48 | # print('-----plaintext-----') 49 | msg1 = np.random.randint(0, 2, 4096).tolist() 50 | print(hex_rep(msg1)) 51 | pr, start = set_timer_profile() 52 | 53 | sk, pk, swk = keyGen(f, 64) 54 | cost = end_timer_profile(pr, start, key_profile) 55 | print('Done {0: .3f}s'.format(cost)) 56 | print('-----Encryption-----') 57 | c1 = Ctxt(4096, 3.2, primes, P, 4) 58 | 59 | pr, start = set_timer_profile() 60 | # message encryption 61 | c1.enc(msg1, pk) 62 | 63 | cost = end_timer_profile(pr, start, enc_profile) 64 | 65 | print('Done {0: .3f}s'.format(cost)) 66 | print('-----Homomorphic Operation-----') 67 | pr, start = set_timer_profile() 68 | 69 | c_mul = c1 * c1 70 | c_mul.relinearize(swk[0]) 71 | 72 | cost = end_timer_profile(pr, start, l0_mul_profile) 73 | print('First multiplication done {0: .3f}s'.format(cost)) 74 | pr, start = set_timer_profile() 75 | 76 | c_mul *= c_mul 77 | c_mul.relinearize(swk[1]) 78 | 79 | cost = end_timer_profile(pr, start, l1_mul_profile) 80 | print('Second multiplication done {0: .3f}s'.format(cost)) 81 | 82 | pr, start = set_timer_profile() 83 | 84 | c_add = c1 + c1 85 | 86 | cost = end_timer_profile(pr, start, add_profile) 87 | print('Addition done {0: .3f}s'.format(cost)) 88 | 89 | print('-----Decryption-----') 90 | pr, start = set_timer_profile() 91 | 92 | m1 = c1.dec(sk) 93 | 94 | cost = end_timer_profile(pr, start, dec_profile) 95 | print('Done {0: .3f}s'.format(cost)) 96 | mul_result = c_mul.dec(sk) 97 | add_result = c_add.dec(sk) 98 | print(msg1 == m1) 99 | #print('msg1 ', m1) 100 | #print('msg1^4 = ', mul_result) 101 | #print('msg1 + msg1 = ', add_result) 102 | -------------------------------------------------------------------------------- /Test_FHE.py: -------------------------------------------------------------------------------- 1 | from FHE import FHE 2 | from CRTPoly import CRTPoly 3 | from numTh import findPrimes 4 | import numpy as np 5 | 6 | 7 | def multiply(c1, c2, primes): 8 | result = [] 9 | fft_c10 = CRTPoly(c1[0], primes) 10 | fft_c11 = CRTPoly(c1[1], primes) 11 | fft_c20 = CRTPoly(c2[0], primes) 12 | fft_c21 = CRTPoly(c2[1], primes) 13 | fft_result0 = fft_c10 * fft_c20 14 | fft_result1 = fft_c10 * fft_c21 + fft_c11 * fft_c20 15 | fft_result2 = fft_c11 * fft_c21 16 | result.append(fft_result0.toPoly()) 17 | result.append(fft_result1.toPoly()) 18 | result.append(fft_result2.toPoly()) 19 | return result 20 | 21 | 22 | def polyMul(p1, p2, primes): 23 | fft_p1 = CRTPoly(p1, primes) 24 | fft_p2 = CRTPoly(p2, primes) 25 | modulus = 1 26 | for prime in primes: 27 | modulus *= prime 28 | fft_result = fft_p1 * fft_p2 29 | result = fft_result.toPoly() 30 | for i, coeff in enumerate(result): 31 | if coeff > modulus // 2: 32 | result[i] -= modulus 33 | return np.remainder(result, 2).tolist() 34 | 35 | 36 | poly_degree = 4096 37 | stdev = 3.2 38 | L = 4 39 | #primes = [549755860993, 549755873281, 549755876353] 40 | primes, bits = findPrimes(22, 4096, 4) 41 | a, bits = findPrimes(10, 4096, 1) 42 | P = a[0] 43 | # primes = [521, 569, 577] 44 | modulus = 1 45 | for prime in primes: 46 | modulus *= prime 47 | f = FHE(poly_degree, stdev, primes, P, L) 48 | sk = f.secretKeyGen(64) 49 | # sk = [[1, 0, 0, 0], [0, 1, -1, 0]] 50 | pk = f.publicKeyGen(sk) 51 | # pk = [[-24187115, -62847359, 2213875, 53855074], [-13973837, -16187706, -70042772, 76821192]] 52 | switch_keys = f.switchKeyGen(sk) 53 | m = np.random.randint(0, 2, 4096).tolist() 54 | m1 = np.random.randint(0, 2, 4096).tolist() 55 | print('plaintext') 56 | # print(m) 57 | # print(m1) 58 | print('Encryption') 59 | c = f.homoEnc(m, pk) 60 | print('homo Multiply') 61 | mul_result = multiply(c, c, primes) 62 | mul_result = f.keySwitch(mul_result, switch_keys[0]) 63 | mul_result = f.modSwitch(mul_result, 0) 64 | dec_mul_result = f.homoDec(mul_result, sk) 65 | print('Decrypt mul result') 66 | #print(polyMul(m, m, primes)) 67 | # print(dec_mul_result) 68 | print(dec_mul_result == polyMul(m, m, primes)) 69 | dec_mm = f.homoDec(c, sk) 70 | print('Decrypt m') 71 | # print(dec_mm) 72 | c1 = f.homoEnc(m1, pk) 73 | """ 74 | print('Modulus Switching') 75 | c = f.modSwitch(c, 0) 76 | c = f.modSwitch(c, 1) 77 | """ 78 | dec_mm = f.homoDec(c, sk) 79 | print('Decrypt m') 80 | print(dec_mm == m) 81 | new_c = ((np.asarray(c1) + np.asarray(c)) % f.modulus).tolist() 82 | print('Decryption') 83 | #dec_m = f.homoDec(new_c,sk) 84 | 85 | 86 | # if m == dec_m: 87 | # print('success') 88 | # else: 89 | # print('fail') 90 | #m = [] 91 | # for bit in dec_m: 92 | # m.append(int(bit)) 93 | # print(m) 94 | -------------------------------------------------------------------------------- /Test_NTT.py: -------------------------------------------------------------------------------- 1 | from NTT import NTT 2 | from numTh import * 3 | import numpy as np 4 | 5 | poly1 = [1,0,1,1] 6 | poly2 = [1,0,1,1] 7 | modulus = 17 8 | N = 4 9 | fft_poly1 = NTT(poly1, modulus, N) 10 | fft_poly2 = NTT(poly2, modulus, N) 11 | print(type(fft_poly1).__name__) 12 | mult_result = fft_poly1 * fft_poly2 13 | add_result = fft_poly1 + fft_poly2 14 | mult_const_result = fft_poly1 * 2 15 | sub_result = fft_poly1 - fft_poly2 16 | print('poly1=', poly1) 17 | print('poly2=', poly2) 18 | print('multiply:', mult_result.intt()) 19 | print('multiply 2:', mult_const_result.intt()) 20 | print('addition:', add_result.intt()) 21 | print('substraction:', sub_result.intt()) 22 | 23 | -------------------------------------------------------------------------------- /numTh.py: -------------------------------------------------------------------------------- 1 | ############################################################ 2 | # # 3 | # File Name: numTh.py # 4 | # # 5 | # Author: Jyun-Neng Ji (jyunnengji@gmail.com) # 6 | # # 7 | # Creation Date: 2018/02/15 # 8 | # # 9 | # Last Modified: 2018/04/12 # 10 | # # 11 | # Description: Number theory library. # 12 | # # 13 | ############################################################ 14 | import math 15 | import random 16 | import numpy as np 17 | from sympy import isprime, nextprime 18 | from sympy.ntheory.residue_ntheory import nthroot_mod 19 | 20 | 21 | def findPrimes(prime_bit, N, num): 22 | """ 23 | Give prime bits and number of primes, 24 | then generate primes, which congruence to 1 modular 2N. 25 | 26 | Examples 27 | ======== 28 | 29 | >>> from numTh import findPrimes 30 | >>> findPrimes(12, 4, 5) 31 | [2081, 2089, 2113, 2129, 2137], 60 32 | """ 33 | primes = [] 34 | total_bits = 0 35 | prime = pow(2, prime_bit - 1) 36 | while len(primes) != num: 37 | prime = nextprime(prime) 38 | if prime % (2 * N) == 1: 39 | primes.append(prime) 40 | total_bits += prime.bit_length() 41 | 42 | return primes, total_bits 43 | 44 | 45 | def findPrimitiveNthRoot(M, N): 46 | """ 47 | Generate the smallest primitive Nth root of unity (expect 1). 48 | Find w s.t w^N = 1 mod M and there are not other numbers k (k < N) 49 | s.t w^k = 1 mod M 50 | 51 | """ 52 | roots = nthroot_mod(1, N, M, True)[1:] # find Nth root of unity 53 | for root in roots: # find primitive Nth root of unity 54 | is_primitive = True 55 | for k in range(1, N): 56 | if pow(root, k, M) == 1: 57 | is_primitive = False 58 | if is_primitive: 59 | return root 60 | return None 61 | 62 | 63 | def isPrimitiveNthRoot(M, N, beta): 64 | """ 65 | verify B^N = 1 (mod M) 66 | """ 67 | return pow(beta, N, M) == 1 # modular(M).modExponent(beta, N) == 1 68 | 69 | 70 | def uniform_sample(upper, num): 71 | """ 72 | Sample num values uniformly between [0,upper). 73 | """ 74 | sample = [] 75 | for i in range(num): 76 | value = random.randint(0, upper - 1) 77 | sample.append(value) 78 | return sample 79 | 80 | 81 | def gauss_sample(num, stdev): 82 | """ 83 | Sample num values from gaussian distribution 84 | mean = 0, standard deviation = stdev 85 | """ 86 | sample = np.random.normal(0, stdev, num) 87 | sample = sample.round().astype(int) 88 | return sample 89 | 90 | 91 | def hamming_sample(num, hwt): 92 | """ 93 | Sample a vector uniformly at random from -1, 0, +1, 94 | subject to the condition that it has exactly hwt nonzero entries. 95 | """ 96 | i = 0 97 | sample = [0] * num 98 | while i < hwt: 99 | degree = random.randint(0, num - 1) 100 | if sample[degree] == 0: 101 | coeff = random.randint(0, 1) 102 | if coeff == 0: 103 | coeff = -1 104 | sample[degree] = coeff 105 | i += 1 106 | return sample 107 | 108 | 109 | def small_sample(num): 110 | """ 111 | Sample vectors with entires -1, 0, +1. 112 | Each element is 0 with probabilty 0.5 and +-1 with probabilty 0.25. 113 | """ 114 | sample = [0] * num 115 | for i in range(num): 116 | u = random.randint(0, 3) 117 | if u == 3: 118 | sample[i] = -1 119 | if u == 2: 120 | sample[i] = 1 121 | return sample 122 | 123 | # Not used 124 | 125 | 126 | class modular: 127 | def __init__(self, M): 128 | self.mod = M 129 | self.M_bit = M.bit_length() 130 | self.u = (1 << (2 * self.M_bit)) // M 131 | 132 | def modReduce(self, x): 133 | """ 134 | Barrett modular reduction algorithm. 135 | Compute x mod M. M is initialized by modular class. 136 | 137 | Examples 138 | ======== 139 | 140 | >>> from numTh import numTh 141 | >>> modular(11).modReduce(12) 142 | 1 143 | """ 144 | 145 | assert 0 <= x < pow(self.mod, 2), 'out of range.' 146 | q = (x * self.u) >> (2 * self.M_bit) 147 | r = x - q * self.mod 148 | while r >= self.mod: 149 | r -= self.mod 150 | return r 151 | 152 | def modReducem(self, x, M): 153 | """ 154 | Barrett modular reduction algorithm. 155 | Compute x mod M. M can be redefined. 156 | 157 | Examples 158 | ======== 159 | 160 | >>> from numTh import numTh 161 | >>> modular(5).modReducem(12, 11) 162 | 1 163 | """ 164 | tmp_mod, tmp_M_bit, tmp_u = self.mod, self.M_bit, self.u 165 | self.mod = M 166 | self.M_bit = M.bit_length() 167 | self.u = (1 << (2 * self.M_bit)) // M 168 | r = self.modReduce(x) 169 | # return initial modular, bit size of modular and precompute u 170 | self.mod, self.M_bit, self.u = tmp_mod, tmp_M_bit, tmp_u 171 | return r 172 | 173 | def modInv(self, x): 174 | """ 175 | Calculate modular inverse. 176 | 177 | Examples 178 | ======== 179 | 180 | >>> from numTh import modular 181 | >>> modular(5).modInv(3) 182 | 2 183 | """ 184 | t, new_t, r, new_r = 0, 1, self.mod, x 185 | 186 | while new_r != 0: 187 | q = r // new_r 188 | r, new_r = new_r, (r % new_r) 189 | t, new_t = new_t, (t - q * new_t) 190 | assert r <= 1, 'x is not invertible' 191 | return t if t > 0 else t + self.mod 192 | 193 | # Slower than pow(base, power, modulus) 194 | def modExponent(self, base, power): 195 | """ 196 | Modular exponentiation algorithm. 197 | It's a fast method to compute a^b mod p 198 | 199 | Examples 200 | ======== 201 | 202 | >>> from numTh import modular 203 | >>> modular(2013265921).modExponent(1003203377, 2048) 204 | 1 205 | """ 206 | result = 1 207 | power = int(power) 208 | base = base % self.mod 209 | while power > 0: 210 | if power & 1: 211 | # self.modReduce(result * base) 212 | result = result * base % self.mod 213 | base = base * base % self.mod # self.modReduce(base * base) 214 | power = power >> 1 215 | return result 216 | --------------------------------------------------------------------------------