├── LICENSE.md ├── README.md ├── auxiliary_functions.py ├── bin_capacity_estimator.py ├── client_offline.py ├── client_online.py ├── cuckoo_hash.py ├── oprf.py ├── parameters.py ├── requirements.txt ├── server_offline.py ├── server_online.py ├── set_gen.py └── simple_hash.py /LICENSE.md: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Bitdefender Machine Learning 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | We implemented in Python a **Private Set Intersection (PSI)** protocol, a functionality that allows two parties to *privately join their sets* in order to compute their *common elements*. In our setup, these parties are: 2 | ​ 3 | * a *server* having a large database 4 | * a *client* who would like to *privately* query the database. 5 | ​ 6 | ## How it works 7 | Our implementation is based on this [paper](https://eprint.iacr.org/2017/299.pdf) and its [follow-up](https://eprint.iacr.org/2018/787.pdf). The protocol uses **homomorphic encryption**, a cryptographic primitive that allows *computations on encrypted data* in such a way that only *the secret key holder has access to the decryption of the result of these computations*. For implementing PSI, we used the [FV] homomorphic encryption scheme from the [TenSEAL](https://github.com/OpenMined/TenSEAL) library. You can also check out a concurrent [SEAL](https://github.com/microsoft/SEAL)-based [C++ implementation](https://github.com/microsoft/APSI) of the protocol that has been recently published by Microsoft. 8 | 9 | **Disclaimer:** Our implementation is not meant for production use. Use it at your own risk. 10 | ​ 11 | ### Main idea of the protocol 12 | ​ 13 | Suppose the client wants to check if his query *x* belongs to the database of the server. Consider this database having as entries the integers *y_1,...,y_n*. Then the server can associate to its database the following polynomial *P(X) = (X-y_1) * ... * (X-y_n).* If *x* belongs to the database, then *P* vanishes at *x*. This is the **main idea of the protocol**: the server should compute this evaluation of *P* at *x* in a secure way and send it to the client. 14 | ​ 15 | 16 | More precisely, the client sends his query encrypted with a homomorphic encryption scheme, *Enc(x)*. Then the server evaluates *P* at the given encryption: due to homomorphic properties, the result will turn out to be *Enc(P(x)).* Then the client decrypts this result and checks if it is equal to 0; in case of equality, the query belongs to the database. 17 | ​ 18 | 19 | The protocol is split in two parts: offline phase and online phase. The online phase starts when the client performs the OPRF protocol with the server, for encoding its items with the server's secret key. 20 | ​ 21 | ### The preprocessing phase 22 | In this phase, both the server and the client preprocess their datasets, until performing the PSI protocol. 23 | ​ 24 | * **Oblivious PRF**: Both the server and the client engage in a Diffie-Hellman-like protocol in order to apply Oblivious PRF to their datasets. 25 | * The server embeds his database entries in points on an elliptic curve, multiplies them by a secret key, ```oprf_server_key``` and considers ```sigma_max``` bits out of the first coordinate of these points. 26 | * The client also embeds his entries in points on an elliptic curve, multiplies them by a secret key, ```oprf_client_key``` and sends them to server. 27 | * The server multiplies the client's points by ```oprf_server_key``` and sends them back to client. Now the client's points are ```oprf_server_key``` * ```oprf_client_key``` * item. 28 | * The client multiplies the received points by the inverse of ```oprf_client_key``` and takes ```sigma_max``` bits out of the first coordinate of these points. 29 | * After this step, both the server and the client have new datasets, each of ```sigma_max```-bit integers. 30 | ​ 31 | * **Hashing**: We used three Murmur hash functions for mapping items of the client and of the server into ```number_of_bins``` bins: 32 | * The client performs *Cuckoo hashing*; each of his bins has 1 element. 33 | * The server performs *simple hashing*, each of his bins has ```bin_capacity``` elements. 34 | Hence, the PSI protocol can be therefore performed per each bin. (a padding step might be required to get the bins full.) 35 | ​ 36 | * **Partitioning**: This helps the server evaluate polynomials of lower degree on encrypted data. 37 | * The server partitions each bin into ```alpha``` minibins, having ```bin_capacity```/```alpha``` elements. 38 | Hence, performing the PSI protocol on each bin is split as performing ```alpha``` PSI protocols for each minibin. 39 | ​ 40 | * **Computing coefficients of the polynomials**: this is applied for each minibin 41 | * The server computes the coefficients of the polynomials that vanish at all the elements of the minibin. 42 | Hence each minibin is represented by ```bin_capacity```/```alpha``` + 1 coefficients. 43 | 44 | ### The actual **PSI** protocol 45 | ​ 46 | In this phase, both the server and the client perform the actual PSI protocol. The encryption scheme used, [FV], allows encrypting messages as polynomials of degree less than ```poly_modulus_degree```, which is a power of 2, with integer coefficients modulo ```plain_modulus```. This modulus is chosen so that it is a prime congruent with 1 modulo 2 * ```poly_modulus_degree```, which helps identifying each such polynomial with a vector of ```poly_modulus_degree``` integer entries modulo ```plain_modulus```. [TenSEAL](https://github.com/OpenMined/TenSEAL/blob/master/tutorials%2FTutorial%200%20-%20Getting%20Started.ipynb) allows encryption of **vectors of integers**, by first performing the above correspondence and then performing the actual encryption. Also, in a similar way, decrypting in TenSEAL works for **vectors of integers**. The encryption scheme implemented in TenSEAL benefits from allowing to encrypt *vectors of integers*, by performing both encoding and encryption. Also, decrypting in TenSEAL leads to *vectors of integers*, the encodings of the corresponding decryptions. 47 | ​ 48 | * **Batching**: 49 | * The client batches his bins (having each 1 integer entry) into ```number_of_bins```/```poly_modulus_degree``` vectors. 50 | * The client encodes each such batch as a plaintext. 51 | * The client encrypts these plaintexts and sends them to the server. 52 | * The server batches his minibins in minibatches. 53 | Due to our choice of parameters, only 1 plaintext is obtained and therefore, only 1 ciphertext is sent: *Enc(x)*. 54 | Hence, performing the PSI protocol can be performed simultaneously per each batch of bins. 55 | 56 | Since each minibin is represented by a polynomial of degree *D* =```bin_capacity```/```alpha```, evaluating such a polynomial can be performed by doing the scalar product between the vector of its coefficients and all the (encrypted) powers of the *x*, with exponent at most *D*. The next step, *windowing*, helps the client send *sufficiently many powers* so that the server can recover all the powers with small computational effort. 57 | ​ 58 | * **Windowing**: 59 | * The client sends besides *Enc(x)*, *Enc(x ** 2), Enc(x ** 4),...,Enc(x ** {2 ** {log D}})*. 60 | This scenario corresponds to the windowing parameter ```ell = 1```. 61 | ​ 62 | * **Recover all powers**: 63 | * The server recovers any *Enc(x ** i)*, for every *i* less or equal than *D*, from the given powers, by writing *i* in binary decomposition. 64 | ​ 65 | * **Doing the scalar products**: The server evaluates the polynomials for each minibin by computing the scalar product between the vector of their coefficients and the previous powers. Thanks to TenSEAL, this is done as follows: 66 | * For each minibatch, the server makes the sum of each encrypted power *Enc(x** i)* multiplied by the *D+1-i*-th column of coefficients from the minibatch. 67 | * The server gets ```alpha``` * ```number_of_bins```/ ```poly_modulus_degree``` encrypted results and sends them to the client. 68 | ​ 69 | * **Getting the verdict**: 70 | * The client decrypts the results he gets from server. Thanks to TenSEAL, he recovers a vector of integers (corresponding to the underlying polynomial plaintext, via encoding). 71 | * The client checks this vector to see where he obtains 0. If there is an index of this vector where he gets 0, then the (Cuckoo hashing) item corresponding to this index belongs to a minibin of the corresponding server's bin. 72 | * This index helps him recover the common element. 73 | ​ 74 | ## How to run 75 | Check ```requirements.txt``` before running the files. You can generate the datasets of the client and the server by running ```set_gen.py```. Then run ```server_offline.py``` and ```client_offline.py``` to preprocess them. Now go the online phase of the protocol by running ```server_online.py``` and ```client_online.py```. Have fun! :smile: 76 | -------------------------------------------------------------------------------- /auxiliary_functions.py: -------------------------------------------------------------------------------- 1 | from math import log2 2 | import numpy as np 3 | from parameters import ell, plain_modulus, bin_capacity, alpha 4 | 5 | base = 2 ** ell 6 | minibin_capacity = int(bin_capacity / alpha)# minibin_capacity = B / alpha 7 | logB_ell = int(log2(minibin_capacity) / ell) + 1 # <= 2 ** HE.depth = 16 8 | t = plain_modulus 9 | 10 | def int2base(n, b): 11 | ''' 12 | :param n: an integer 13 | :param b: a base 14 | :return: an array of coefficients from the base decomposition of an integer n with coeff[i] being the coeff of b ** i 15 | ''' 16 | if n < b: 17 | return [n] 18 | else: 19 | return [n % b] + int2base(n // b, b) 20 | 21 | # We need len(powers_vec) <= 2 ** HE.depth 22 | def low_depth_multiplication(vector): 23 | ''' 24 | :param: vector: a vector of integers 25 | :return: an integer representing the multiplication of all the integers from vector 26 | ''' 27 | L = len(vector) 28 | if L == 1: 29 | return vector[0] 30 | if L == 2: 31 | return(vector[0] * vector[1]) 32 | else: 33 | if (L % 2 == 1): 34 | vec = [] 35 | for i in range(int(L / 2)): 36 | vec.append(vector[2 * i] * vector[2 * i + 1]) 37 | vec.append(vector[L-1]) 38 | return low_depth_multiplication(vec) 39 | else: 40 | vec = [] 41 | for i in range(int(L / 2)): 42 | vec.append(vector[2 * i] * vector[2 * i + 1]) 43 | return low_depth_multiplication(vec) 44 | 45 | def power_reconstruct(window, exponent): 46 | ''' 47 | :param: window: a matrix of integers as powers of y; in the protocol is the matrix with entries window[i][j] = [y ** i * base ** j] 48 | :param: exponent: an integer, will be an exponent <= logB_ell 49 | :return: y ** exponent 50 | ''' 51 | e_base_coef = int2base(exponent, base) 52 | necessary_powers = [] #len(necessary_powers) <= 2 ** HE.depth 53 | j = 0 54 | for x in e_base_coef: 55 | if x >= 1: 56 | necessary_powers.append(window[x - 1][j]) 57 | j = j + 1 58 | return low_depth_multiplication(necessary_powers) 59 | 60 | 61 | def windowing(y, bound, modulus): 62 | ''' 63 | :param: y: an integer 64 | :param bound: an integer 65 | :param modulus: a modulus integer 66 | :return: a matrix associated to y, where we put y ** (i+1)*base ** j mod modulus in the (i,j) entry, as long as the exponent of y is smaller than some bound 67 | ''' 68 | windowed_y = [[None for j in range(logB_ell)] for i in range(base-1)] 69 | for j in range(logB_ell): 70 | for i in range(base-1): 71 | if ((i+1) * base ** j - 1 < bound): 72 | windowed_y[i][j] = pow(y, (i+1) * base ** j, modulus) 73 | return windowed_y 74 | 75 | 76 | def coeffs_from_roots(roots, modulus): 77 | ''' 78 | :param roots: an array of integers 79 | :param modulus: an integer 80 | :return: coefficients of a polynomial whose roots are roots modulo modulus 81 | ''' 82 | coefficients = np.array(1, dtype=np.int64) 83 | for r in roots: 84 | coefficients = np.convolve(coefficients, [1, -r]) % modulus 85 | return coefficients -------------------------------------------------------------------------------- /bin_capacity_estimator.py: -------------------------------------------------------------------------------- 1 | from math import log2 2 | from math import comb 3 | 4 | no_of_hashes = 3 5 | m = 2 ** 13 #no of bins 6 | server_size = 2 ** 20 7 | d = no_of_hashes * server_size 8 | security_bits = 30 #lambda 9 | 10 | md_1 = m ** (d-1) 11 | 12 | s = 0 13 | S = m ** d 14 | i = 0 15 | power_of_m_1 = (m-1) ** d 16 | TV = True 17 | 18 | while TV == True: 19 | print(i) 20 | current_term = comb(d,i) * power_of_m_1 21 | s = s + current_term 22 | S = S - current_term 23 | if int(log2(md_1) - log2(S)) >= security_bits: 24 | TV = False 25 | i = i + 1 26 | power_of_m_1 = power_of_m_1 // (m-1) 27 | 28 | print('--------------------') 29 | print('bin_capacity = {}'.format(i-1)) 30 | -------------------------------------------------------------------------------- /client_offline.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | from oprf import client_prf_offline, order_of_generator, G 3 | from time import time 4 | 5 | # client's PRF secret key (a value from range(order_of_generator)) 6 | oprf_client_key = 12345678910111213141516171819222222222222 7 | t0 = time() 8 | 9 | # key * generator of elliptic curve 10 | client_point_precomputed = (oprf_client_key % order_of_generator) * G 11 | 12 | client_set = [] 13 | f = open('client_set', 'r') 14 | lines = f.readlines() 15 | for item in lines: 16 | client_set.append(int(item[:-1])) 17 | f.close() 18 | 19 | # OPRF layer: encode the client's set as elliptic curve points. 20 | encoded_client_set = [client_prf_offline(item, client_point_precomputed) for item in client_set] 21 | 22 | g = open('client_preprocessed', 'wb') 23 | pickle.dump(encoded_client_set, g) 24 | g.close() 25 | t1 = time() 26 | print('Client OFFLINE time: {:.2f}s'.format(t1-t0)) 27 | -------------------------------------------------------------------------------- /client_online.py: -------------------------------------------------------------------------------- 1 | import tenseal as ts 2 | from time import time 3 | import socket 4 | import pickle 5 | from math import log2 6 | from parameters import sigma_max, output_bits, plain_modulus, poly_modulus_degree, number_of_hashes, bin_capacity, alpha, ell, hash_seeds 7 | from cuckoo_hash import reconstruct_item, Cuckoo 8 | from auxiliary_functions import windowing 9 | from oprf import order_of_generator, client_prf_online_parallel 10 | 11 | oprf_client_key = 12345678910111213141516171819222222222222 12 | 13 | log_no_hashes = int(log2(number_of_hashes)) + 1 14 | base = 2 ** ell 15 | minibin_capacity = int(bin_capacity / alpha) 16 | logB_ell = int(log2(minibin_capacity) / ell) + 1 # <= 2 ** HE.depth 17 | dummy_msg_client = 2 ** (sigma_max - output_bits + log_no_hashes) 18 | 19 | client = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 20 | client.connect(('localhost', 4470)) 21 | 22 | # Setting the public and private contexts for the BFV Homorphic Encryption scheme 23 | private_context = ts.context(ts.SCHEME_TYPE.BFV, poly_modulus_degree=poly_modulus_degree, plain_modulus=plain_modulus) 24 | public_context = ts.context_from(private_context.serialize()) 25 | public_context.make_context_public() 26 | 27 | # We prepare the partially OPRF processed database to be sent to the server 28 | pickle_off = open("client_preprocessed", "rb") 29 | encoded_client_set = pickle.load(pickle_off) 30 | encoded_client_set_serialized = pickle.dumps(encoded_client_set, protocol=None) 31 | 32 | L = len(encoded_client_set_serialized) 33 | sL = str(L) + ' ' * (10 - len(str(L))) 34 | client_to_server_communiation_oprf = L #in bytes 35 | # The length of the message is sent first 36 | client.sendall((sL).encode()) 37 | client.sendall(encoded_client_set_serialized) 38 | 39 | L = client.recv(10).decode().strip() 40 | L = int(L, 10) 41 | 42 | PRFed_encoded_client_set_serialized = b"" 43 | while len(PRFed_encoded_client_set_serialized) < L: 44 | data = client.recv(4096) 45 | if not data: break 46 | PRFed_encoded_client_set_serialized += data 47 | PRFed_encoded_client_set = pickle.loads(PRFed_encoded_client_set_serialized) 48 | t0 = time() 49 | server_to_client_communication_oprf = len(PRFed_encoded_client_set_serialized) 50 | 51 | # We finalize the OPRF processing by applying the inverse of the secret key, oprf_client_key 52 | key_inverse = pow(oprf_client_key, -1, order_of_generator) 53 | PRFed_client_set = client_prf_online_parallel(key_inverse, PRFed_encoded_client_set) 54 | print(' * OPRF protocol done!') 55 | 56 | # Each PRFed item from the client set is mapped to a Cuckoo hash table 57 | CH = Cuckoo(hash_seeds) 58 | for item in PRFed_client_set: 59 | CH.insert(item) 60 | 61 | # We padd the Cuckoo vector with dummy messages 62 | for i in range(CH.number_of_bins): 63 | if (CH.data_structure[i] == None): 64 | CH.data_structure[i] = dummy_msg_client 65 | 66 | # We apply the windowing procedure for each item from the Cuckoo structure 67 | windowed_items = [] 68 | for item in CH.data_structure: 69 | windowed_items.append(windowing(item, minibin_capacity, plain_modulus)) 70 | 71 | plain_query = [None for k in range(len(windowed_items))] 72 | enc_query = [[None for j in range(logB_ell)] for i in range(1, base)] 73 | 74 | # We create the <> query to be sent to the server 75 | # By our choice of parameters, number of bins = poly modulus degree (m/N =1), so we get (base - 1) * logB_ell ciphertexts 76 | for j in range(logB_ell): 77 | for i in range(base - 1): 78 | if ((i + 1) * base ** j - 1 < minibin_capacity): 79 | for k in range(len(windowed_items)): 80 | plain_query[k] = windowed_items[k][i][j] 81 | enc_query[i][j] = ts.bfv_vector(private_context, plain_query) 82 | 83 | enc_query_serialized = [[None for j in range(logB_ell)] for i in range(1, base)] 84 | for j in range(logB_ell): 85 | for i in range(base - 1): 86 | if ((i + 1) * base ** j - 1 < minibin_capacity): 87 | enc_query_serialized[i][j] = enc_query[i][j].serialize() 88 | 89 | context_serialized = public_context.serialize() 90 | message_to_be_sent = [context_serialized, enc_query_serialized] 91 | message_to_be_sent_serialized = pickle.dumps(message_to_be_sent, protocol=None) 92 | t1 = time() 93 | L = len(message_to_be_sent_serialized) 94 | sL = str(L) + ' ' * (10 - len(str(L))) 95 | client_to_server_communiation_query = L 96 | #the lenght of the message is sent first 97 | client.sendall((sL).encode()) 98 | print(" * Sending the context and ciphertext to the server....") 99 | # Now we send the message to the server 100 | client.sendall(message_to_be_sent_serialized) 101 | 102 | print(" * Waiting for the servers's answer...") 103 | 104 | # The answer obtained from the server: 105 | L = client.recv(10).decode().strip() 106 | L = int(L, 10) 107 | answer = b"" 108 | while len(answer) < L: 109 | data = client.recv(4096) 110 | if not data: break 111 | answer += data 112 | t2 = time() 113 | server_to_client_query_response = len(answer) #bytes 114 | # Here is the vector of decryptions of the answer 115 | ciphertexts = pickle.loads(answer) 116 | decryptions = [] 117 | for ct in ciphertexts: 118 | decryptions.append(ts.bfv_vector_from(private_context, ct).decrypt()) 119 | 120 | recover_CH_structure = [] 121 | for matrix in windowed_items: 122 | recover_CH_structure.append(matrix[0][0]) 123 | 124 | count = [0] * alpha 125 | 126 | g = open('client_set', 'r') 127 | client_set_entries = g.readlines() 128 | g.close() 129 | client_intersection = [] 130 | for j in range(alpha): 131 | for i in range(poly_modulus_degree): 132 | if decryptions[j][i] == 0: 133 | count[j] = count[j] + 1 134 | 135 | # The index i is the location of the element in the intersection 136 | # Here we recover this element from the Cuckoo hash structure 137 | PRFed_common_element = reconstruct_item(recover_CH_structure[i], i, hash_seeds[recover_CH_structure[i] % (2 ** log_no_hashes)]) 138 | index = PRFed_client_set.index(PRFed_common_element) 139 | client_intersection.append(int(client_set_entries[index][:-1])) 140 | 141 | h = open('intersection', 'r') 142 | real_intersection = [int(line[:-1]) for line in h] 143 | h.close() 144 | t3 = time() 145 | print('\n Intersection recovered correctly: {}'.format(set(client_intersection) == set(real_intersection))) 146 | print("Disconnecting...\n") 147 | print(' Client ONLINE computation time {:.2f}s'.format(t1 - t0 + t3 - t2)) 148 | print(' Communication size:') 149 | print(' ~ Client --> Server: {:.2f} MB'.format((client_to_server_communiation_oprf + client_to_server_communiation_query )/ 2 ** 20)) 150 | print(' ~ Server --> Client: {:.2f} MB'.format((server_to_client_communication_oprf + server_to_client_query_response )/ 2 ** 20)) 151 | client.close() 152 | 153 | 154 | -------------------------------------------------------------------------------- /cuckoo_hash.py: -------------------------------------------------------------------------------- 1 | from random import randint 2 | import math 3 | import mmh3 4 | 5 | #parameters 6 | from parameters import output_bits, number_of_hashes 7 | mask_of_power_of_2 = 2 ** output_bits - 1 8 | log_no_hashes = int(math.log(number_of_hashes) / math.log(2)) + 1 9 | 10 | 11 | #The hash family used for Cuckoo hashing relies on the Murmur hash family (mmh3) 12 | 13 | def location(seed, item): 14 | ''' 15 | :param seed: a seed of a Murmur hash function 16 | :param item: an integer 17 | :return: Murmur_hash(item_left) xor item_right, where item = item_left || item_right 18 | ''' 19 | item_left = item >> output_bits 20 | item_right = item & mask_of_power_of_2 21 | hash_item_left = mmh3.hash(str(item_left), seed, signed=False) >> (32 - output_bits) 22 | return hash_item_left ^ item_right 23 | 24 | def left_and_index(item, index): 25 | ''' 26 | :param item: an integer 27 | :param index: a log_no_hashes bits integer 28 | :return: an integer represented as item_left || index 29 | ''' 30 | return ((item >> (output_bits)) << (log_no_hashes)) + index 31 | 32 | def extract_index(item_left_and_index): 33 | ''' 34 | :param item_left_and_index: an integer represented as item_left || index 35 | :return: index extracted 36 | ''' 37 | return item_left_and_index & (2 ** log_no_hashes - 1) 38 | 39 | def reconstruct_item(item_left_and_index, current_location, seed): 40 | ''' 41 | :param item_left_and_index: an integer represented as item_left || index 42 | :param current_location: the corresponding location, i.e. Murmur_hash(item_left) xor item_right 43 | :param seed: the seed of the Murmur hash function 44 | :return: the integer item 45 | ''' 46 | item_left = item_left_and_index >> log_no_hashes 47 | hashed_item_left = mmh3.hash(str(item_left), seed, signed=False) >> (32 - output_bits) 48 | item_right = hashed_item_left ^ current_location 49 | return (item_left << output_bits) + item_right 50 | 51 | def rand_point(bound, i): 52 | ''' 53 | :param bound: an integer 54 | :param i: an integer less than bound 55 | :return: a uniform integer from [0, bound - 1], distinct from i 56 | ''' 57 | value = randint(0, bound - 1) 58 | while (value == i): 59 | value = randint(0, bound - 1) 60 | return value 61 | 62 | class Cuckoo(): 63 | 64 | def __init__(self, hash_seed): 65 | self.number_of_bins = 2 ** output_bits 66 | self.recursion_depth = int(8 * math.log(self.number_of_bins) / math.log(2)) 67 | self.data_structure = [None for j in range(self.number_of_bins)] 68 | self.insert_index = randint(0, number_of_hashes - 1) 69 | self.depth = 0 70 | self.FAIL = 0 71 | 72 | self.hash_seed = hash_seed 73 | 74 | def insert(self, item): #item is an integer 75 | current_location = location( self.hash_seed[self.insert_index], item) 76 | current_item = self.data_structure[ current_location] 77 | self.data_structure[ current_location ] = left_and_index(item, self.insert_index) 78 | 79 | if (current_item == None): 80 | self.insert_index = randint(0, number_of_hashes - 1) 81 | self.depth = 0 82 | else: 83 | unwanted_index = extract_index(current_item) 84 | self.insert_index = rand_point(number_of_hashes, unwanted_index) 85 | if (self.depth < self.recursion_depth): 86 | self.depth +=1 87 | jumping_item = reconstruct_item(current_item, current_location, self.hash_seed[unwanted_index]) 88 | self.insert(jumping_item) 89 | else: 90 | self.FAIL = 1 91 | -------------------------------------------------------------------------------- /oprf.py: -------------------------------------------------------------------------------- 1 | from fastecdsa.curve import P192 2 | from fastecdsa.point import Point 3 | from math import log2 4 | from multiprocessing import Pool 5 | from parameters import sigma_max 6 | 7 | mask = 2 ** sigma_max - 1 8 | 9 | number_of_processes = 4 10 | 11 | # Curve parameters 12 | curve_used = P192 13 | prime_of_curve_equation = curve_used.p 14 | order_of_generator = curve_used.q 15 | log_p = int(log2(prime_of_curve_equation)) + 1 16 | G = Point(curve_used.gx, curve_used.gy, curve=curve_used) #generator of the curve_used 17 | 18 | def server_prf_offline(vector_of_items_and_point): #used as a subroutine for server_prf_offline_paralel 19 | vector_of_items = vector_of_items_and_point[0] 20 | point = vector_of_items_and_point[1] 21 | vector_of_multiples = [item * point for item in vector_of_items] 22 | return [(Q.x >> log_p - sigma_max - 10) & mask for Q in vector_of_multiples] 23 | 24 | def server_prf_offline_parallel(vector_of_items, point): 25 | ''' 26 | :param vector_of_items: a vector of integers 27 | :param point: a point on elliptic curve (it will be key * G) 28 | :return: a sigma_max bits integer from the first coordinate of item * point (this will be the same as item * key * G) 29 | ''' 30 | division = int(len(vector_of_items) / number_of_processes) 31 | inputs = [vector_of_items[i * division: (i+1) * division] for i in range(number_of_processes)] 32 | if len(vector_of_items) % number_of_processes != 0: 33 | inputs.append(vector_of_items[number_of_processes * division: number_of_processes * division + (len(vector_of_items) % number_of_processes)]) 34 | inputs_and_point = [(input_vec, point) for input_vec in inputs] 35 | outputs = [] 36 | with Pool(number_of_processes) as p: 37 | outputs = p.map(server_prf_offline, inputs_and_point) 38 | final_output = [] 39 | for output_vector in outputs: 40 | final_output = final_output + output_vector 41 | return final_output 42 | 43 | def server_prf_online(keyed_vector_of_points): #used as a subroutine in server_prf_online_paralel 44 | key = keyed_vector_of_points[0] 45 | vector_of_points = keyed_vector_of_points[1] 46 | vector_of_multiples = [key * PP for PP in vector_of_points] 47 | return [[Q.x, Q.y] for Q in vector_of_multiples] 48 | 49 | 50 | def server_prf_online_parallel(key, vector_of_pairs): 51 | ''' 52 | :param key: an integer 53 | :param vector_of_pairs: vector of coordinates of some points P on the elliptic curve 54 | :return: vector of coordinates of points key * P on the elliptic curve 55 | ''' 56 | vector_of_points = [Point(P[0], P[1], curve=curve_used) for P in vector_of_pairs] 57 | division = int(len(vector_of_points) / number_of_processes) 58 | inputs = [vector_of_points[i * division: (i+1) * division] for i in range(number_of_processes)] 59 | if len(vector_of_points) % number_of_processes != 0: 60 | inputs.append(vector_of_points[number_of_processes * division: number_of_processes * division + (len(vector_of_points) % number_of_processes)]) 61 | keyed_inputs = [(key, _) for _ in inputs] 62 | outputs = [] 63 | with Pool(number_of_processes) as p: 64 | outputs = p.map(server_prf_online, keyed_inputs) 65 | final_output = [] 66 | for output_vector in outputs: 67 | final_output = final_output + output_vector 68 | return final_output 69 | 70 | def client_prf_offline(item, point): 71 | ''' 72 | :param item: an integer 73 | :param point: a point on elliptic curve (ex. in the protocol point = key * G) 74 | :return: coordinates of item * point (ex. in the protocol it computes key * item * G) 75 | ''' 76 | P = item * point 77 | x_item = P.x 78 | y_item = P.y 79 | return [x_item, y_item] 80 | 81 | def client_prf_online(keyed_vector_of_pairs): 82 | key_inverse = keyed_vector_of_pairs[0] 83 | vector_of_pairs = keyed_vector_of_pairs[1] 84 | vector_of_points = [Point(pair[0],pair[1], curve=curve_used) for pair in vector_of_pairs] 85 | vector_key_inverse_points = [key_inverse * PP for PP in vector_of_points] 86 | return [(Q.x >> log_p - sigma_max - 10) & mask for Q in vector_key_inverse_points] 87 | 88 | def client_prf_online_parallel(key_inverse, vector_of_pairs): 89 | vector_of_pairs = vector_of_pairs 90 | division = int(len(vector_of_pairs) / number_of_processes) 91 | inputs = [vector_of_pairs[i * division: (i+1) * division] for i in range(number_of_processes)] 92 | if len(vector_of_pairs) % number_of_processes != 0: 93 | inputs.append(vector_of_pairs[number_of_processes * division: number_of_processes * division + (len(vector_of_pairs) % number_of_processes)]) 94 | keyed_inputs = [(key_inverse, _) for _ in inputs] 95 | outputs = [] 96 | with Pool(number_of_processes) as p: 97 | outputs = p.map(client_prf_online, keyed_inputs) 98 | final_output = [] 99 | for output_vector in outputs: 100 | final_output = final_output + output_vector 101 | return final_output 102 | 103 | -------------------------------------------------------------------------------- /parameters.py: -------------------------------------------------------------------------------- 1 | from math import log2 2 | 3 | # sizes of databases of server and client 4 | # size of intersection should be less than size of client's database 5 | server_size = 2 ** 20 6 | client_size = 4000 7 | intersection_size = 3500 8 | 9 | # seeds used by both the Server and the Client for the Murmur hash functions 10 | hash_seeds = [123456789, 10111213141516, 17181920212223] 11 | 12 | # output_bits = number of bits of output of the hash functions 13 | # number of bins for simple/Cuckoo Hashing = 2 ** output_bits 14 | output_bits = 13 15 | 16 | # encryption parameters of the BFV scheme: the plain modulus and the polynomial modulus degree 17 | plain_modulus = 536903681 18 | poly_modulus_degree = 2 ** 13 19 | 20 | # the number of hashes we use for simple/Cuckoo hashing 21 | number_of_hashes = 3 22 | 23 | # length of the database items 24 | sigma_max = int(log2(plain_modulus)) + output_bits - (int(log2(number_of_hashes)) + 1) 25 | 26 | # B = [68, 176, 536, 1832, 6727] for log(server_size) = [16, 18, 20, 22, 24] 27 | bin_capacity = 536 28 | 29 | # partitioning parameter 30 | alpha = 16 31 | 32 | # windowing parameter 33 | ell = 2 34 | 35 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | fastecdsa==2.1.5 2 | mmh3==3.0.0 3 | numpy==1.21.0 4 | tenseal==0.3.4 5 | -------------------------------------------------------------------------------- /server_offline.py: -------------------------------------------------------------------------------- 1 | from parameters import sigma_max, number_of_hashes, output_bits, bin_capacity, alpha, hash_seeds, plain_modulus 2 | from simple_hash import Simple_hash 3 | from auxiliary_functions import coeffs_from_roots 4 | from math import log2 5 | import pickle 6 | from oprf import server_prf_offline_parallel, order_of_generator, G 7 | from time import time 8 | 9 | #server's PRF secret key 10 | oprf_server_key = 1234567891011121314151617181920 11 | 12 | # key * generator of elliptic curve 13 | server_point_precomputed = (oprf_server_key % order_of_generator) * G 14 | 15 | server_set = [] 16 | f = open('server_set', 'r') 17 | lines = f.readlines() 18 | for item in lines: 19 | server_set.append(int(item[:-1])) 20 | 21 | t0 = time() 22 | #The PRF function is applied on the set of the server, using parallel computation 23 | PRFed_server_set = server_prf_offline_parallel(server_set, server_point_precomputed) 24 | PRFed_server_set = set(PRFed_server_set) 25 | t1 = time() 26 | 27 | log_no_hashes = int(log2(number_of_hashes)) + 1 28 | dummy_msg_server = 2 ** (sigma_max - output_bits + log_no_hashes) + 1 29 | server_size = len(server_set) 30 | minibin_capacity = int(bin_capacity / alpha) 31 | number_of_bins = 2 ** output_bits 32 | 33 | # The OPRF-processed database entries are simple hashed 34 | SH = Simple_hash(hash_seeds) 35 | for item in PRFed_server_set: 36 | for i in range(number_of_hashes): 37 | SH.insert(item, i) 38 | 39 | # simple_hashed_data is padded with dummy_msg_server 40 | for i in range(number_of_bins): 41 | for j in range(bin_capacity): 42 | if SH.simple_hashed_data[i][j] == None: 43 | SH.simple_hashed_data[i][j] = dummy_msg_server 44 | 45 | # Here we perform the partitioning: 46 | # Namely, we partition each bin into alpha minibins with B/alpha items each 47 | # We represent each minibin as the coefficients of a polynomial of degree B/alpha that vanishes in all the entries of the mininbin 48 | # Therefore, each minibin will be represented by B/alpha + 1 coefficients; notice that the leading coeff = 1 49 | t2 = time() 50 | 51 | poly_coeffs = [] 52 | for i in range(number_of_bins): 53 | # we create a list of coefficients of all minibins from concatenating the list of coefficients of each minibin 54 | coeffs_from_bin = [] 55 | for j in range(alpha): 56 | roots = [SH.simple_hashed_data[i][minibin_capacity * j + r] for r in range(minibin_capacity)] 57 | coeffs_from_bin = coeffs_from_bin + coeffs_from_roots(roots, plain_modulus).tolist() 58 | poly_coeffs.append(coeffs_from_bin) 59 | 60 | f = open('server_preprocessed', 'wb') 61 | pickle.dump(poly_coeffs, f) 62 | f.close() 63 | t3 = time() 64 | #print('OPRF preprocessing time {:.2f}s'.format(t1 - t0)) 65 | #print('Hashing time {:.2f}s'.format(t2 - t1)) 66 | #print('Poly coefficients from roots time {:.2f}s'.format(t3 - t2)) 67 | print('Server OFFLINE time {:.2f}s'.format(t3 - t0)) 68 | -------------------------------------------------------------------------------- /server_online.py: -------------------------------------------------------------------------------- 1 | import socket 2 | import tenseal as ts 3 | import pickle 4 | import numpy as np 5 | from math import log2 6 | 7 | from parameters import number_of_hashes, bin_capacity, alpha, ell 8 | from auxiliary_functions import power_reconstruct 9 | from oprf import server_prf_online_parallel 10 | 11 | oprf_server_key = 1234567891011121314151617181920 12 | from time import time 13 | 14 | log_no_hashes = int(log2(number_of_hashes)) + 1 15 | base = 2 ** ell 16 | minibin_capacity = int(bin_capacity / alpha) 17 | logB_ell = int(log2(minibin_capacity) / ell) + 1 # <= 2 ** HE.depth 18 | 19 | serv = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 20 | serv.bind(('localhost', 4470)) 21 | serv.listen(1) 22 | 23 | g = open('server_preprocessed', 'rb') 24 | poly_coeffs = pickle.load(g) 25 | 26 | # For the online phase of the server, we need to use the columns of the preprocessed database 27 | transposed_poly_coeffs = np.transpose(poly_coeffs).tolist() 28 | 29 | for i in range(1): 30 | conn, addr = serv.accept() 31 | L = conn.recv(10).decode().strip() 32 | L = int(L, 10) 33 | # OPRF layer: the server receives the encoded set elements as curve points 34 | encoded_client_set_serialized = b"" 35 | while len(encoded_client_set_serialized) < L: 36 | data = conn.recv(4096) 37 | if not data: break 38 | encoded_client_set_serialized += data 39 | encoded_client_set = pickle.loads(encoded_client_set_serialized) 40 | t0 = time() 41 | # The server computes (parallel computation) the online part of the OPRF protocol, using its own secret key 42 | PRFed_encoded_client_set = server_prf_online_parallel(oprf_server_key, encoded_client_set) 43 | PRFed_encoded_client_set_serialized = pickle.dumps(PRFed_encoded_client_set, protocol=None) 44 | L = len(PRFed_encoded_client_set_serialized) 45 | sL = str(L) + ' ' * (10 - len(str(L))) #pad len to 10 bytes 46 | 47 | conn.sendall((sL).encode()) 48 | conn.sendall(PRFed_encoded_client_set_serialized) 49 | print(' * OPRF layer done!') 50 | t1 = time() 51 | L = conn.recv(10).decode().strip() 52 | L = int(L, 10) 53 | 54 | # The server receives bytes that represent the public HE context and the query ciphertext 55 | final_data = b"" 56 | while len(final_data) < L: 57 | data = conn.recv(4096) 58 | if not data: break 59 | final_data += data 60 | 61 | t2 = time() 62 | # Here we recover the context and ciphertext received from the received bytes 63 | received_data = pickle.loads(final_data) 64 | srv_context = ts.context_from(received_data[0]) 65 | received_enc_query_serialized = received_data[1] 66 | received_enc_query = [[None for j in range(logB_ell)] for i in range(base - 1)] 67 | for i in range(base - 1): 68 | for j in range(logB_ell): 69 | if ((i + 1) * base ** j - 1 < minibin_capacity): 70 | received_enc_query[i][j] = ts.bfv_vector_from(srv_context, received_enc_query_serialized[i][j]) 71 | 72 | # Here we recover all the encrypted powers Enc(y), Enc(y^2), Enc(y^3) ..., Enc(y^{minibin_capacity}), from the encrypted windowing of y. 73 | # These are needed to compute the polynomial of degree minibin_capacity 74 | all_powers = [None for i in range(minibin_capacity)] 75 | for i in range(base - 1): 76 | for j in range(logB_ell): 77 | if ((i + 1) * base ** j - 1 < minibin_capacity): 78 | all_powers[(i + 1) * base ** j - 1] = received_enc_query[i][j] 79 | 80 | for k in range(minibin_capacity): 81 | if all_powers[k] == None: 82 | all_powers[k] = power_reconstruct(received_enc_query, k + 1) 83 | all_powers = all_powers[::-1] 84 | 85 | # Server sends alpha ciphertexts, obtained from performing dot_product between the polynomial coefficients from the preprocessed server database and all the powers Enc(y), ..., Enc(y^{minibin_capacity}) 86 | srv_answer = [] 87 | for i in range(alpha): 88 | # the rows with index multiple of (B/alpha+1) have only 1's 89 | dot_product = all_powers[0] 90 | for j in range(1, minibin_capacity): 91 | dot_product = dot_product + transposed_poly_coeffs[(minibin_capacity + 1) * i + j] * all_powers[j] 92 | dot_product = dot_product + transposed_poly_coeffs[(minibin_capacity + 1) * i + minibin_capacity] 93 | srv_answer.append(dot_product.serialize()) 94 | 95 | # The answer to be sent to the client is prepared 96 | response_to_be_sent = pickle.dumps(srv_answer, protocol=None) 97 | t3 = time() 98 | L = len(response_to_be_sent) 99 | sL = str(L) + ' ' * (10 - len(str(L))) #pad len to 10 bytes 100 | 101 | conn.sendall((sL).encode()) 102 | conn.sendall(response_to_be_sent) 103 | 104 | # Close the connection 105 | print("Client disconnected \n") 106 | print('Server ONLINE computation time {:.2f}s'.format(t1 - t0 + t3 - t2)) 107 | 108 | conn.close() 109 | -------------------------------------------------------------------------------- /set_gen.py: -------------------------------------------------------------------------------- 1 | from random import sample 2 | from parameters import server_size, client_size, intersection_size 3 | 4 | #set elements can be integers < order of the generator of the elliptic curve (192 bits integers if P192 is used); 'sample' works only for a maximum of 63 bits integers. 5 | disjoint_union = sample(range(2 ** 63 - 1), server_size + client_size) 6 | intersection = disjoint_union[:intersection_size] 7 | server_set = intersection + disjoint_union[intersection_size: server_size] 8 | client_set = intersection + disjoint_union[server_size: server_size - intersection_size + client_size] 9 | 10 | f = open('server_set', 'w') 11 | for item in server_set: 12 | f.write(str(item) + '\n') 13 | f.close() 14 | 15 | g = open('client_set', 'w') 16 | for item in client_set: 17 | g.write(str(item) + '\n') 18 | g.close() 19 | 20 | h = open('intersection', 'w') 21 | for item in intersection: 22 | h.write(str(item) + '\n') 23 | h.close() 24 | -------------------------------------------------------------------------------- /simple_hash.py: -------------------------------------------------------------------------------- 1 | from random import randint 2 | import math 3 | import mmh3 4 | 5 | #parameters 6 | from parameters import output_bits, number_of_hashes, bin_capacity 7 | log_no_hashes = int(math.log(number_of_hashes) / math.log(2)) + 1 8 | mask_of_power_of_2 = 2 ** output_bits - 1 9 | 10 | 11 | def left_and_index(item, index): 12 | ''' 13 | :param item: an integer 14 | :param index: a log_no_hashes bits integer 15 | :return: an integer represented as item_left || index 16 | ''' 17 | 18 | return ((item >> (output_bits)) << (log_no_hashes)) + index 19 | 20 | #The hash family used for simple hashing relies on the Murmur hash family (mmh3) 21 | 22 | def location(seed, item): 23 | ''' 24 | :param seed: a seed of a Murmur hash function 25 | :param item: an integer 26 | :return: Murmur_hash(item_left) xor item_right, where item = item_left || item_right 27 | ''' 28 | 29 | item_left = item >> output_bits 30 | item_right = item & mask_of_power_of_2 31 | hash_item_left = mmh3.hash(str(item_left), seed, signed=False) >> (32 - output_bits) 32 | return hash_item_left ^ item_right 33 | 34 | class Simple_hash(): 35 | 36 | def __init__(self, hash_seed): 37 | self.no_bins = 2 ** output_bits 38 | self.simple_hashed_data = [[None for j in range(bin_capacity)] for i in range(self.no_bins)] 39 | self.occurences = [0 for i in range(self.no_bins)] 40 | self.FAIL = 0 41 | self.hash_seed = hash_seed 42 | self.bin_capacity = bin_capacity 43 | 44 | # insert item using hash i on position given by location 45 | def insert(self, item, i): 46 | loc = location(self.hash_seed[i], item) 47 | if (self.occurences[loc] < self.bin_capacity): 48 | self.simple_hashed_data[loc][self.occurences[loc]] = left_and_index(item, i) 49 | self.occurences[loc] += 1 50 | else: 51 | self.FAIL = 1 52 | print('Simple hashing aborted') 53 | --------------------------------------------------------------------------------