├── .gitignore ├── LICENSE ├── README.md ├── bin └── TODO ├── data ├── __init__.py └── iris_data.py ├── examples ├── __init__.py ├── perceptron │ ├── __init__.py │ ├── alg.py │ ├── circuit.py │ └── eval_circuit.py └── svm │ ├── alg.py │ ├── circuit.py │ └── eval_circuit.py ├── requirements.txt ├── setup.py ├── src ├── __init__.py ├── circuits │ ├── __init__.py │ ├── dealer.py │ ├── evaluator.py │ ├── gate.py │ ├── oracle.py │ └── share.py └── util │ ├── mod.py │ └── primality_test.py └── tests ├── __init__.py ├── examples └── __init__.py ├── src ├── __init__.py └── circuits │ ├── __init__.py │ ├── test_evaluator.py │ └── test_share.py └── util ├── __init__.py └── test_mod.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | 5 | # C extensions 6 | *.so 7 | 8 | # Distribution / packaging 9 | .Python 10 | env/ 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | *.egg-info/ 23 | .installed.cfg 24 | *.egg 25 | 26 | # PyInstaller 27 | # Usually these files are written by a python script from a template 28 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 29 | *.manifest 30 | *.spec 31 | 32 | # Installer logs 33 | pip-log.txt 34 | pip-delete-this-directory.txt 35 | 36 | # Unit test / coverage reports 37 | htmlcov/ 38 | .tox/ 39 | .coverage 40 | .coverage.* 41 | .cache 42 | nosetests.xml 43 | coverage.xml 44 | *,cover 45 | 46 | # Translations 47 | *.mo 48 | *.pot 49 | 50 | # Django stuff: 51 | *.log 52 | 53 | # Sphinx documentation 54 | docs/_build/ 55 | 56 | # PyBuilder 57 | target/ 58 | 59 | # virtual environment 60 | .venv/ 61 | 62 | # visual studio code configuration 63 | .vscode/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MPC-learning 2 | MPC-learning is a Python library for performing multi-party computation on machine learning applications. This library implements the 3-party computation protocol of https://eprint.iacr.org/2016/768.pdf . For now, a "dealer" is required to distribute shares of inputs, and the protocol can only be run locally (does not support networking yet). 3 | 4 | ## Installation 5 | This is a quick guide to getting this repo up and running for development. 6 | 7 | 0. Clone the library 8 | 9 | ```bash 10 | $ git clone https://github.com/trailofbits/mpc-learning 11 | ``` 12 | 13 | 1. Download virtualenv. 14 | 15 | 2. Create and source your virtual environment. 16 | 17 | ```bash 18 | $ virtualenv -p python3 .venv 19 | $ source .venv/bin/activate 20 | ``` 21 | 22 | 3. Install the library: 23 | 24 | a. if you want to use and edit the library: 25 | ```bash 26 | $ python setup.py develop 27 | ``` 28 | 29 | b. otherwise to just use the library: 30 | ```bash 31 | $ python setup.py install 32 | ``` 33 | 34 | ## Usage 35 | 36 | If everything installed correctly the following examples should work: 37 | 38 | 1. raw perceptron algorithm: 39 | 40 | ```bash 41 | $ python examples/perceptron/alg.py 42 | ``` 43 | 44 | 2. mpc perceptron (should be same result as raw algorithm, but will take longer): 45 | 46 | ```bash 47 | $ python examples/perceptron/eval_circuit.py 48 | ``` 49 | 50 | 3. raw svm algorithm: 51 | 52 | ```bash 53 | $ python examples/svm/alg.py 54 | ``` 55 | 56 | 4. mpc svm (should be same result as raw algorithm, but will take longer): 57 | 58 | ```bash 59 | $ python examples/svm/eval_circuit.py 60 | ``` 61 | 62 | If you would like to run this library on a different algorithm, you will have to synthesize the corresponding circuit for one iteration of the algorithm. The circuits must be in the correct format. For reference, checkout the perceptron and svm circuits: examples/*/circuit.py 63 | 64 | ## Contributing 65 | Pull requests are welcome. For major changes, please open an issue first to discuss what you would like to change. 66 | 67 | Please make sure to update tests as appropriate. 68 | 69 | ## License 70 | [Apache 2.0](https://www.apache.org/licenses/LICENSE-2.0) -------------------------------------------------------------------------------- /bin/TODO: -------------------------------------------------------------------------------- 1 | TODO -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/trailofbits/mpc-learning/7fa64451f9d6d5a4bb5fe2465762f2734952fc4c/data/__init__.py -------------------------------------------------------------------------------- /data/iris_data.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | 4 | 5 | def get_iris_data(): 6 | """ 7 | Function for fetching iris data from archive. The function obtains 8 | the data and puts it in the format needed for examples/perceptron 9 | 10 | Returns 11 | ------- 12 | data: iterable 13 | Iris data in format for perceptron algorithm (iterable of pairs x,y) 14 | """ 15 | 16 | url = 'https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data' 17 | 18 | df = pd.read_csv(url, header=None) 19 | 20 | x_vals = df.iloc[:, [0, 2]].values 21 | 22 | y_vals = df.iloc[:, 4].values 23 | y_vals = np.where(y_vals == 'Iris-setosa', -1, 1) 24 | 25 | size = len(x_vals) 26 | 27 | data = [] 28 | 29 | # randomize ordering of data 30 | np.random.seed(1) 31 | random_indices = np.random.permutation(size) 32 | 33 | for i in range(size): 34 | cur_index = random_indices[i] 35 | data.append((x_vals[cur_index], y_vals[cur_index])) 36 | 37 | return data -------------------------------------------------------------------------------- /examples/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/trailofbits/mpc-learning/7fa64451f9d6d5a4bb5fe2465762f2734952fc4c/examples/__init__.py -------------------------------------------------------------------------------- /examples/perceptron/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/trailofbits/mpc-learning/7fa64451f9d6d5a4bb5fe2465762f2734952fc4c/examples/perceptron/__init__.py -------------------------------------------------------------------------------- /examples/perceptron/alg.py: -------------------------------------------------------------------------------- 1 | # this file implements the raw perceptron algorithm for testing purposes 2 | # ideally the output of this algorithm will match that of the circuit 3 | 4 | import numpy as np 5 | import time 6 | 7 | 8 | def perceptron(data, num_iterations, modulus, initial_w=0, initial_b=0, fp_precision=16): 9 | """ 10 | Takes data (assumed to be an iterable of pairs (x,y)) and runs the 11 | perceptron algorithm for the number of iterations specified. We 12 | also assume that x is a numpy array-like object. 13 | 14 | The goal of the perceptron algorithm is to find (optimal) w,b such that 15 | y(dot(w,x) + b) > 0 is true for as many data points as possible 16 | 17 | Parameters 18 | ---------- 19 | data: iteratable 20 | Data to be input into algorithm (assumed iterable pairs) 21 | num_iterations: int 22 | Number of iterations that algorithm will run for 23 | modulus: int 24 | Value representing the modulus of field used 25 | (optional) initial_w=0: int 26 | Initial value of w, parameter of perceptron algorithm 27 | (optional) initial_b=0: int 28 | Initial value of b, parameter of perceptron algorithm 29 | (optional) fp_precision=16: int 30 | Fixed point number precision 31 | 32 | Returns 33 | ------- 34 | w: float 35 | w value achieved after num_iterations of perceptron 36 | b: int 37 | b value achieved after num_iterations of perceptron 38 | """ 39 | 40 | # need to make dimenions of w the same as x 41 | if initial_w == 0: 42 | first_x = data[0][0] 43 | initial_w = np.zeros(len(first_x)) 44 | 45 | w = initial_w 46 | b = initial_b 47 | 48 | # use fixed point numbers 49 | # input data should be scaled up by 10^fp_precision 50 | # also scale down by 10^fp_precision after every mult 51 | scale = 10**fp_precision 52 | 53 | start_time = time.time() 54 | 55 | for i in range(num_iterations): 56 | 57 | np_x = np.array(data[i][0]) 58 | y = data[i][1] 59 | 60 | # we use fixed point, so multiply by precision and round to integer 61 | for a in range(len(np_x)): 62 | np_x[a] = int( np_x[a] * scale) 63 | y = int( y * scale) 64 | 65 | # if point misclassified, update w and b, else do nothing 66 | xw_dot = np.dot(np_x,w) / scale 67 | if (y * (xw_dot + b)) / scale <= 0: 68 | w += (y * np_x) / scale 69 | b += y 70 | 71 | 72 | print("iteration: " + str(i)) 73 | print(w) 74 | print(b) 75 | 76 | 77 | w = w / scale 78 | b = b / scale 79 | 80 | elapsed_time = time.time() - start_time 81 | print("elapsed time: " + str(elapsed_time)) 82 | 83 | return (w, b) 84 | 85 | if __name__ == "__main__": 86 | 87 | import data.iris_data as iris 88 | 89 | data = iris.get_iris_data() 90 | 91 | num_iter = len(data) 92 | 93 | print(perceptron(data,num_iter,2**128)) -------------------------------------------------------------------------------- /examples/perceptron/circuit.py: -------------------------------------------------------------------------------- 1 | from src.circuits.gate import Gate 2 | from queue import Queue 3 | 4 | # this file contains the circuit for one iteration of the perceptron algorithm 5 | # the circuit is represented via a python dictionary 6 | # each wire is given a unique ID, and each gate is given a unique ID and label 7 | # they unique ID for the gate will serve as its key 8 | # the IDs will be assigned as g1 for gate one, g2 ... 9 | # the values of the dictionary correspond to the gate type and in/out wires 10 | 11 | # here is a summary of the perceptron algorithm 12 | 13 | # initialize w,b (for us, we take w,b to be random integer in {-1,1}) 14 | # for each data point (x,y): 15 | # if y * ( dotproduct(w,x) + b ) <= 0: 16 | # w = w + yx 17 | # b = b + y 18 | # else: 19 | # w = w 20 | # b = b 21 | 22 | # this circuit will consist of the code block below the for loop 23 | # there are four input values for each iteration: x, y, w, and b 24 | # they are labeled as follows: 25 | # x -> input0 26 | # y -> input1 27 | # w -> input2 28 | # b -> input3 29 | # there are two output values: w and b 30 | # they are labeled as follows: 31 | # w -> output0 32 | # b -> output1 33 | # the additional intermediate wires will be given the lable zi for all i 34 | 35 | # we use the following gate labels: ADD, MULT, SMULT, COMP, DOT, and NOT 36 | # ADD is the addition gate 37 | # MULT is the multiplication gate 38 | # SMULT is the scalar multiplication gate 39 | # COMP is the comparison gate, which computes the boolean (input <= 0) 40 | # DOT is the dot product gate, which computes the dot product of two inputs 41 | # NOT is the not gate, which computes 1 - input (input 0 or 1 here) 42 | 43 | x = "input0" 44 | y = "input1" 45 | wi = "input2" 46 | bi = "input3" 47 | 48 | wo = "output0" 49 | bo = "output1" 50 | 51 | circuit = {} 52 | 53 | # specify input, intermediate, and output wires 54 | circuit["input"] = [x,y,wi,bi] 55 | circuit["output"] = [wo,bo] 56 | wires = [] 57 | for i in range(12): 58 | wires.append("z"+str(i)) 59 | circuit["wires"] = wires 60 | 61 | # specify order of gates to be evaluated 62 | gate_order = [] 63 | for i in range(14): 64 | gate_order.append("g"+str(i)) 65 | 66 | Q = Queue() 67 | 68 | x = Gate("in0","INPUT",[],Q) 69 | y = Gate("in1","INPUT",[],Q) 70 | wi = Gate("in2","INPUT",[],Q) 71 | bi = Gate("in3","INPUT",[],Q) 72 | 73 | # gate for dot product of x and w 74 | g0 = Gate("g0","DOT",[x.get_id(),wi.get_id()],Q) 75 | 76 | # gate for adding b to dot product of x and w 77 | g1 = Gate("g1","ADD",[bi.get_id(),g0.get_id()],Q) 78 | 79 | # gate for multiplying y with b + dot(x,w) 80 | g2 = Gate("g2","MULT",[y.get_id(),g1.get_id()],Q) 81 | 82 | # gate for computing y(b + dot(x,w)) <= 0 83 | g3 = Gate("g3","COMP",[g2.get_id()],Q) 84 | 85 | # gate for computing y*x for conditional assignment to w 86 | g4 = Gate("g4","SMULT",[y.get_id(),x.get_id()],Q) 87 | 88 | # gate for computing w + x*y for conditional assignment to w 89 | g5 = Gate("g5","ADD",[wi.get_id(),g4.get_id()],Q) 90 | 91 | # gate for computing b + y for conditional assignment to b 92 | g6 = Gate("g6","ADD",[bi.get_id(),y.get_id()],Q) 93 | 94 | # gate for computing not of if statement 95 | g7 = Gate("g7","NOT",[g3.get_id()],Q) 96 | 97 | # gate for computing if conditional assignment to w 98 | g8 = Gate("g8","SMULT",[g3.get_id(),g5.get_id()],Q) 99 | 100 | # gate for computing else conditional assignment to w 101 | g9 = Gate("g9","SMULT",[g7.get_id(),wi.get_id()],Q) 102 | 103 | # gate for computing output for w 104 | g10 = Gate("g10","ADD",[g8.get_id(),g9.get_id()],Q) 105 | 106 | # gate for computing if conditional assignment to b 107 | g11 = Gate("g11","MULT",[g3.get_id(),g6.get_id()],Q) 108 | 109 | # gate for computing else conditional assignment to b 110 | g12 = Gate("g12","MULT",[bi.get_id(),g7.get_id()],Q) 111 | 112 | # gate for computing output for b 113 | g13 = Gate("g13","ADD",[g11.get_id(),g12.get_id()],Q) 114 | 115 | # output values 116 | wo = Gate("out0","OUTPUT",[g10.get_id()],Q) 117 | bo = Gate("out1","OUTPUT",[g13.get_id()],Q) 118 | 119 | circuit = {} 120 | 121 | circuit[x.get_id()] = [g0,g4] 122 | circuit[y.get_id()] = [g2,g4,g6] 123 | circuit[wi.get_id()] = [g0,g5,g9] 124 | circuit[bi.get_id()] = [g1,g6,g12] 125 | circuit[g0.get_id()] = [g1] 126 | circuit[g1.get_id()] = [g2] 127 | circuit[g2.get_id()] = [g3] 128 | circuit[g3.get_id()] = [g7,g8,g11] 129 | circuit[g4.get_id()] = [g5] 130 | circuit[g5.get_id()] = [g8] 131 | circuit[g6.get_id()] = [g11] 132 | circuit[g7.get_id()] = [g9,g12] 133 | circuit[g8.get_id()] = [g10] 134 | circuit[g9.get_id()] = [g10] 135 | circuit[g10.get_id()] = [wo] 136 | circuit[g11.get_id()] = [g13] 137 | circuit[g12.get_id()] = [g13] 138 | circuit[g13.get_id()] = [bo] 139 | 140 | in_gates = [x,y,wi,bi] 141 | out_gates = [wo,bo] -------------------------------------------------------------------------------- /examples/perceptron/eval_circuit.py: -------------------------------------------------------------------------------- 1 | from src.circuits.evaluator import BasicEvaluator 2 | from src.circuits.evaluator import SecureEvaluator 3 | from src.circuits.dealer import Dealer 4 | from src.circuits.oracle import Oracle 5 | #import circuit 6 | from examples.perceptron import circuit as circ 7 | import numpy as np 8 | from threading import Thread 9 | import copy 10 | import time 11 | 12 | def secure_eval_circuit(data,num_iterations,modulus,initial_w=0,initial_b=0,fp_precision=16): 13 | """ 14 | Function that evaluates the perceptron circuit using three SecureEvaluator 15 | objects. The current protocol also requires a Dealer and an Oracle. 16 | 17 | Parameters 18 | ---------- 19 | data: iterable 20 | Data to be input into the perceptron algorithm (assumed iterable pairs) 21 | num_iterations: int 22 | Number of iterations that algorithm will run for 23 | modulus: int 24 | Value representing the modulus of field used 25 | (optional) initial_w=0: int 26 | Initial value of w, parameter of perceptron algorithm 27 | (optional) initial_b=0: int 28 | Initial value of b, parameter of perceptron algorithm 29 | (optional) fp_precision=16: int 30 | Fixed point number precision 31 | 32 | Returns 33 | ------- 34 | w: float 35 | w value achieved after num_iterations of perceptron 36 | b: int 37 | b value achieved after num_iterations of perceptron 38 | """ 39 | 40 | # account for fixed point precision 41 | scale = 10**fp_precision 42 | 43 | circ1 = copy.deepcopy(circ.circuit) 44 | circ2 = copy.deepcopy(circ.circuit) 45 | circ3 = copy.deepcopy(circ.circuit) 46 | 47 | # initialize evaluators 48 | evaluator1 = SecureEvaluator(circ1,circ.in_gates,circ.out_gates,1,modulus) 49 | evaluator2 = SecureEvaluator(circ2,circ.in_gates,circ.out_gates,2,modulus) 50 | evaluator3 = SecureEvaluator(circ3,circ.in_gates,circ.out_gates,3,modulus) 51 | 52 | parties = [evaluator1,evaluator2,evaluator3] 53 | party_dict = {1: evaluator1, 2: evaluator2, 3: evaluator3} 54 | 55 | evaluator1.add_parties(party_dict) 56 | evaluator2.add_parties(party_dict) 57 | evaluator3.add_parties(party_dict) 58 | 59 | # initialize dealer 60 | dealer = Dealer(parties,modulus,fp_precision=fp_precision) 61 | 62 | start_time = time.time() 63 | 64 | # split x_data and y_data into 3 lists, one for each party 65 | # this simulates each party having private input data 66 | data_len = len(data) 67 | data1x = [] 68 | data2x = [] 69 | data3x = [] 70 | data1y = [] 71 | data2y = [] 72 | data3y = [] 73 | 74 | split = int(data_len/3) 75 | 76 | for i in range(split): 77 | data1x.append(data[i][0]) 78 | data1y.append(data[i][1]) 79 | data2x.append(data[split + i][0]) 80 | data2y.append(data[split + i][1]) 81 | data3x.append(data[2*split + 1][0]) 82 | data3y.append(data[2*split + 1][1]) 83 | 84 | # use dealer to create shares of all inputs 85 | dealer.distribute_shares(data1x) 86 | dealer.distribute_shares(data2x) 87 | dealer.distribute_shares(data3x) 88 | 89 | dealer.distribute_shares(data1y) 90 | dealer.distribute_shares(data2y) 91 | dealer.distribute_shares(data3y) 92 | 93 | # use dealer to create random values for interactive operations 94 | num_randomness = 3000 * num_iterations 95 | dealer.generate_randomness(num_randomness) 96 | dealer.generate_truncate_randomness(5*num_iterations) 97 | 98 | # need to make dimenions of w the same as x 99 | if initial_w == 0: 100 | first_x = data[0][0] 101 | initial_w = np.zeros(len(first_x)) 102 | initial_w = [initial_w,[]] 103 | 104 | dealer.distribute_shares(initial_w) 105 | dealer.distribute_shares(initial_b) 106 | 107 | results = {} 108 | 109 | # for each iteration of perceptron algorithm, have each SecureEvaluator 110 | # compute the circuit, each on their own thread, so they can interact 111 | res = {} 112 | for i in range(num_iterations): 113 | 114 | if i % 10 == 0: 115 | print("iteration: " + str(i)) 116 | 117 | t1 = Thread(target=run_eval,args=(evaluator1,i,data_len,results,1,fp_precision,res)) 118 | t2 = Thread(target=run_eval,args=(evaluator2,i,data_len,results,2,fp_precision,res)) 119 | t3 = Thread(target=run_eval,args=(evaluator3,i,data_len,results,3,fp_precision,res)) 120 | 121 | t1.start() 122 | t2.start() 123 | t3.start() 124 | 125 | t1.join() 126 | t2.join() 127 | t3.join() 128 | 129 | print("iter 0: " + str(unshare(res["0_1"][0],res["0_2"][0])) + ", " + str(unshare(res["0_1"][1],res["0_2"][1]))) 130 | print("iter 1: " + str(unshare(res["1_1"][0],res["1_2"][0])) + ", " + str(unshare(res["1_1"][1],res["1_2"][1]))) 131 | print("iter 2: " + str(unshare(res["2_1"][0],res["2_2"][0])) + ", " + str(unshare(res["2_1"][1],res["2_2"][1]))) 132 | print("iter 3: " + str(unshare(res["3_1"][0],res["3_2"][0])) + ", " + str(unshare(res["3_1"][1],res["3_2"][1]))) 133 | print("iter 4: " + str(unshare(res["4_1"][0],res["4_2"][0])) + ", " + str(unshare(res["4_1"][1],res["4_2"][1]))) 134 | print("iter 5: " + str(unshare(res["5_1"][0],res["5_2"][0])) + ", " + str(unshare(res["5_1"][1],res["5_2"][1]))) 135 | print("iter 6: " + str(unshare(res["6_1"][0],res["6_2"][0])) + ", " + str(unshare(res["6_1"][1],res["6_2"][1]))) 136 | print("iter 7: " + str(unshare(res["7_1"][0],res["7_2"][0])) + ", " + str(unshare(res["7_1"][1],res["7_2"][1]))) 137 | 138 | 139 | 140 | 141 | # extract final outputs, scale them down 142 | (w,b) = get_w_b(results) 143 | #return (w / scale, b / scale) 144 | 145 | elapsed_time = time.time() - start_time 146 | print("elapsed time: " + str(elapsed_time)) 147 | return (w,b) 148 | 149 | def unshare(share1,share2): 150 | """ 151 | Method for converting shares into their hidden value 152 | 153 | Parameters 154 | ---------- 155 | share1: int or iterable 156 | Shares of value 157 | share2: int or iterable 158 | Shares of same value as share1 159 | 160 | Returns 161 | ------- 162 | res: 163 | value hidden by share1 and share2 164 | """ 165 | 166 | if type(share1) == list: 167 | res = [] 168 | for i in range(len(share1)): 169 | res.append(share1[i].unshare(share2[i])) 170 | print(share1[i].unshare(share2[i])) 171 | 172 | else: 173 | res = share1.unshare(share2) 174 | 175 | return res 176 | 177 | def get_w_b(w_b_shares): 178 | """ 179 | Method for computing (w,b) from their shares 180 | 181 | Parameters 182 | ---------- 183 | w_b_shares: dictionary 184 | Dictionary of shares for values of (w,b) 185 | 186 | Returns 187 | ------- 188 | w: float 189 | w value achieved after num_iterations of perceptron 190 | b: int 191 | b value achieved after num_iterations of perceptron 192 | """ 193 | 194 | w1 = w_b_shares[1]['w'] 195 | b1 = w_b_shares[1]['b'] 196 | w2 = w_b_shares[2]['w'] 197 | b2 = w_b_shares[2]['b'] 198 | w3 = w_b_shares[3]['w'] 199 | b3 = w_b_shares[3]['b'] 200 | 201 | w = [w1[0].unshare(w2[0]), w1[1].unshare(w2[1])] 202 | b = b1.unshare(b2) 203 | 204 | 205 | return (w,b) 206 | 207 | 208 | def run_eval(evaluator,iter_num,data_length,results_dict,party_index,fp_precision=16,wd={}): 209 | """ 210 | Method to be run by each SecureEvaluator within their Thread (this will be 211 | called with secure_eval_circuit). 212 | 213 | Parameters 214 | ---------- 215 | evaluator: SecureEvaluator object 216 | SecureEvaluator that will compute an iteration of perceptron algorithm 217 | iter_num: int 218 | Iteration number of perceptron algorithm 219 | data_length: int 220 | Integer representing length of input data 221 | results_dict: dictionary 222 | Dictionary for each thread to insert ouput values 223 | party_index: int 224 | Integer representing evaluator party index 225 | (optional) fp_precision=16: int 226 | Fixed point number precision 227 | """ 228 | 229 | scale = 10**fp_precision 230 | 231 | # input will map wire name to index in list of shares 232 | cur_input = {} 233 | cur_input["in0"] = iter_num 234 | cur_input["in1"] = data_length + iter_num 235 | 236 | # only load initial b and w 237 | #if iter_num == 0: 238 | # cur_input["in2"] = -2 239 | # cur_input["in3"] = -1 240 | 241 | cur_input["in2"] = -2 242 | cur_input["in3"] = -1 243 | 244 | evaluator.load_secure_inputs(cur_input) 245 | evaluator.run() 246 | 247 | [w,b] = evaluator.get_outputs() 248 | evaluator.reset_circuit() 249 | 250 | if iter_num < 10: 251 | wd[str(iter_num) + "_" + str(party_index)] = [w,b] 252 | 253 | #cur_in = {} 254 | #cur_in["in2"] = w 255 | #cur_in["in3"] = b 256 | #evaluator.load_inputs(cur_in) 257 | 258 | evaluator.receive_shares([w,b]) 259 | 260 | results_dict[party_index] = {"w": w, "b": b} 261 | 262 | if __name__ == "__main__": 263 | MOD = 10001112223334445556667778889991 264 | 265 | import data.iris_data as iris 266 | 267 | data = iris.get_iris_data() 268 | 269 | num_iter = len(data) 270 | 271 | #print(eval_circuit(data,num_iter)) 272 | 273 | print(secure_eval_circuit(data,num_iter,MOD,fp_precision=10)) -------------------------------------------------------------------------------- /examples/svm/alg.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import time 4 | import math 5 | 6 | MOD = MOD = 622288097498926496141095869268883999563096063592498055290461 7 | MOD_BIT_SIZE = len(bin(MOD)[2:]) 8 | 9 | def svm(data, num_iterations, initial_w=0, initial_b=0, hyper_param=1, fp_precision=16): 10 | 11 | svm_data = alter_data(data) 12 | #svm_data = data 13 | 14 | # need to make dimenions of w the same as x 15 | if initial_w == 0: 16 | first_x = svm_data[0][0] 17 | initial_w = np.zeros(len(first_x)) 18 | 19 | w = initial_w 20 | b = initial_b 21 | 22 | # use fixed point numbers 23 | # input data should be scaled up by 10^fp_precision 24 | # also scale down by 10^fp_precision after every mult 25 | scale = 10**fp_precision 26 | #scale = 1.0 27 | 28 | start_time = time.time() 29 | 30 | for i in range(num_iterations): 31 | 32 | learning_rate = int(round((1.0 / (1.0 + i))*10**7) / 10**7 * scale) 33 | #learning_rate = int(0.5 * scale) 34 | #learning_rate = int(1*scale) 35 | #learning_rate = int(scale / 3) 36 | 37 | np_x = np.array(svm_data[i][0]) 38 | y = svm_data[i][1] 39 | 40 | # we use fixed point, so multiply by precision and round to integer 41 | for a in range(len(np_x)): 42 | np_x[a] = int( np_x[a] * scale) 43 | y = int( y * scale) 44 | 45 | # if y * (w dot x) <= 1: 46 | # w <- (1 - learning rate) * w + (learning rate) * hyper_param * y * x 47 | # else: 48 | # w <- (1 - learning rate) * w 49 | xw_dot = int(np.dot(np_x,w) / scale) 50 | #xw_dot = np.dot(np_x,w) 51 | if (y / scale * xw_dot) <= (1*scale): 52 | #print(w) 53 | w = ((scale - learning_rate) / scale * w) + ((learning_rate / scale * hyper_param * y) / scale * np_x) 54 | #print(w) 55 | else: 56 | w = (1*scale - learning_rate) / scale * w 57 | 58 | mod_bit_size = MOD_BIT_SIZE 59 | #trunc_val = 2**int((mod_bit_size - 1) / 3) 60 | #trunc_val = 2**20 61 | trunc_val = 10**7 62 | 63 | #w = np.array([int(round(el / 10**7)*10**7) for el in w]) 64 | w = np.array([int(math.floor(el / trunc_val)*trunc_val) for el in w]) 65 | #for a in range(len(w)): 66 | # w[a] = int(w[a]) 67 | # print(w[a]) 68 | # print(int(w[a])) 69 | 70 | #if (i < 150): 71 | # print("iter " + str(i) + ": " + str(w)) 72 | #print(w) 73 | 74 | #if 98 < i < 102: 75 | # print("ITERATION: " + str(i)) 76 | # print("data: " + str((np_x,y))) 77 | 78 | return_w = w[:-1] 79 | return_b = w[-1] 80 | #print(w) 81 | #print(return_w) 82 | elapsed_time = time.time() - start_time 83 | print("elapsed time: " + str(elapsed_time)) 84 | 85 | return (return_w / scale,return_b / scale) 86 | #return w / scale 87 | 88 | def alter_data(data): 89 | 90 | new_data = [] 91 | 92 | for i in range(len(data)): 93 | old_x = data[i][0] 94 | old_y = data[i][1] 95 | 96 | new_x = np.append(old_x,1) 97 | 98 | new_data.append([np.array(new_x),old_y]) 99 | 100 | #for line in new_data: 101 | #print(line) 102 | return new_data 103 | 104 | def plot_data_line(data,line): 105 | 106 | x1_vals = [] 107 | y1_vals = [] 108 | x2_vals = [] 109 | y2_vals = [] 110 | 111 | for i in range(len(data)): 112 | y = data[i][1] 113 | x = data[i][0] 114 | 115 | if y == -1: 116 | x1_vals.append(x[0]) 117 | y1_vals.append(x[1]) 118 | else: 119 | x2_vals.append(x[0]) 120 | y2_vals.append(x[1]) 121 | 122 | plt.scatter(x1_vals,y1_vals,color='r') 123 | plt.scatter(x2_vals,y2_vals,color='b') 124 | plt.show() 125 | 126 | 127 | if __name__ == "__main__": 128 | 129 | import data.iris_data as iris 130 | data = iris.get_iris_data() 131 | 132 | num_iter = len(data) 133 | print(svm(data,num_iter,fp_precision=10)) 134 | #plot_data_line(data,"") -------------------------------------------------------------------------------- /examples/svm/circuit.py: -------------------------------------------------------------------------------- 1 | from src.circuits.gate import Gate 2 | from queue import Queue 3 | 4 | # this file contains the circuit for one iteration of a support vector machine 5 | # here we use sub gradient descent 6 | # the circuit is represented via a python dictionary 7 | # each gate is given a unique ID and label 8 | # the unique ID for the gate will serve as its key 9 | # the IDs will be assigned as g1 for gate one, g2 ... 10 | # the values of the dictionary correspond to which gates the output of the key 11 | # will input into 12 | # for example circuit["g1"] = [g2,g3] means that the output of gate "g1" 13 | # will serve as input for both gates "g2" and "g3" 14 | 15 | # here is a summary of the algorithm 16 | 17 | # first, we augment the data such that we set all x values to be x' = [x,1] 18 | # this will then solve for a w', where w' = [w,b] 19 | # after augmenting the data, we do the following 20 | 21 | # initialize w,b (for us, we take w,b to be random integer in {-1,1}) 22 | # for each data point (x',y): 23 | # if y * ( dotproduct(w',x') ) <= 1: 24 | # w' = (1 - gamma) * w' + gamma * C * y * x' 25 | # else: 26 | # w' = (1 - gamma) * w' 27 | # 28 | # here gamma is the "learning rate", and C is a hyper-parameter 29 | # we chose the following values: 30 | # gamma = 1 / (1 + [iteration_number]) 31 | # C = 1 32 | 33 | # this circuit will consist of the code block below the for loop 34 | # there are 5 input values for each iteration: x', y, w', (1 - gamma), gamma*C 35 | # they are labeled as follows: 36 | # x' -> input0 37 | # y -> input1 38 | # w' -> input2 39 | # (1 - gamma) -> input3 40 | # gamma * C -> input4 41 | 42 | # there is one output value: w' 43 | # labeled as follows: 44 | # w' -> output0 45 | # the additional intermediate wires will be given the lable gi for all i 46 | 47 | # we use the following gate labels: ADD, MULT, SMULT, COMP, DOT, NOT, CMULT, CADD 48 | # ADD is the addition gate 49 | # MULT is the multiplication gate 50 | # SMULT is the scalar multiplication gate 51 | # COMP is the comparison gate, which computes the boolean (input <= 0) 52 | # DOT is the dot product gate, which computes the dot product of two inputs 53 | # NOT is the not gate, which computes 1 - input (input 0 or 1 here) 54 | # CMULT is multiplying by a constant (i.e. multiply share of x by public c val) 55 | # CADD is addition by constant (i.e. adding a public c val to a share of x) 56 | 57 | x = Gate("in0","INPUT",[]) 58 | y = Gate("in1","INPUT",[]) 59 | wi = Gate("in2","INPUT",[]) 60 | gamma1 = Gate("in3","INPUT",[]) 61 | gammaC = Gate("in4","INPUT",[]) 62 | minus1 = Gate("in5","INPUT",[]) 63 | 64 | # gate for dot product of x' and w' 65 | g0 = Gate("g0","DOT",[x.get_id(),wi.get_id()]) 66 | 67 | # gate for multiplying y with dot product of x' and w' 68 | g1 = Gate("g1","MULT",[y.get_id(),g0.get_id()]) 69 | 70 | # gate for subtracting 1 from y*dot(x',w') 71 | g2 = Gate("g2","CADD",[g1.get_id(),minus1.get_id()],const_input=minus1.get_id()) 72 | 73 | # gate for comparing y*dot(x',w') - 1 <= 0 74 | g3 = Gate("g3","COMP",[g2.get_id()]) 75 | 76 | # gate for multipling y with gamma*C 77 | g4 = Gate("g4","CMULT",[y.get_id(),gammaC.get_id()],const_input=gammaC.get_id()) 78 | 79 | # gate for multiplying y*gamma*C with x' 80 | g5 = Gate("g5","SMULT",[g4.get_id(),x.get_id()]) 81 | 82 | # gate for multiplying w' with (1 - gamma) 83 | g6 = Gate("g6","CMULT",[wi.get_id(),gamma1.get_id()],const_input=gamma1.get_id()) 84 | 85 | # gate for adding (1-gamma)w' with y*gamma*C*x' 86 | g7 = Gate("g7","ADD",[g5.get_id(),g6.get_id()]) 87 | 88 | # gate for multiplying result of comparison with g7 output 89 | g8 = Gate("g8","SMULT",[g3.get_id(),g7.get_id()]) 90 | 91 | # gate for computing NOT of comparison (for multiplexing) 92 | g9 = Gate("g9","NOT",[g3.get_id()]) 93 | 94 | # gate for multiplying g9 (not of comparison) with (1-gamma)w' 95 | g10 = Gate("g10","SMULT",[g9.get_id(),g6.get_id()]) 96 | 97 | # gate for adding two multiplexes 98 | g11 = Gate("g11","ADD",[g8.get_id(),g10.get_id()]) 99 | 100 | # gate for rounding results 101 | g12 = Gate("g12","ROUND",[g11.get_id()]) 102 | 103 | # gate for outputing w' 104 | wo = Gate("out0","OUTPUT",[g12.get_id()]) 105 | 106 | circuit = {} 107 | 108 | circuit[x.get_id()] = [g0,g5] 109 | circuit[y.get_id()] = [g1,g4] 110 | circuit[wi.get_id()] = [g0,g6] 111 | circuit[gamma1.get_id()] = [g6] 112 | circuit[gammaC.get_id()] = [g4] 113 | circuit[minus1.get_id()] = [g2] 114 | circuit[g0.get_id()] = [g1] 115 | circuit[g1.get_id()] = [g2] 116 | circuit[g2.get_id()] = [g3] 117 | circuit[g3.get_id()] = [g8,g9] 118 | circuit[g4.get_id()] = [g5] 119 | circuit[g5.get_id()] = [g7] 120 | circuit[g6.get_id()] = [g7,g10] 121 | circuit[g7.get_id()] = [g8] 122 | circuit[g8.get_id()] = [g11] 123 | circuit[g9.get_id()] = [g10] 124 | circuit[g10.get_id()] = [g11] 125 | circuit[g11.get_id()] = [g12] 126 | circuit[g12.get_id()] = [wo] 127 | 128 | in_gates = [x,y,wi,gamma1,gammaC,minus1] 129 | out_gates = [wo] -------------------------------------------------------------------------------- /examples/svm/eval_circuit.py: -------------------------------------------------------------------------------- 1 | from src.circuits.evaluator import BasicEvaluator 2 | from src.circuits.evaluator import SecureEvaluator 3 | from src.circuits.dealer import Dealer 4 | from src.circuits.oracle import Oracle 5 | from examples.svm import circuit as circ 6 | import numpy as np 7 | from threading import Thread 8 | import copy 9 | from examples.svm.alg import alter_data 10 | import asyncio 11 | import time 12 | 13 | 14 | 15 | def secure_eval_circuit(data,num_iterations,modulus,initial_w=0,initial_b=0,fp_precision=16): 16 | """ 17 | Function that evaluates the perceptron circuit using three SecureEvaluator 18 | objects. The current protocol also requires a Dealer and an Oracle. 19 | 20 | Parameters 21 | ---------- 22 | data: iterable 23 | Data to be input into the perceptron algorithm (assumed iterable pairs) 24 | num_iterations: int 25 | Number of iterations that algorithm will run for 26 | modulus: int 27 | Value representing the modulus of field used 28 | (optional) initial_w=0: int 29 | Initial value of w, parameter of perceptron algorithm 30 | (optional) initial_b=0: int 31 | Initial value of b, parameter of perceptron algorithm 32 | (optional) fp_precision=16: int 33 | Fixed point number precision 34 | 35 | Returns 36 | ------- 37 | w: float 38 | w value achieved after num_iterations of perceptron 39 | b: int 40 | b value achieved after num_iterations of perceptron 41 | """ 42 | 43 | # alter data to work for svm 44 | data = alter_data(data) 45 | 46 | # account for fixed point precision 47 | scale = 10**fp_precision 48 | 49 | circ1 = copy.deepcopy(circ.circuit) 50 | circ2 = copy.deepcopy(circ.circuit) 51 | circ3 = copy.deepcopy(circ.circuit) 52 | 53 | # initialize evaluators 54 | evaluator1 = SecureEvaluator(circ1,circ.in_gates,circ.out_gates,1,modulus,fp_precision=fp_precision) 55 | evaluator2 = SecureEvaluator(circ2,circ.in_gates,circ.out_gates,2,modulus,fp_precision=fp_precision) 56 | evaluator3 = SecureEvaluator(circ3,circ.in_gates,circ.out_gates,3,modulus,fp_precision=fp_precision) 57 | 58 | #evaluator1 = SecureEvaluator(circ1,circ.in_gates,circ.out_gates,1,oracle,modulus) 59 | #evaluator2 = SecureEvaluator(circ2,circ.in_gates,circ.out_gates,2,oracle,modulus) 60 | #evaluator3 = SecureEvaluator(circ3,circ.in_gates,circ.out_gates,3,oracle,modulus) 61 | 62 | 63 | parties = [evaluator1,evaluator2,evaluator3] 64 | party_dict = {1: evaluator1, 2: evaluator2, 3: evaluator3} 65 | 66 | evaluator1.add_parties(party_dict) 67 | evaluator2.add_parties(party_dict) 68 | evaluator3.add_parties(party_dict) 69 | 70 | # initialize dealer 71 | dealer = Dealer(parties,modulus,fp_precision=fp_precision) 72 | 73 | start_time = time.time() 74 | 75 | # split x_data and y_data into 3 lists, one for each party 76 | # this simulates each party having private input data 77 | data_len = len(data) 78 | data1x = [] 79 | data2x = [] 80 | data3x = [] 81 | data1y = [] 82 | data2y = [] 83 | data3y = [] 84 | 85 | split = int(data_len/3) 86 | 87 | for i in range(split): 88 | data1x.append(data[i][0]) 89 | data1y.append(data[i][1]) 90 | data2x.append(data[split + i][0]) 91 | data2y.append(data[split + i][1]) 92 | data3x.append(data[2*split + i][0]) 93 | data3y.append(data[2*split + i][1]) 94 | 95 | # use dealer to create shares of all inputs 96 | dealer.distribute_shares(data1x) 97 | dealer.distribute_shares(data2x) 98 | dealer.distribute_shares(data3x) 99 | 100 | dealer.distribute_shares(data1y) 101 | dealer.distribute_shares(data2y) 102 | dealer.distribute_shares(data3y) 103 | 104 | # use dealer to create random values for interactive operations 105 | num_randomness = 10000 * num_iterations 106 | dealer.generate_randomness(num_randomness) 107 | dealer.generate_truncate_randomness(5*num_iterations) 108 | 109 | # need to make dimenions of w the same as x 110 | if initial_w == 0: 111 | first_x = data[0][0] 112 | initial_w = np.zeros(len(first_x)) 113 | initial_w = [initial_w,[]] 114 | 115 | dealer.distribute_shares(initial_w) 116 | #dealer.distribute_shares(initial_b) 117 | 118 | results = {} 119 | 120 | # for each iteration of perceptron algorithm, have each SecureEvaluator 121 | # compute the circuit, each on their own thread, so they can interact 122 | res = {} 123 | for i in range(num_iterations): 124 | #for i in range(1): 125 | 126 | #print("iteration: " + str(i)) 127 | 128 | t1 = Thread(target=run_eval,args=(evaluator1,i,data_len,results,1,modulus,fp_precision,res)) 129 | t2 = Thread(target=run_eval,args=(evaluator2,i,data_len,results,2,modulus,fp_precision,res)) 130 | t3 = Thread(target=run_eval,args=(evaluator3,i,data_len,results,3,modulus,fp_precision,res)) 131 | 132 | t1.start() 133 | t2.start() 134 | t3.start() 135 | 136 | t1.join() 137 | t2.join() 138 | t3.join() 139 | 140 | #for a in range(150): 141 | #for a in range(5): 142 | # print("iter " + str(a) + ": " + str(unshare(res[str(a)+"_1"][0],res[str(a)+"_2"][0]))) 143 | 144 | # extract final outputs, scale them down 145 | (w,b) = get_w_b(results) 146 | #return (w / scale, b / scale) 147 | wout = [] 148 | for el in w: 149 | wout.append(el / scale) 150 | bout = b / scale 151 | elapsed_time = time.time() - start_time 152 | print("elapsed time: " + str(elapsed_time)) 153 | return (np.array(wout),bout) 154 | 155 | def unshare(share1,share2): 156 | """ 157 | Method for converting shares into their hidden value 158 | 159 | Parameters 160 | ---------- 161 | share1: int or iterable 162 | Shares of value 163 | share2: int or iterable 164 | Shares of same value as share1 165 | 166 | Returns 167 | ------- 168 | res: 169 | value hidden by share1 and share2 170 | """ 171 | 172 | if type(share1) == list: 173 | res = [] 174 | for i in range(len(share1)): 175 | res.append(share1[i].unshare(share2[i])) 176 | 177 | else: 178 | res = share1.unshare(share2) 179 | 180 | return res 181 | 182 | def get_w_b(w_b_shares): 183 | """ 184 | Method for computing (w,b) from their shares 185 | 186 | Parameters 187 | ---------- 188 | w_b_shares: dictionary 189 | Dictionary of shares for values of (w,b) 190 | 191 | Returns 192 | ------- 193 | w: float 194 | w value achieved after num_iterations of perceptron 195 | b: int 196 | b value achieved after num_iterations of perceptron 197 | """ 198 | 199 | w1 = w_b_shares[1]['w'] 200 | w2 = w_b_shares[2]['w'] 201 | w3 = w_b_shares[3]['w'] 202 | 203 | w = [w1[0].unshare(w2[0]), w1[1].unshare(w2[1])] 204 | b = w1[2].unshare(w2[2]) 205 | 206 | return (w,b) 207 | 208 | 209 | def run_eval(evaluator,iter_num,data_length,results_dict,party_index,mod,fp_precision=16,wd={}): 210 | """ 211 | Method to be run by each SecureEvaluator within their Thread (this will be 212 | called with secure_eval_circuit). 213 | 214 | Parameters 215 | ---------- 216 | evaluator: SecureEvaluator object 217 | SecureEvaluator that will compute an iteration of perceptron algorithm 218 | iter_num: int 219 | Iteration number of perceptron algorithm 220 | data_length: int 221 | Integer representing length of input data 222 | results_dict: dictionary 223 | Dictionary for each thread to insert ouput values 224 | party_index: int 225 | Integer representing evaluator party index 226 | (optional) fp_precision=16: int 227 | Fixed point number precision 228 | """ 229 | 230 | scale = 10**fp_precision 231 | 232 | # input will map wire name to index in list of shares 233 | cur_input = {} 234 | cur_input["in0"] = iter_num 235 | cur_input["in1"] = data_length + iter_num 236 | 237 | # only load initial b and w 238 | #if iter_num == 0: 239 | # cur_input["in2"] = -2 240 | # cur_input["in3"] = -1 241 | 242 | cur_input["in2"] = -1 243 | 244 | evaluator.load_secure_inputs(cur_input) 245 | 246 | # need to load in constant values 247 | # -1: for computing <=1, subtract 1 and comp <= 0 248 | # gam1: need to have 1 - gamma for computing w 249 | # gamC: need gamma * C also for computing w 250 | neg1 = int(-1*scale) 251 | pre_gamma = 1.0 / (1.0 + iter_num) 252 | gamma = round(pre_gamma * 10**7) 253 | gamma = int(gamma * scale) 254 | gamma = int(gamma / 10**7) 255 | gam1 = int(1*scale - gamma) 256 | gamC = int(gamma * 1) 257 | 258 | load_in_constants = {} 259 | load_in_constants["in3"] = gam1 260 | load_in_constants["in4"] = gamC 261 | load_in_constants["in5"] = neg1 262 | 263 | evaluator.load_inputs(load_in_constants) 264 | 265 | evaluator.run() 266 | 267 | [w] = evaluator.get_outputs() 268 | evaluator.reset_circuit() 269 | 270 | wd[str(iter_num) + "_" + str(party_index)] = [w] 271 | 272 | #cur_in = {} 273 | #cur_in["in2"] = w 274 | #cur_in["in3"] = b 275 | #evaluator.load_inputs(cur_in) 276 | 277 | evaluator.receive_shares([w]) 278 | 279 | results_dict[party_index] = {"w": w} 280 | 281 | def mod_inverse(val, mod): 282 | g, x, y = egcd(val, mod) 283 | if g != 1: 284 | raise Exception('modular inverse does not exist') 285 | else: 286 | return x % mod 287 | 288 | def egcd(a,b): 289 | if a == 0: 290 | return (b,0,1) 291 | else: 292 | g, y, x = egcd(b %a, a) 293 | return (g, x - (b //a) * y, y) 294 | 295 | 296 | if __name__ == "__main__": 297 | 298 | MOD = 10001112223334445556667778889991 299 | 300 | # 199 bits 301 | #MOD = 622288097498926496141095869268883999563096063592498055290461 302 | 303 | #MOD = 24684249032065892333066123534168930441269525239006410135714283699648991959894332868446109170827166448301044689 304 | 305 | import data.iris_data as iris 306 | 307 | data = iris.get_iris_data() 308 | 309 | num_iter = len(data) 310 | #num_iter = 10 311 | 312 | print(secure_eval_circuit(data,num_iter,MOD,fp_precision=10)) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | astroid==2.2.5 2 | atomicwrites==1.3.0 3 | attrs==19.1.0 4 | autopep8==1.4.4 5 | cycler==0.10.0 6 | importlib-metadata==0.18 7 | isort==4.3.21 8 | kiwisolver==1.1.0 9 | lazy-object-proxy==1.4.1 10 | matplotlib==3.1.1 11 | mccabe==0.6.1 12 | mock==3.0.5 13 | more-itertools==7.1.0 14 | -e git+https://github.com/trailofbits/mpc-learning@0b0cb3ffce17e48a8e844f8fae0b1e7a426fffe2#egg=MPC_learning 15 | numpy==1.16.4 16 | packaging==19.0 17 | pandas==0.24.2 18 | pep8==1.7.1 19 | pluggy==0.12.0 20 | py==1.8.0 21 | pycodestyle==2.5.0 22 | pylint==2.3.1 23 | pyparsing==2.4.0 24 | pytest==4.6.3 25 | pytest-mock==1.10.4 26 | python-dateutil==2.8.0 27 | pytz==2019.1 28 | six==1.12.0 29 | typed-ast==1.4.0 30 | wcwidth==0.1.7 31 | wrapt==1.11.2 32 | zipp==0.5.1 33 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | with open('requirements.txt') as f: 4 | requirements = f.read().splitlines() 5 | 6 | setup( 7 | name="MPC-learning", 8 | version="1.0", 9 | install_requires=requirements, 10 | packages=['src','examples','data'], 11 | long_description=open("README.md").read(), 12 | platforms=['any'] 13 | ) -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- 1 | # this file contains the raw Perceptron algorithm to use for testing 2 | # ideally the output of this algorithm will match that of the circuit 3 | 4 | def perceptron(x_data,y_data,num_iterations): 5 | """ 6 | 7 | 8 | """ -------------------------------------------------------------------------------- /src/circuits/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/trailofbits/mpc-learning/7fa64451f9d6d5a4bb5fe2465762f2734952fc4c/src/circuits/__init__.py -------------------------------------------------------------------------------- /src/circuits/dealer.py: -------------------------------------------------------------------------------- 1 | from random import randint 2 | import numpy as np 3 | from src.circuits.share import Share 4 | import math 5 | 6 | class Dealer(): 7 | """ 8 | Dealer class that is responsible for creating input shares for the 9 | SecureEvaluators. 10 | 11 | Methods 12 | ------- 13 | __init__(self, parties, modulus, fp_precision=16) 14 | Dealer object constructor 15 | Initliaze parties, mod, modulus, and fp_precision 16 | 17 | Parameters 18 | ---------- 19 | parties: iterable 20 | Iterable of Evaluators 21 | modulus: integer 22 | Modulus representing input domain 23 | (optional) fp_precision=16: int 24 | Fixed point number precision 25 | 26 | generate_randomness(self, number_of_values) 27 | Method that generates number_of_values random values for parties 28 | According to current protocol, each party receives a value x_i 29 | such that x_1 + x_2 + x_3 = 0 30 | 31 | Parameters 32 | ---------- 33 | number_of_values: int 34 | Number of random shares to be generated 35 | 36 | distribute_shares(self, inputs, random=False, verbose=False) 37 | Method that generates shares of inputs and sends the shares 38 | to the three parties 39 | 40 | Parameters 41 | ---------- 42 | inputs: iterable 43 | Iterable of input values to be converted to shares 44 | (optional) random=False: boolean 45 | Boolean indicating whether we are creating shares of inputs 46 | or shares for generate_randomness 47 | (optional) verbose=False: boolean 48 | Boolean to turn on verbose mode for debugging 49 | 50 | _make_shares(self, input_value) 51 | Method for generating shares from an input value (called within 52 | distribute_shares) 53 | 54 | Parameters 55 | ---------- 56 | input_value: int or iterable 57 | Value to be made into shares and sent to parties 58 | 59 | _send_shares(self, shares, receiver) 60 | Method for sending shares to parties (called within distribute_shares) 61 | 62 | Parameters 63 | ---------- 64 | shares: iterable 65 | Iterable of values/shares to be sent to parties 66 | receiver: Evaluator object 67 | Evaluator (party) that will receive shares 68 | 69 | _send_randomness(self, shares, receiver) 70 | Method for sending randomness to parties (called within 71 | generate_randomness) 72 | 73 | Parameters 74 | ---------- 75 | shares: iterable 76 | Iterable of values/shares to be sent to parties 77 | receiver: Evaluator object 78 | Evaluator (party) that will receive shares 79 | 80 | """ 81 | def __init__(self,parties,modulus,fp_precision=16): 82 | self.parties = parties 83 | self.mod = modulus 84 | self.scale = 10**fp_precision 85 | self.fpp = fp_precision 86 | #self.modulus = modulus / self.scale 87 | self.mod_bit_size = len(bin(self.mod)[2:]) 88 | 89 | def generate_randomness(self,number_of_values): 90 | inputs = [] 91 | for i in range(number_of_values): 92 | inputs.append(0) 93 | 94 | self.distribute_shares(inputs,random=True) 95 | 96 | def generate_truncate_randomness(self, number_of_truncs): 97 | for i in range(number_of_truncs): 98 | self._truncate_randomness() 99 | 100 | def _truncate_randomness(self): 101 | # need to generate a lot of random values for building blocks 102 | # of the truncate protocol 103 | 104 | # need to generate two random numbers, r2 and r1 105 | # and need to make shares of them 106 | # also need to generate shares of all the bits of r1 107 | 108 | #mod2_bit_size = int((self.mod_bit_size - 1) / 3) 109 | mod2_bit_size = 20 110 | 111 | r2 = randint(0,math.floor(2**(mod2_bit_size))) 112 | r1 = randint(0,math.floor(2**(mod2_bit_size))) 113 | 114 | r1_bits = [] 115 | for bit in bin(r1)[2:]: 116 | r1_bits.append(int(bit)) 117 | 118 | # prepend list with 0's to match modulus bit size 119 | r1_bits = [0]*(self.mod_bit_size - 1 - len(r1_bits)) + r1_bits 120 | 121 | # pass r2, r1, r1_bits as list to make shares to return list of shares 122 | random_vals = [r2, r1] + r1_bits 123 | 124 | r2_r1_shares = self._make_shares(random_vals,random=False) 125 | 126 | (sh1,sh2,sh3) = r2_r1_shares 127 | shrs = [sh1,sh2,sh3] 128 | for items in shrs: 129 | items[0] = items[0].switch_precision(0) 130 | items[1] = items[1].switch_precision(0) 131 | 132 | r2_r1_shares = shrs 133 | 134 | # also need to generate shares of 2 random bits for mod2 subprotocol 135 | # we actually need to create two shares of the second random bit 136 | b2 = randint(0,1) 137 | b1 = randint(0,1) 138 | 139 | b2_b1_shares = self._make_shares([b2,b1,b1],random=False) 140 | 141 | # lastly we need to generate shares of 2*mod_bit_size integers 142 | s_vals = [] 143 | r_vals = [] 144 | for i in range(self.mod_bit_size - 1): 145 | s_vals.append(randint(1,math.floor(self.mod / self.scale))) 146 | r_vals.append(randint(1,math.floor(self.mod / self.scale))) 147 | 148 | s_shares = self._make_shares(s_vals,random=False) 149 | r_shares = self._make_shares(r_vals,random=False) 150 | 151 | for i,party in enumerate(self.parties): 152 | shares = {} 153 | shares["mod2m"] = r2_r1_shares[i] 154 | shares["mod2"] = b2_b1_shares[i] 155 | shares["premul"] = {'s': s_shares[i], 'r': r_shares[i]} 156 | self._send_truncate_randomness(shares,party) 157 | 158 | 159 | def distribute_shares(self,inputs, random=False, verbose=False): 160 | # generate and send shares to each party 161 | # here we assume there are exactly three parties 162 | shares_for_1 = [] 163 | shares_for_2 = [] 164 | shares_for_3 = [] 165 | 166 | if (type(inputs) == int) or (type(inputs) == float): 167 | (sh1,sh2,sh3) = self._make_shares(inputs, random=random) 168 | shares_for_1.append(sh1) 169 | shares_for_2.append(sh2) 170 | shares_for_3.append(sh3) 171 | 172 | else: 173 | for val in inputs: 174 | if verbose: 175 | print("val: " + str(val)) 176 | if val == []: 177 | continue 178 | (sh1,sh2,sh3) = self._make_shares(val, random=random) 179 | shares_for_1.append(sh1) 180 | shares_for_2.append(sh2) 181 | shares_for_3.append(sh3) 182 | 183 | shares = [shares_for_1,shares_for_2,shares_for_3] 184 | 185 | for i,party in enumerate(self.parties): 186 | if random: 187 | self._send_randomness(shares[i],party) 188 | else: 189 | self._send_shares(shares[i],party) 190 | 191 | def _make_shares(self,input_value, random=False): 192 | it = type(input_value) 193 | if (it == int) or (it == np.int64) or (it == np.float64) or (it == float): 194 | 195 | val = int(input_value*self.scale) % self.mod 196 | 197 | # first generate three random values a, b, c s.t. a + b + c = 0 198 | #a = int(randint(0,self.mod-1) ) 199 | #b = int(randint(0,self.mod-1) ) 200 | #c = (- (a + b)) % self.mod 201 | 202 | a = int(randint(0,math.floor(self.mod / self.scale) - 1) * self.scale) 203 | b = int(randint(0,math.floor(self.mod / self.scale) - 1) * self.scale) 204 | c = (- (a + b)) % self.mod 205 | 206 | if random: 207 | share1 = a 208 | share2 = b 209 | share3 = c 210 | else: 211 | #share1 = (a,c-val) 212 | #share2 = (b,a-val) 213 | #share3 = (c,b-val) 214 | share1 = Share(a,c-val,mod=self.mod,fp_prec=self.fpp) 215 | share2 = Share(b,a-val,mod=self.mod,fp_prec=self.fpp) 216 | share3 = Share(c,b-val,mod=self.mod,fp_prec=self.fpp) 217 | 218 | else: 219 | share1 = [] 220 | share2 = [] 221 | share3 = [] 222 | 223 | for val in input_value: 224 | mod_val = int(round(val*self.scale)) % self.mod 225 | 226 | # first generate three random values a, b, c s.t. a + b + c = 0 227 | #a = int(randint(0,self.mod-1)) 228 | #b = int(randint(0,self.mod-1)) 229 | #c = (- (a + b)) % self.mod 230 | 231 | a = int(randint(0,math.floor(self.mod / self.scale) - 1) * self.scale) 232 | b = int(randint(0,math.floor(self.mod / self.scale) - 1) * self.scale) 233 | c = (- (a + b)) % self.mod 234 | 235 | if random: 236 | share1.append(a) 237 | share2.append(b) 238 | share3.append(c) 239 | else: 240 | share1.append(Share(a,c-mod_val,mod=self.mod,fp_prec=self.fpp)) 241 | share2.append(Share(b,a-mod_val,mod=self.mod,fp_prec=self.fpp)) 242 | share3.append(Share(c,b-mod_val,mod=self.mod,fp_prec=self.fpp)) 243 | 244 | return (share1,share2,share3) 245 | 246 | def _send_shares(self,shares,receiver): 247 | receiver.receive_shares(shares) 248 | 249 | def _send_randomness(self,shares,receiver): 250 | receiver.receive_randomness(shares) 251 | 252 | def _send_truncate_randomness(self,shares,receiver): 253 | receiver.receive_truncate_randomness(shares) -------------------------------------------------------------------------------- /src/circuits/evaluator.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from src.circuits.share import Share 3 | from src.util.mod import mod_inverse 4 | from queue import Queue 5 | import time 6 | import asyncio 7 | from threading import Event 8 | import math 9 | from src.util.primality_test import miller_rabin as is_prime 10 | 11 | class Evaluator: 12 | """ 13 | Generic evaluator class that serves as an interface for other evaluators. 14 | This interface is designed to be objects that take a function (represent 15 | as a circuit) and computes the ouput based on input wires. 16 | 17 | Methods 18 | ------- 19 | __init__(self, circuit, gate_order, fp_precision=16) 20 | Evaluator object constructor 21 | Initializes the circuit and gate order and scale (10^fp_precision) 22 | Initializes dictionary for wire_values 23 | Initializes list of input and output wires 24 | 25 | Parameters 26 | ---------- 27 | circuit: dictionary 28 | Circuit to be used to compute function, assumed to be in correct 29 | format- i.e. contains the following (key,value) pairs: 30 | ("input",) 31 | ("output",) 32 | ("wires",) 33 | (,dict({"type": , "input": , 34 | "output": })) (for each gate in circuit) 35 | gate_order: iterable (ordered) 36 | Iterable representing the order in which circuit gates should be 37 | evaluated. Each gate in gate_order should be a key in the circuit 38 | (optional) fp_precision=16: int 39 | Fixed point number precision 40 | 41 | load_inputs(self, inputs) 42 | Initialzes inputs to be used for computing the circuit 43 | 44 | Parameters 45 | ---------- 46 | inputs: dictionary 47 | Inputs to be used to compute circuit, assumed to be in correct 48 | format- i.e. contains following (key,value) pairs: 49 | (,) for each in 50 | circuit 51 | 52 | run(self, verbose=False) 53 | Computes the circuit using the loaded inputs 54 | 55 | Parameters 56 | ---------- 57 | (optional) verbose: boolean 58 | Optional argument to display intermediate wire values for debugging 59 | 60 | _eval_gate(self, gate, verbose=False) 61 | Evaluate an individual gate of the circuit 62 | 63 | Parameters 64 | ---------- 65 | gate: dictionary 66 | Gate to be computed, assumed to be in correct format- i.e. contains 67 | the following (key,value) pairs: 68 | ("type",) where can be "ADD", "MULT", 69 | "SMULT", "DOT", "NOT", or "COMP" 70 | ("input",) 71 | ("output",[ss of rv 1, ss of rv 2, ss of bits of rv 2] 425 | (where the individual bits should be individual elements of list) 426 | "mod2"->[ss of rb 1, ss of rb 2] 427 | "premul"-> {'s': [ss of rv_1s, ... , ss of rv_ns], 428 | 'r': [ss of rv_1r, ... , ss of rv_nr ]} 429 | (here n is bit size of modulus) 430 | 431 | receive_shares(self,shares) 432 | Method for getting shares of input values 433 | Current iteration receives shares from dealer 434 | 435 | Parameters 436 | ---------- 437 | shares: iterable 438 | Iterable containing shares of input values. According to current 439 | protocol, each share is a pair of values. 440 | 441 | load_secure_inputs(self,inputs) 442 | Method for loading input shares into Evaluator object. This makes 443 | secure evaluation compatible with Evaluator interface. 444 | 445 | Parameters 446 | ---------- 447 | inputs: dictionary 448 | Dictionary mapping input wires to their index in the list of 449 | input shares (self.shares). This index is used to obtain 450 | the desired share value and load it into input wire. 451 | 452 | _add(self, wire_in, wire_out) 453 | Method specifying how to compute ADD gate 454 | 455 | Parameters 456 | ---------- 457 | wire_in: iterable 458 | Iterable of input wire names 459 | wire_out: iterable 460 | Iterable of output wire names 461 | 462 | _mult(self, wire_in, wire_out) 463 | Method specifying how to compute MULT gate 464 | 465 | Parameters 466 | ---------- 467 | wire_in: iterable 468 | Iterable of input wire names 469 | wire_out: iterable 470 | Iterable of output wire names 471 | 472 | _smult(self, wire_in, wire_out) 473 | Method specifying how to compute SMULT gate 474 | 475 | Parameters 476 | ---------- 477 | wire_in: iterable 478 | Iterable of input wire names 479 | wire_out: iterable 480 | Iterable of output wire names 481 | 482 | _dot(self, wire_in, wire_out) 483 | Method specifying how to compute DOT gate 484 | 485 | Parameters 486 | ---------- 487 | wire_in: iterable 488 | Iterable of input wire names 489 | wire_out: iterable 490 | Iterable of output wire names 491 | 492 | _not(self, wire_in, wire_out) 493 | Method specifying how to compute NOT gate 494 | 495 | Parameters 496 | ---------- 497 | wire_in: iterable 498 | Iterable of input wire names 499 | wire_out: iterable 500 | Iterable of output wire names 501 | 502 | _comp(self, wire_in, wire_out) 503 | Method specifying how to compute COMP gate 504 | 505 | Parameters 506 | ---------- 507 | wire_in: iterable 508 | Iterable of input wire names 509 | wire_out: iterable 510 | Iterable of output wire names 511 | 512 | _send_share(self, value, party_index, random_index) 513 | Method for sending value to other Evaluator parties 514 | 515 | Parameters 516 | ---------- 517 | value: int 518 | Value to be send to other party 519 | party_index: int 520 | party_index of Evaluator to receive value 521 | random_index: int 522 | Index representing which random interactive value to use 523 | 524 | _receive_party_share(self, share, random_index) 525 | Method for receiving value from another Evaluator party 526 | 527 | Parameters 528 | ---------- 529 | share: int 530 | Value to be received from other party 531 | random_index: 532 | Index representing which random interactive value to use 533 | 534 | get_outputs(self) 535 | Getter for outputs of circuit 536 | 537 | Returns 538 | ------- 539 | outs: list 540 | List of output values from the computation of the circuit 541 | 542 | get_wire_dict(self) 543 | Getter for dictionary of wire values 544 | 545 | Returns 546 | ------- 547 | wire_dict: dictionary 548 | Dictionary of wire values (key: wire name, value: wire value) 549 | """ 550 | 551 | def __init__(self,circuit,input_gates,output_gates,party_index,mod,fp_precision=16): 552 | #Evaluator.__init__(self,circuit,[],fp_precision=fp_precision) 553 | self.circuit = circuit 554 | self.fpp = fp_precision 555 | self.scale = 10**fp_precision 556 | self.party_index = party_index 557 | self.input_gates = {} 558 | for ing in input_gates: 559 | self.input_gates[ing.get_id()] = ing 560 | self.output_gates = output_gates 561 | self.input_shares = [] 562 | self.q = Queue() 563 | self.outputs = {} 564 | for outg in self.output_gates: 565 | self.outputs[outg.get_id()] = "" 566 | 567 | if is_prime(mod): 568 | self.mod = mod 569 | else: 570 | raise Exception("Modulus: {} is not prime. Modulus must be prime.".format(mod)) 571 | 572 | #self.mod = mod 573 | self.mod_bit_size = len(bin(self.mod)[2:]) 574 | 575 | self.interaction_listener = None 576 | 577 | self.truncate_randomness = [] 578 | self.trunc_index = 0 579 | 580 | def add_parties(self,parties): 581 | self.parties = {} 582 | self.parties[self.party_index] = self 583 | for pindex in parties: 584 | if pindex == self.party_index: 585 | continue 586 | if pindex in self.parties: 587 | raise Exception("Party number: {} already exists".format(pindex)) 588 | else: 589 | self.parties[pindex] = parties[pindex] 590 | 591 | def receive_randomness(self,random_values): 592 | self.randomness = random_values 593 | self.random_index = 0 594 | self.interaction = {} 595 | for i in range(len(self.randomness)): 596 | self.interaction[i] = "wait" 597 | 598 | def receive_truncate_randomness(self, trunc_share_dict): 599 | self.truncate_randomness.append(trunc_share_dict) 600 | 601 | def receive_shares(self,shares): 602 | for share in shares: 603 | self.input_shares.append(share) 604 | 605 | def initialize_state(self, inputs): 606 | self.load_inputs(inputs) 607 | 608 | def load_inputs(self, inputs): 609 | for ing in inputs: 610 | self.input_gates[ing].add_input("",inputs[ing]) 611 | if self.input_gates[ing].is_ready(): 612 | self.q.put(self.input_gates[ing]) 613 | 614 | def load_secure_inputs(self,inputs): 615 | for ing_key in inputs: 616 | inputs[ing_key] = self.input_shares[inputs[ing_key]] 617 | 618 | self.load_inputs(inputs) 619 | 620 | def run(self, verbose=False): 621 | i = 0 622 | 623 | gate = self.q.get() 624 | while gate != "FIN": 625 | self._eval_gate(gate, verbose=verbose) 626 | i += 1 627 | gate = self.q.get() 628 | 629 | def reset_circuit(self): 630 | self._clear_gates() 631 | 632 | def _clear_gates(self): 633 | # need to remove inputs from input gates 634 | for in_id in self.input_gates: 635 | self.input_gates[in_id].reset() 636 | 637 | # also reset all other gates 638 | for gid in self.circuit: 639 | for gate in self.circuit[gid]: 640 | gate.reset() 641 | 642 | self.outputs = {} 643 | for outg in self.output_gates: 644 | self.outputs[outg.get_id()] = "" 645 | 646 | def _eval_gate(self,gate,verbose=False): 647 | gate_type = gate.get_type() 648 | 649 | if verbose: 650 | print("gate type: " + str(gate_type)) 651 | print("gate id: " + str(gate.get_id())) 652 | print("party index: " + str(self.party_index)) 653 | ins = gate.get_inputs() 654 | ext = [] 655 | for element in ins: 656 | if type(element) == int: 657 | ext.append(element) 658 | elif type(element) == list: 659 | el_list = [] 660 | for el in element: 661 | x_val = el.get_x() 662 | a_val = el.get_a() 663 | el_list.append((x_val,a_val)) 664 | ext.append(el_list) 665 | else: 666 | x_val = element.get_x() 667 | a_val = element.get_a() 668 | ext.append((x_val,a_val)) 669 | print("gate inputs: " + str(ext)) 670 | 671 | if gate_type == "ADD": 672 | self._add(gate) 673 | elif gate_type == "MULT": 674 | self._mult(gate) 675 | elif gate_type == "SMULT": 676 | self._smult(gate) 677 | elif gate_type == "DOT": 678 | self._dot(gate) 679 | elif gate_type == "NOT": 680 | self._not(gate) 681 | elif gate_type == "COMP": 682 | self._comp(gate) 683 | elif gate_type == "ROUND": 684 | self._round(gate) 685 | elif gate_type == "CMULT": 686 | self._cmult(gate) 687 | elif gate_type == "CADD": 688 | self._cadd(gate) 689 | elif gate_type == "INPUT": 690 | self._input(gate) 691 | elif gate_type == "OUTPUT": 692 | self._output(gate) 693 | if self._is_run_finished(): 694 | self.q.put("FIN") 695 | else: 696 | raise(Exception('{} is not a valid gate type'.format(gate_type))) 697 | 698 | 699 | def get_truncate_randomness(self, index, rand_type): 700 | if index >= len(self.truncate_randomness): 701 | raise(Exception('Randomness generated for truncation exhausted. Generate more randomness.')) 702 | return self.truncate_randomness[index][rand_type] 703 | 704 | def _truncate(self, value, k, m, pow2_switch=False): 705 | 706 | if pow2_switch: 707 | m_val = 2**m 708 | else: 709 | m_val = m 710 | 711 | a_prime = self._mod2m(value, k, m, pow2_switch=pow2_switch) 712 | 713 | self.trunc_index += 1 714 | d = value + a_prime.const_mult(-1,scaled=False) 715 | d = d.const_mult(mod_inverse(m_val,self.mod),scaled=False) 716 | 717 | return d 718 | 719 | def _mod2m(self, value, k, m, pow2_switch=False): 720 | 721 | if pow2_switch: 722 | m_val = 2**m 723 | else: 724 | m_val = m 725 | 726 | r2_r1_shares = self.get_truncate_randomness(self.trunc_index,"mod2m") 727 | r2 = r2_r1_shares[0] 728 | r1 = r2_r1_shares[1] 729 | r1_bits = r2_r1_shares[2:] 730 | 731 | pre_c = value.const_add(self.mod) 732 | pre_c += r2.const_mult(m_val,scaled=False) 733 | pre_c += r1 734 | c = self._reveal(pre_c) 735 | 736 | c_prime = int(c % m_val) 737 | 738 | u = self._bit_lt(c_prime, r1_bits) 739 | 740 | a_prime = r1.const_mult(-1,scaled=False).const_add(c_prime) 741 | a_prime += u.const_mult(m_val) 742 | 743 | return a_prime 744 | 745 | def _bit_lt(self, a, b_bits): 746 | d_vals = [] 747 | a_bits = [] 748 | for bit in bin(a)[2:]: 749 | a_bits.append(int(bit)*self.scale) 750 | a_bits = [0]*(len(b_bits) - len(a_bits)) + a_bits 751 | 752 | for i in range(len(a_bits)): 753 | d_val = b_bits[i].const_add(a_bits[i]) 754 | d_val += b_bits[i].const_mult(-2*a_bits[i]) 755 | d_vals.append(d_val.const_add(1*self.scale)) 756 | 757 | p_vals = self._premul(d_vals) 758 | p_vals.reverse() 759 | 760 | s_vals = [] 761 | for i in range(len(p_vals)-1): 762 | s_val = p_vals[i] + p_vals[i+1].const_mult(-1,scaled=False) 763 | s_vals.append(s_val) 764 | s_vals.append(p_vals[-1].const_add(-1,scaled=False)) 765 | 766 | a_bits.reverse() 767 | 768 | s = Share(0,0,mod=self.mod,fp_prec=self.fpp) 769 | slen = len(s_vals) 770 | for i in range(slen): 771 | s += s_vals[i].const_mult(self.scale - a_bits[i]) 772 | 773 | ret_val = self._mod2(s,len(b_bits)) 774 | 775 | return ret_val 776 | 777 | def _mod2(self, value, k): 778 | value = value.switch_precision(0) 779 | bits = self.get_truncate_randomness(self.trunc_index,"mod2") 780 | for i,bit in enumerate(bits): 781 | bits[i] = bit.switch_precision(0) 782 | 783 | c_pre = value 784 | c_pre += bits[0].const_mult(2) + bits[2] 785 | c = self._reveal(c_pre) 786 | c0 = int(bin(math.floor(c))[-1]) 787 | a = bits[2].const_add(c0) 788 | a += bits[2].const_mult(-2*c0) 789 | return a.switch_precision(self.fpp) 790 | 791 | def _premul(self, a_vals): 792 | premul_rand = self.get_truncate_randomness(self.trunc_index,"premul") 793 | r_vals = premul_rand['r'] 794 | s_vals = premul_rand['s'] 795 | u_vals = [] 796 | 797 | mod_scale = mod_inverse(self.scale,self.mod) 798 | 799 | for i in range(len(r_vals)): 800 | r_val = self._reveal(r_vals[i]) 801 | s_val = self._reveal(s_vals[i]) 802 | u_val = (r_val * s_val * mod_scale) % self.mod 803 | u_vals.append(u_val) 804 | 805 | u_inv_vals = [] 806 | for i in range(len(u_vals)): 807 | u_val = u_vals[i] * mod_scale % self.mod 808 | u_inv_vals.append(mod_inverse(u_val,self.mod) * self.scale) 809 | 810 | v_vals = [] 811 | for i in range(len(r_vals)-1): 812 | v_vals.append(self._multiply(r_vals[i+1],s_vals[i])) 813 | 814 | w_vals = [] 815 | w_vals.append(r_vals[0]) 816 | for i in range(len(v_vals)): 817 | w_val = v_vals[i].const_mult(u_inv_vals[i]) 818 | w_vals.append(w_val) 819 | 820 | z_vals = [] 821 | for i in range(len(s_vals)): 822 | z_val = s_vals[i].const_mult(u_inv_vals[i]) 823 | z_vals.append(z_val) 824 | 825 | m_vals = [] 826 | for i in range(len(w_vals)): 827 | m_val = self._reveal(w_vals[i]) * self._reveal(a_vals[i]) * mod_scale 828 | m_vals.append(m_val % self.mod) 829 | 830 | p_vals = [] 831 | p_vals.append(a_vals[0]) 832 | for i in range(1,len(z_vals)): 833 | m_prod = 1 * self.scale 834 | for j in range(i+1): 835 | m_prod *= m_vals[j] 836 | m_prod *= mod_scale 837 | p_vals.append(z_vals[i].const_mult(m_prod)) 838 | 839 | return p_vals 840 | 841 | def _reveal(self, value): 842 | 843 | other_party_value = self._interact(value) 844 | if self.party_index == 1: 845 | return value.unshare(other_party_value,indices=[1,3],neg_representation=False) 846 | elif self.party_index == 2: 847 | return value.unshare(other_party_value,indices=[2,1],neg_representation=False) 848 | elif self.party_index == 3: 849 | return value.unshare(other_party_value,indices=[3,2],neg_representation=False) 850 | 851 | def _add(self, gate): 852 | gid = gate.get_id() 853 | 854 | [x,y] = gate.get_inputs() 855 | 856 | gate_output = self.circuit[gid] 857 | 858 | if type(x) == list: 859 | z_vals = [] 860 | 861 | for i in range(len(x)): 862 | z_vals.append(x[i] + y[i]) 863 | 864 | for gout in gate_output: 865 | gout.add_input(gid, z_vals) 866 | if gout.is_ready(): 867 | self.q.put(gout) 868 | 869 | else: 870 | for gout in gate_output: 871 | gout.add_input(gid, x + y) 872 | if gout.is_ready(): 873 | self.q.put(gout) 874 | 875 | def _interact(self, r_val): 876 | # if self.interaction[rand_index] doesnt equal "wait" 877 | # this means that another party already send us a value 878 | # so we do not need an async event listener, so it will 879 | # be set to None 880 | if self.interaction[self.random_index] == "wait": 881 | self.interaction_listener = Event() 882 | else: 883 | self.interaction_listener = None 884 | 885 | # each party sends value to other party 886 | if self.party_index == 1: 887 | self._send_share(r_val,2,self.random_index) 888 | elif self.party_index == 2: 889 | self._send_share(r_val,3,self.random_index) 890 | elif self.party_index == 3: 891 | self._send_share(r_val,1,self.random_index) 892 | 893 | # if we don't have event listener, we don't have to wait 894 | # because we already received share from party 895 | if self.interaction_listener != None: 896 | self.interaction_listener.wait() 897 | new_r = self.interaction[self.random_index] 898 | else: 899 | # pull new value from self.interaction list 900 | new_r = self.interaction[self.random_index] 901 | self.random_index += 1 902 | self.interaction_listener = None 903 | return new_r 904 | 905 | def _multiply(self, share1, share2): 906 | 907 | if self.random_index >= len(self.randomness): 908 | raise Exception('Randomness for multiplicaiton exhausted. Please generate more randomness.') 909 | rand_value = self.randomness[self.random_index] 910 | r = share1.pre_mult(share2, rand_value) 911 | new_r = self._interact(r) 912 | 913 | return Share(new_r - r, -2* new_r - r, mod=self.mod, fp_prec=self.fpp) 914 | 915 | def _mult(self, gate): 916 | gid = gate.get_id() 917 | 918 | [x,y] = gate.get_inputs() 919 | gate_output = self.circuit[gid] 920 | 921 | z_val = self._multiply(x,y) 922 | 923 | for gout in gate_output: 924 | gout.add_input(gid,z_val) 925 | if gout.is_ready(): 926 | self.q.put(gout) 927 | 928 | def _smult(self, gate): 929 | gid = gate.get_id() 930 | 931 | [x_val, yvec] = gate.get_inputs() 932 | gate_output = self.circuit[gid] 933 | 934 | z_vals = [] 935 | 936 | for i in range(len(yvec)): 937 | y_val = yvec[i] 938 | z_val = self._multiply(x_val,y_val) 939 | 940 | z_vals.append(z_val) 941 | 942 | for gout in gate_output: 943 | gout.add_input(gid, z_vals) 944 | if gout.is_ready(): 945 | self.q.put(gout) 946 | 947 | def _dot(self, gate): 948 | gid = gate.get_id() 949 | 950 | [xvec, yvec] = gate.get_inputs() 951 | gate_output = self.circuit[gid] 952 | 953 | z_val = Share(0,0,mod=self.mod,fp_prec=self.fpp) 954 | 955 | for i in range(len(xvec)): 956 | x_val = xvec[i] 957 | y_val = yvec[i] 958 | 959 | z_val += self._multiply(x_val,y_val) 960 | 961 | self.random_index += 1 962 | 963 | for gout in gate_output: 964 | gout.add_input(gid, z_val) 965 | if gout.is_ready(): 966 | self.q.put(gout) 967 | 968 | def _not(self, gate): 969 | gid = gate.get_id() 970 | 971 | [x] = gate.get_inputs() 972 | gate_output = self.circuit[gid] 973 | 974 | for gout in gate_output: 975 | gout.add_input(gid,x.not_op()) 976 | if gout.is_ready(): 977 | self.q.put(gout) 978 | 979 | def _comp(self, gate): 980 | gid = gate.get_id() 981 | 982 | [xa] = gate.get_inputs() 983 | gate_output = self.circuit[gid] 984 | 985 | half_mod = self.mod / 2 986 | 987 | # need to do truncate in pieces in order to work 988 | # take square root and perform truncate twice 989 | sq_half_mod = math.floor(half_mod**(1/2)) 990 | 991 | s_val1 = self._truncate(xa, sq_half_mod, sq_half_mod) 992 | s_val2 = self._truncate(s_val1, sq_half_mod, sq_half_mod) 993 | 994 | # need to invert truncation to get comparison value 995 | out_val = s_val2.const_mult(-1,scaled=False) 996 | 997 | # need to scale value back up to fixed-point precision 998 | out_val = out_val.const_mult(self.scale, scaled=False) 999 | 1000 | for gout in gate_output: 1001 | gout.add_input(gid, out_val) 1002 | if gout.is_ready(): 1003 | self.q.put(gout) 1004 | 1005 | self.random_index += 1 1006 | 1007 | def _round(self, gate): 1008 | gid = gate.get_id() 1009 | 1010 | [xa] = gate.get_inputs() 1011 | gate_output = self.circuit[gid] 1012 | 1013 | k = 10**7 1014 | m = 10**7 1015 | 1016 | if type(xa) is list: 1017 | out_val = [] 1018 | for share in xa: 1019 | cur = self._truncate(share, k, m) 1020 | cur = cur.const_mult(m,scaled=False) 1021 | 1022 | out_val.append(cur) 1023 | else: 1024 | out_val = self._truncate(xa, k, m) 1025 | out_val = out_val.const_mult(m,scaled=False) 1026 | 1027 | for gout in gate_output: 1028 | gout.add_input(gid, out_val) 1029 | if gout.is_ready(): 1030 | self.q.put(gout) 1031 | 1032 | self.random_index += 1 1033 | 1034 | def _cmult(self, gate): 1035 | gid = gate.get_id() 1036 | [x] = gate.get_inputs() 1037 | [const] = gate.get_const_inputs() 1038 | gate_output = self.circuit[gid] 1039 | 1040 | if type(x) == list: 1041 | out_val = [] 1042 | 1043 | for i in range(len(x)): 1044 | out_val.append(x[i].const_mult(const)) 1045 | 1046 | else: 1047 | out_val = x.const_mult(const) 1048 | 1049 | for gout in gate_output: 1050 | gout.add_input(gid, out_val) 1051 | if gout.is_ready(): 1052 | self.q.put(gout) 1053 | 1054 | def _cadd(self, gate): 1055 | gid = gate.get_id() 1056 | [x] = gate.get_inputs() 1057 | [const] = gate.get_const_inputs() 1058 | gate_output = self.circuit[gid] 1059 | 1060 | if type(x) == list: 1061 | out_val = [] 1062 | for i in range(len(x)): 1063 | out_val.append(x[i].const_add(const)) 1064 | else: 1065 | out_val = x.const_add(const) 1066 | 1067 | for gout in gate_output: 1068 | gout.add_input(gid, out_val) 1069 | if gout.is_ready(): 1070 | self.q.put(gout) 1071 | 1072 | def _input(self, gate): 1073 | gid = gate.get_id() 1074 | [x] = gate.get_inputs() 1075 | gate_output = self.circuit[gid] 1076 | for gout in gate_output: 1077 | gout.add_input(gid, x) 1078 | if gout.is_ready(): 1079 | self.q.put(gout) 1080 | 1081 | def _output(self, gate): 1082 | gid = gate.get_id() 1083 | [x] = gate.get_inputs() 1084 | self.outputs[gid] = x 1085 | 1086 | def _is_run_finished(self): 1087 | finished = True 1088 | for out in self.outputs: 1089 | if self.outputs[out] == "": 1090 | finished = False 1091 | return finished 1092 | 1093 | def _send_share(self, value, party_index, random_index): 1094 | receiver = self.parties[party_index] 1095 | receiver._receive_party_share(value,random_index) 1096 | 1097 | def _receive_party_share(self, share, random_index): 1098 | if self.interaction_listener != None: 1099 | self.interaction_listener.set() 1100 | self.interaction[random_index] = share 1101 | 1102 | def get_outputs(self): 1103 | outs = [] 1104 | for out in self.outputs: 1105 | outs.append(self.outputs[out]) 1106 | return outs 1107 | 1108 | def get_wire_dict(self): 1109 | return self.wire_dict 1110 | -------------------------------------------------------------------------------- /src/circuits/gate.py: -------------------------------------------------------------------------------- 1 | from queue import Queue 2 | 3 | class Gate(): 4 | def __init__(self, id_num, gate_type, input_ids, gate_queue=None, ready=False, const_input=None): 5 | self.id_num = id_num 6 | self.gate_type = gate_type 7 | self.input_ids = input_ids 8 | self.inputs = {} 9 | for in_id in input_ids: 10 | self.inputs[in_id] = "" 11 | if self.gate_type == "INPUT": 12 | self.inputs = {} 13 | self.inputs[""] = "" 14 | #self.gate_queue = gate_queue 15 | self.ready = ready 16 | self.complete = False 17 | self.output = "" 18 | self.const_input = const_input 19 | 20 | def reset(self): 21 | for in_id in self.inputs: 22 | self.inputs[in_id] = "" 23 | if self.gate_type == "INPUT": 24 | self.inputs[""] = "" 25 | self.ready = False 26 | 27 | def add_input(self, in_id, in_value): 28 | #if in_id == "": 29 | # if type(in_value) == list: 30 | # print("in_val: " + str(in_value[0].get_x()) + ", " + str(in_value[0].get_a())) 31 | self.inputs[in_id] = in_value 32 | #if self._is_ready(): 33 | # print("we ready") 34 | # try: 35 | # self.gate_queue.put(self) 36 | # except: 37 | # print("no gate queue for gate: " + self.id_num) 38 | 39 | def get_inputs(self): 40 | input_vals = [] 41 | for key in self.inputs: 42 | if key != self.const_input: 43 | input_vals.append(self.inputs[key]) 44 | return input_vals 45 | 46 | def get_input_ids(self): 47 | return self.input_ids 48 | 49 | def get_type(self): 50 | return self.gate_type 51 | 52 | def get_id(self): 53 | return self.id_num 54 | 55 | def get_const_inputs(self): 56 | input_vals = [] 57 | for key in self.inputs: 58 | if key == self.const_input: 59 | input_vals.append(self.inputs[key]) 60 | return input_vals 61 | 62 | def is_ready(self): 63 | ready = True 64 | for in_id in self.inputs: 65 | if self.inputs[in_id] == "": 66 | ready = False 67 | self.ready = ready 68 | return self.ready 69 | 70 | def is_complete(self): 71 | return self.complete 72 | 73 | def set_queue(self,q): 74 | self.gate_queue = q 75 | 76 | 77 | 78 | -------------------------------------------------------------------------------- /src/circuits/oracle.py: -------------------------------------------------------------------------------- 1 | import random 2 | from src.circuits.dealer import Dealer 3 | 4 | class Oracle(Dealer): 5 | """ 6 | Oracle class that inherits from the Dealer class (to create shares). 7 | The Oracle is responsible for receiving shares from the parties, performing 8 | an operation (either MULT, SMULT, DOT, or COMP), creating new shares for 9 | the results and sending them to the parties. 10 | 11 | Methods 12 | ------- 13 | __init__(self, modulus, fp_precision=16) 14 | Constructor for Dealer object. 15 | Calls the Dealer constructor and initializes shares and outputs. 16 | 17 | Parameters 18 | ---------- 19 | modulus: int 20 | Modulus representing input domain 21 | (optional) fp_precision=16: int 22 | Fixed point number precision 23 | 24 | send_op(self, values, pindex, rindex, op) 25 | Method for performing an operation for the parties 26 | 27 | Parameters 28 | ---------- 29 | values: int or iterable 30 | Inputs to be used for performing operation 31 | pindex: int 32 | Index of party sending the value (must be 1, 2, or 3) 33 | rindex: int 34 | Index to keep track of which values to use 35 | op: string 36 | String indicating which operation to perform 37 | Must be "MULT", "DOT", "COMP", or "SMULT" 38 | 39 | receive_op(self, pindex, rindex) 40 | Method for distributing shares of result of operation to parties 41 | 42 | Parameters 43 | ---------- 44 | pindex: int 45 | Index of party receiving the value 46 | rindex: int 47 | Index to keep track of which values to use 48 | 49 | Returns 50 | ------- 51 | "wait" 52 | Value returned if all parties have not yet contributed shares 53 | self.outputs[rindex][pindex] 54 | Share of result of operation to be sent to party 55 | 56 | _mult(self, rindex) 57 | Method specifying how to compute multiplication 58 | 59 | Parameters 60 | ---------- 61 | rindex: int 62 | Index to keep track of which values to use 63 | 64 | _dot(self, rindex) 65 | Method specifying how to compute dot product 66 | 67 | Parameters 68 | ---------- 69 | rindex: int 70 | Index to keep track of which values to use 71 | 72 | _comp(self, rindex) 73 | Method specifying how to compute comparison (input <= 0) 74 | 75 | Parameters 76 | ---------- 77 | rindex: int 78 | Index to keep track of which values to use 79 | 80 | _smult(self, rindex) 81 | Method specifying how to compute scalar multiplication 82 | 83 | Parameters 84 | ---------- 85 | rindex: int 86 | Index to keep track of which values to use 87 | 88 | """ 89 | def __init__(self,modulus,fp_precision=16): 90 | Dealer.__init__(self,[],modulus,fp_precision) 91 | self.shares = {} 92 | self.outputs = {} 93 | 94 | def send_op(self,values,pindex,rindex,op): 95 | if rindex not in self.shares: 96 | self.shares[rindex] = {pindex: values} 97 | else: 98 | self.shares[rindex][pindex] = values 99 | 100 | if op == "MULT": 101 | self._mult(rindex) 102 | elif op == "DOT": 103 | self._dot(rindex) 104 | elif op == "COMP": 105 | self._comp(rindex) 106 | elif op == "ROUND": 107 | self._round(rindex) 108 | elif op == "SMULT": 109 | self._smult(rindex) 110 | elif op == "REVEAL": 111 | self._reveal(rindex) 112 | 113 | def receive_op(self,pindex,rindex): 114 | if rindex not in self.outputs: 115 | return "wait" 116 | if self.outputs[rindex] == "wait": 117 | return "wait" 118 | else: 119 | return self.outputs[rindex][pindex] 120 | 121 | def _mult(self,rindex): 122 | shrs = self.shares[rindex] 123 | if (1 not in shrs) or (2 not in shrs) or (3 not in shrs): 124 | self.outputs[rindex] = "wait" 125 | else: 126 | #x1 = shrs[1][0][0] 127 | #y1 = shrs[1][1][0] 128 | #a2 = shrs[2][0][1] 129 | #b2 = shrs[2][1][1] 130 | 131 | #x_val = x1 - a2 132 | #y_val = y1 - b2 133 | 134 | [x_val1,y_val1] = shrs[1] 135 | [x_val2,y_val2] = shrs[2] 136 | 137 | x_val = x_val1.unshare(x_val2) 138 | y_val = y_val1.unshare(y_val2) 139 | 140 | z = (x_val / self.scale) * (y_val / self.scale) 141 | [sh1,sh2,sh3] = self._make_shares(z) 142 | self.outputs[rindex] = {1: sh1, 2: sh2, 3: sh3} 143 | 144 | def _dot(self,rindex): 145 | shrs = self.shares[rindex] 146 | if (1 not in shrs) or (2 not in shrs) or (3 not in shrs): 147 | self.outputs[rindex] = "wait" 148 | else: 149 | z = 0 150 | 151 | [xvec1,yvec1] = shrs[1] 152 | [xvec2,yvec2] = shrs[2] 153 | 154 | for i in range(len(xvec1)): 155 | #(x1,a1) = xvec1[i] 156 | #(y1,b1) = yvec1[i] 157 | #(x2,a2) = xvec2[i] 158 | #(y2,b2) = yvec2[i] 159 | 160 | #x_val = x1 - a2 161 | #y_val = y1 - b2 162 | 163 | x_val = xvec1[i].unshare(xvec2[i]) 164 | y_val = yvec1[i].unshare(yvec2[i]) 165 | 166 | z += (x_val / self.scale) * y_val 167 | 168 | [sh1,sh2,sh3] = self._make_shares(z / self.scale) 169 | self.outputs[rindex] = {1: sh1, 2: sh2, 3: sh3} 170 | 171 | def _comp(self,rindex): 172 | shrs = self.shares[rindex] 173 | if (1 not in shrs) or (2 not in shrs) or (3 not in shrs): 174 | self.outputs[rindex] = "wait" 175 | else: 176 | z = 0 177 | 178 | #(x1,a1) = shrs[1][0] 179 | #(x2,a2) = shrs[2][0] 180 | 181 | #x_val = x1 - a2 182 | 183 | [x_val1] = shrs[1] 184 | [x_val2] = shrs[2] 185 | 186 | x_val = x_val1.unshare(x_val2) 187 | 188 | z = int(x_val <= 0) 189 | 190 | [sh1,sh2,sh3] = self._make_shares(z) 191 | self.outputs[rindex] = {1: sh1, 2: sh2, 3: sh3} 192 | 193 | def _round(self,rindex): 194 | shrs = self.shares[rindex] 195 | if (1 not in shrs) or (2 not in shrs) or (3 not in shrs): 196 | self.outputs[rindex] = "wait" 197 | else: 198 | z = 0 199 | 200 | #(x1,a1) = shrs[1][0] 201 | #(x2,a2) = shrs[2][0] 202 | 203 | #x_val = x1 - a2 204 | 205 | [x_val1] = shrs[1] 206 | [x_val2] = shrs[2] 207 | 208 | if type(x_val1) == list: 209 | z = [] 210 | for i in range(len(x_val1)): 211 | cur = x_val1[i].unshare(x_val2[i]) 212 | z.append((round(cur / 10**7) * 10**7) / self.scale) 213 | else: 214 | x_val = x_val1.unshare(x_val2) 215 | z = (round(x_val / 10**7) * 10**7) / self.scale 216 | 217 | [sh1,sh2,sh3] = self._make_shares(z) 218 | 219 | self.outputs[rindex] = {1: sh1, 2: sh2, 3: sh3} 220 | 221 | def _smult(self,rindex): 222 | shrs = self.shares[rindex] 223 | if (1 not in shrs) or (2 not in shrs) or (3 not in shrs): 224 | self.outputs[rindex] = "wait" 225 | else: 226 | #z = 0 227 | 228 | #(x1,a1) = shrs[1][0] 229 | #(x2,a2) = shrs[2][0] 230 | 231 | #yvec1 = shrs[1][1] 232 | #yvec2 = shrs[2][1] 233 | 234 | #x_val = x1 - a2 235 | 236 | [x_val1, yvec1] = shrs[1] 237 | [x_val2, yvec2] = shrs[2] 238 | 239 | x_val = x_val1.unshare(x_val2) 240 | 241 | z = [] 242 | 243 | for i in range(len(yvec1)): 244 | #(y1,b1) = yvec1[i] 245 | #(y2,b2) = yvec2[i] 246 | 247 | #y_val = y1 - b2 248 | 249 | y_val = yvec1[i].unshare(yvec2[i]) 250 | 251 | z.append((x_val / self.scale) * (y_val / self.scale)) 252 | 253 | [sh1,sh2,sh3] = self._make_shares(z) 254 | self.outputs[rindex] = {1: sh1, 2: sh2, 3: sh3} 255 | 256 | def _reveal(self, rindex): 257 | shrs = self.shares[rindex] 258 | if (1 not in shrs) or (2 not in shrs) or (3 not in shrs): 259 | self.outputs[rindex] = "wait" 260 | else: 261 | z = 0 262 | 263 | #(x1,a1) = shrs[1][0] 264 | #(x2,a2) = shrs[2][0] 265 | 266 | #x_val = x1 - a2 267 | 268 | [x_val1] = shrs[1] 269 | [x_val2] = shrs[2] 270 | 271 | if type(x_val1) == list: 272 | z_vals = [] 273 | for i in range(len(x_val1)): 274 | z_vals.append(x_val1[i].unshare(x_val2[i])) 275 | 276 | print("REVEALING z: " + str(z_vals)) 277 | 278 | else: 279 | z = x_val1.unshare(x_val2) 280 | print("REVEALING z: " + str(z)) 281 | self.outputs[rindex] = "done" 282 | 283 | -------------------------------------------------------------------------------- /src/circuits/share.py: -------------------------------------------------------------------------------- 1 | from src.util.mod import mod_inverse 2 | 3 | # I picked a large, 32-digit prime number to be used for default modulus: 4 | # 10001112223334445556667778889991 5 | # 1/3 mod 10001112223334445556667778889991 is 3333704074444815185555926296664 6 | # this will be needed for MPC protocol 7 | 8 | #MOD = 10001112223334445556667778889991 9 | #INVERSE_OF_3 = 3333704074444815185555926296664 10 | #MOD_SCALE = 1056332347261636068068068068067 # inverse of 10^16 mod MOD 11 | 12 | # I picked an even large, 60-digit prime for other use cases 13 | MOD = 622288097498926496141095869268883999563096063592498055290461 14 | 15 | #MOD = 24684249032065892333066123534168930441269525239006410135714283699648991959894332868446109170827166448301044689 16 | 17 | class Share(): 18 | 19 | # setup cache for inverse of 3 and scale for a given modulus 20 | # key will be mod, value will be inverse of 3 / scale for that modulus 21 | inv3_cache = {} 22 | # for scale cache, key will be tuple (mod, scale) 23 | scale_cache = {} 24 | 25 | def __init__(self, value1, value2, mod=MOD, inv_3=None, fp_prec=12, mod_scale=None): 26 | self.mod = mod 27 | if inv_3 != None: 28 | self.inv_3 = inv_3 29 | else: 30 | if self.mod not in type(self).inv3_cache: 31 | self.inv_3 = mod_inverse(3, self.mod) 32 | type(self).inv3_cache[self.mod] = self.inv_3 33 | else: 34 | self.inv_3 = type(self).inv3_cache[self.mod] 35 | 36 | self.fp = fp_prec 37 | self.x = value1 % self.mod 38 | self.a = value2 % self.mod 39 | self.scale = 10**self.fp 40 | if mod_scale != None: 41 | self.mod_scale = mod_scale 42 | else: 43 | if (self.mod, self.scale) not in type(self).scale_cache: 44 | self.mod_scale = mod_inverse(self.scale, self.mod) 45 | type(self).scale_cache[(self.mod,self.scale)] = self.mod_scale 46 | else: 47 | self.mod_scale = type(self).scale_cache[(self.mod,self.scale)] 48 | def switch_precision(self, new_precision): 49 | scale = 10**new_precision 50 | new_x = (self.x * self.mod_scale * scale) % self.mod 51 | new_a = (self.a * self.mod_scale * scale) % self.mod 52 | return Share(new_x, new_a, mod=self.mod, fp_prec=new_precision) 53 | 54 | def __eq__(self,other): 55 | if type(other) != type(Share(0,0)): 56 | return False 57 | if self.mod != other.mod: 58 | return False 59 | if self.fp != other.fp: 60 | return False 61 | if (self.x % self.mod) != (other.x % self.mod): 62 | return False 63 | if (self.a % self.mod) != (other.a % self.mod): 64 | return False 65 | else: 66 | return True 67 | 68 | def __add__(self,other): 69 | new_x = (self.x + other.x) % self.mod 70 | new_a = (self.a + other.a) % self.mod 71 | return Share(new_x, new_a, mod=self.mod, inv_3=self.inv_3, fp_prec=self.fp) 72 | 73 | def not_op(self): 74 | new_x = (-self.x) % self.mod 75 | new_a = (-(self.a + 1*self.scale)) % self.mod 76 | return Share(new_x, new_a, mod=self.mod, inv_3=self.inv_3, fp_prec=self.fp) 77 | 78 | def pre_mult(self, other, random_val): 79 | r = ((self.a * self.mod_scale) * other.a) 80 | r -= ((self.x * self.mod_scale) * other.x) 81 | r += random_val 82 | r = r % self.mod 83 | return (r * self.inv_3) % self.mod 84 | 85 | def unshare(self, other, indices=[1,2],neg_representation=True): 86 | ind = indices 87 | if (ind == [1,2]) or (ind == [2,3]) or (ind == [3,1]): 88 | res = (self.x - other.a) % self.mod 89 | else: 90 | #res = (self.a - other.x) % self.mod 91 | res = (other.x - self.a) % self.mod 92 | 93 | if neg_representation: 94 | if res > (self.mod / 2): 95 | res = res - self.mod 96 | return res 97 | 98 | def const_mult(self, const_value, scaled=True): 99 | if scaled: 100 | new_x = (self.x * const_value * self.mod_scale) % self.mod 101 | new_a = (self.a * const_value * self.mod_scale) % self.mod 102 | else: 103 | new_x = (self.x * const_value) % self.mod 104 | new_a = (self.a * const_value) % self.mod 105 | 106 | return Share(new_x, new_a, mod=self.mod, inv_3=self.inv_3, fp_prec=self.fp) 107 | 108 | def const_add(self, const_value, scaled=True): 109 | if scaled: 110 | new_x = self.x 111 | new_a = (self.a - const_value) % self.mod 112 | else: 113 | new_x = self.x 114 | new_a = (self.a - (const_value * self.scale)) % self.mod 115 | return Share(new_x, new_a, mod=self.mod, inv_3=self.inv_3, fp_prec=self.fp) 116 | 117 | def get_x(self): 118 | return self.x 119 | 120 | def get_a(self): 121 | return self.a 122 | -------------------------------------------------------------------------------- /src/util/mod.py: -------------------------------------------------------------------------------- 1 | def mod_inverse(val, mod): 2 | g, x, y = egcd(val, mod) 3 | if g != 1: 4 | raise Exception('modular inverse does not exist') 5 | else: 6 | return x % mod 7 | 8 | def egcd(a,b): 9 | if a == 0: 10 | return (b,0,1) 11 | else: 12 | g, y, x = egcd(b %a, a) 13 | return (g, x - (b //a) * y, y) -------------------------------------------------------------------------------- /src/util/primality_test.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | # Code taken from https://gist.github.com/Ayrx/5884790 4 | # implements primality testing 5 | 6 | def miller_rabin(n, k=40): 7 | 8 | # Implementation uses the Miller-Rabin Primality Test 9 | # The optimal number of rounds for this test is 40 10 | # See http://stackoverflow.com/questions/6325576/how-many-iterations-of-rabin-miller-should-i-use-for-cryptographic-safe-primes 11 | # for justification 12 | 13 | # If number is even, it's a composite number 14 | 15 | if n == 2 or n == 3: 16 | return True 17 | 18 | if n % 2 == 0: 19 | return False 20 | 21 | r, s = 0, n - 1 22 | while s % 2 == 0: 23 | r += 1 24 | s //= 2 25 | for _ in range(k): 26 | a = random.randrange(2, n - 1) 27 | x = pow(a, s, n) 28 | if x == 1 or x == n - 1: 29 | continue 30 | for _ in range(r - 1): 31 | x = pow(x, 2, n) 32 | if x == n - 1: 33 | break 34 | else: 35 | return False 36 | return True 37 | 38 | if __name__ == "__main__": 39 | 40 | test_vals = [2,3,4,5,6,7,9,43,11003,10233] 41 | 42 | for tv in test_vals: 43 | print("value " + str(tv) + " is prime: " + str(miller_rabin(tv,k=40))) -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/trailofbits/mpc-learning/7fa64451f9d6d5a4bb5fe2465762f2734952fc4c/tests/__init__.py -------------------------------------------------------------------------------- /tests/examples/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/trailofbits/mpc-learning/7fa64451f9d6d5a4bb5fe2465762f2734952fc4c/tests/examples/__init__.py -------------------------------------------------------------------------------- /tests/src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/trailofbits/mpc-learning/7fa64451f9d6d5a4bb5fe2465762f2734952fc4c/tests/src/__init__.py -------------------------------------------------------------------------------- /tests/src/circuits/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/trailofbits/mpc-learning/7fa64451f9d6d5a4bb5fe2465762f2734952fc4c/tests/src/circuits/__init__.py -------------------------------------------------------------------------------- /tests/src/circuits/test_evaluator.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from random import randint 3 | from mock import Mock 4 | from _pytest.monkeypatch import MonkeyPatch 5 | from src.circuits.share import Share 6 | from src.util.mod import mod_inverse 7 | import math 8 | 9 | from src.circuits.evaluator import SecureEvaluator 10 | 11 | class mock_mod2: 12 | def __init__(self,val1): 13 | self.val1 = val1 14 | self.cval = None 15 | 16 | def _reveal(self,value): 17 | self.cval = value.unshare(self.val1) 18 | return self.cval 19 | 20 | def get_cval(self): 21 | return self.cval 22 | 23 | @pytest.fixture 24 | def mock_secure_evaluator(): 25 | return Mock(spec=SecureEvaluator) 26 | 27 | @pytest.mark.parametrize( 28 | "value,k,mod,fpp,expected", 29 | [(1,2,19,0,1),(2,2,19,0,0),(3,5,61,1,10),(3,5,11003,2,100),(8,6,1009,1,0),(19,6,131,0,1)] 30 | ) 31 | def test_mod2(value,k,mod,fpp,expected): 32 | scale = 10**fpp 33 | value = value * scale 34 | 35 | b1 = randint(0,1) * scale 36 | b2 = randint(0,1) * scale 37 | b3 = b2 38 | 39 | # make shares of b1, b2, and b3 40 | b1a = randint(0,mod - 1) 41 | b1b = randint(0,mod - 1) 42 | b1c = (- (b1a+b1b)) % mod 43 | 44 | b2a = randint(0,mod - 1) 45 | b2b = randint(0,mod - 1) 46 | b2c = (- (b2a+b2b)) % mod 47 | 48 | b3a = randint(0,mod - 1) 49 | b3b = randint(0,mod - 1) 50 | b3c = (- (b3a+b3b)) % mod 51 | 52 | b1_share1 = Share(b1a,b1c - b1,mod=mod, fp_prec=fpp) 53 | b1_share2 = Share(b1b,b1a - b1,mod=mod, fp_prec=fpp) 54 | 55 | b2_share1 = Share(b2a,b2c - b2,mod=mod, fp_prec=fpp) 56 | b2_share2 = Share(b2b,b2a - b2,mod=mod, fp_prec=fpp) 57 | 58 | b3_share1 = Share(b3a,b3c - b3,mod=mod, fp_prec=fpp) 59 | b3_share2 = Share(b3b,b3a - b3,mod=mod, fp_prec=fpp) 60 | 61 | vala = randint(0,mod -1) 62 | valb = randint(0,mod -1) 63 | valc = (-(vala + valb)) % mod 64 | 65 | val_share1 = Share(vala,valc-value,mod=mod, fp_prec=fpp) 66 | val_share2 = Share(valb,vala-value,mod=mod, fp_prec=fpp) 67 | 68 | b1_share2 = b1_share2.switch_precision(0) 69 | b2_share2 = b2_share2.switch_precision(0) 70 | b3_share2 = b3_share2.switch_precision(0) 71 | 72 | val_share2 = val_share2.switch_precision(0) 73 | #val = val_share2.const_add(2**(k-1)) 74 | val = val_share2 75 | val += b1_share2.const_mult(2) + b3_share2 76 | m = mock_mod2(val) 77 | 78 | rbits = [b1_share1,b2_share1,b3_share1] 79 | 80 | def rand_bits(obj,index,rand): 81 | return rbits 82 | 83 | monkeypatch = MonkeyPatch() 84 | evltr = SecureEvaluator(None,[],[],1,None,mod,fp_precision=fpp) 85 | monkeypatch.setattr('src.circuits.evaluator.SecureEvaluator.get_truncate_randomness',rand_bits) 86 | monkeypatch.setattr('src.circuits.evaluator.SecureEvaluator._reveal',m._reveal) 87 | evltr_out = evltr._mod2(val_share1,k) 88 | 89 | b3_share1 = b3_share1.switch_precision(0) 90 | 91 | c0 = int(bin(m.get_cval())[-1]) 92 | party2_out = b3_share2.const_add(c0) 93 | party2_out += b3_share2.const_mult(-2*c0) 94 | party2_out = party2_out.switch_precision(fpp) 95 | 96 | 97 | assert (evltr_out.unshare(party2_out) % mod) == expected 98 | 99 | class mock_premul: 100 | def __init__(self,mod,fpp): 101 | self.mod = mod 102 | self.fpp = fpp 103 | 104 | def add_rs_vals(self,r1_vals,r_vals,s1_vals,s_vals): 105 | self.r1_vals = r1_vals 106 | self.r_vals = r_vals 107 | self.s1_vals = s1_vals 108 | self.s_vals = s_vals 109 | 110 | def add_wa_vals(self,w1_vals,w_vals,a1_vals,a_vals): 111 | self.w1_vals = w1_vals 112 | self.w_vals = w_vals 113 | self.a1_vals = a1_vals 114 | self.a_vals = a_vals 115 | 116 | def add_v_vals_dict(self,r1_vals,s1_vals,r3_vals,s3_vals): 117 | self.v_vals_dict = {} 118 | for i in range(len(r1_vals) - 1): 119 | val1 = r1_vals[i+1].get_x() + r1_vals[i+1].get_a() 120 | val2 = s1_vals[i].get_x() + s1_vals[i].get_a() 121 | self.v_vals_dict[(val1,val2)] = r3_vals[i+1].pre_mult(s3_vals[i],0) 122 | 123 | def _reveal(self,value): 124 | for i in range(len(self.r1_vals)): 125 | if value == self.r1_vals[i]: 126 | return self.r_vals[i] 127 | for i in range(len(self.s1_vals)): 128 | if value == self.s1_vals[i]: 129 | return self.s_vals[i] 130 | for i in range(len(self.w1_vals)): 131 | if value == self.w1_vals[i]: 132 | return self.w_vals[i] 133 | for i in range(len(self.a1_vals)): 134 | if value == self.a1_vals[i]: 135 | return self.a_vals[i] 136 | else: 137 | print("value: " + str(value)) 138 | print("couldn't find value !!!!!!!!!!!!!") 139 | 140 | def _multiply(self,value1,value2): 141 | r = value1.pre_mult(value2,0) 142 | val1 = value1.get_x() + value1.get_a() 143 | val2 = value2.get_x() + value2.get_a() 144 | if (val1,val2) in self.v_vals_dict: 145 | new_r = self.v_vals_dict[(val1,val2)] 146 | elif (val2,val1) in self.v_vals_dict: 147 | new_r = self.v_vals_dict[(val2,val1)] 148 | else: 149 | print("couldnt find tuple in v vals dict!!!!!") 150 | 151 | return Share(new_r - r, -2*new_r - r,mod=self.mod,fp_prec=self.fpp) 152 | 153 | @pytest.mark.parametrize( 154 | "inputs,mod,fpp,expected", 155 | [([1,2],11003,0,[1,2]),([1,2,3],11003,0,[1,2,6]), 156 | ([2,2,2,2],11003,1,[2,4,8,16]),([3,5],11003,2,[3,15]), 157 | ([15,2,3],15485867,5,[15,30,90]),([6,1,7],15485867,4,[6,6,42])] 158 | ) 159 | def test_premul(inputs,mod,fpp,expected): 160 | scale = 10**fpp 161 | mod_scale = mod_inverse(scale,mod) 162 | new_rand_mod = math.floor(mod / scale) - 1 163 | 164 | for i,inp in enumerate(inputs): 165 | inputs[i] = inp * scale 166 | a_vals = inputs 167 | 168 | m = mock_premul(mod,fpp) 169 | 170 | k = len(inputs) 171 | r_vals = [] 172 | s_vals = [] 173 | u_vals = [] 174 | u_inv_vals = [] 175 | a1_vals = [] 176 | a2_vals = [] 177 | r1_vals = [] 178 | s1_vals = [] 179 | r2_vals = [] 180 | s2_vals = [] 181 | r3_vals = [] 182 | s3_vals = [] 183 | for i in range(k): 184 | r_val = randint(1,new_rand_mod) * scale 185 | s_val = randint(1,new_rand_mod) * scale 186 | u_val = (r_val * s_val * mod_scale) % mod 187 | print("test vals:") 188 | print("r_" + str(i) + ": " + str(r_val)) 189 | print("s_" + str(i) + ": " + str(s_val)) 190 | print("u_" + str(i) + ": " + str(u_val)) 191 | u_inv = mod_inverse(u_val * mod_scale,mod) * scale 192 | 193 | r_vals.append(r_val) 194 | s_vals.append(s_val) 195 | u_vals.append(u_val) 196 | u_inv_vals.append(u_inv) 197 | 198 | ra = randint(0,new_rand_mod) * scale 199 | rb = randint(0,new_rand_mod) * scale 200 | rc = (-(ra+rb)) % mod 201 | sa = randint(0,new_rand_mod) * scale 202 | sb = randint(0,new_rand_mod) * scale 203 | sc = (-(sa+sb)) % mod 204 | aa = randint(0,new_rand_mod) * scale 205 | ab = randint(0,new_rand_mod) * scale 206 | ac = (-(aa+ab)) % mod 207 | 208 | r1_vals.append(Share(ra, rc - r_val,mod=mod,fp_prec=fpp)) 209 | s1_vals.append(Share(sa, sc - s_val,mod=mod,fp_prec=fpp)) 210 | r2_vals.append(Share(rb, ra - r_val,mod=mod,fp_prec=fpp)) 211 | s2_vals.append(Share(sb, sa - s_val,mod=mod,fp_prec=fpp)) 212 | r3_vals.append(Share(rc, rb - r_val,mod=mod,fp_prec=fpp)) 213 | s3_vals.append(Share(sc, sb - s_val,mod=mod,fp_prec=fpp)) 214 | a1_vals.append(Share(aa, ac - a_vals[i],mod=mod,fp_prec=fpp)) 215 | a2_vals.append(Share(ab, aa - a_vals[i],mod=mod,fp_prec=fpp)) 216 | 217 | m.add_rs_vals(r1_vals,r_vals,s1_vals,s_vals) 218 | m.add_v_vals_dict(r1_vals,s1_vals,r3_vals,s3_vals) 219 | 220 | v_vals = [] 221 | for i in range(len(r_vals) - 1): 222 | v_vals.append(r_vals[i+1]*s_vals[i]*mod_scale) 223 | 224 | v1_vals = [] 225 | for i in range(len(r1_vals) - 1): 226 | r = r1_vals[i+1].pre_mult(s1_vals[i],0) 227 | new_r = r3_vals[i+1].pre_mult(s3_vals[i],0) 228 | v1_vals.append(Share(new_r - r, -2*new_r - r,mod=mod,fp_prec=fpp)) 229 | 230 | v2_vals = [] 231 | for i in range(len(r1_vals) - 1): 232 | r = r2_vals[i+1].pre_mult(s2_vals[i],0) 233 | new_r = r1_vals[i+1].pre_mult(s1_vals[i],0) 234 | v2_vals.append(Share(new_r - r, -2*new_r - r,mod=mod,fp_prec=fpp)) 235 | 236 | print("raw v vals: ") 237 | for i in range(len(v_vals)): 238 | print(v_vals[i] % mod) 239 | print("cooked v vals: ") 240 | for i in range(len(v1_vals)): 241 | print(v1_vals[i].unshare(v2_vals[i]) % mod) 242 | 243 | w_vals = [] 244 | w_vals.append(r_vals[0]) 245 | for i in range(len(v_vals)): 246 | w_vals.append(v_vals[i]*u_inv_vals[i] * mod_scale % mod) 247 | 248 | w1_vals = [] 249 | w1_vals.append(r1_vals[0]) 250 | for i in range(len(v1_vals)): 251 | w1_vals.append(v1_vals[i].const_mult(u_inv_vals[i])) 252 | 253 | m.add_wa_vals(w1_vals,w_vals,a1_vals,a_vals) 254 | 255 | w2_vals = [] 256 | w2_vals.append(r2_vals[0]) 257 | for i in range(len(v2_vals)): 258 | w2_vals.append(v2_vals[i].const_mult(u_inv_vals[i])) 259 | 260 | print("raw w vals: ") 261 | for i in range(len(w_vals)): 262 | print(w_vals[i] % mod) 263 | print("cooked w vals: ") 264 | for i in range(len(w1_vals)): 265 | print(w1_vals[i].unshare(w2_vals[i]) % mod) 266 | 267 | z_vals = [] 268 | for i in range(len(s_vals)): 269 | z_vals.append(s_vals[i] * u_inv_vals[i] * mod_scale % mod) 270 | 271 | z1_vals = [] 272 | for i in range(len(s1_vals)): 273 | z1_vals.append(s1_vals[i].const_mult(u_inv_vals[i])) 274 | 275 | z2_vals = [] 276 | for i in range(len(s2_vals)): 277 | z2_vals.append(s2_vals[i].const_mult(u_inv_vals[i])) 278 | 279 | print("raw z vals: ") 280 | for i in range(len(z_vals)): 281 | print(z_vals[i] % mod) 282 | print("cooked z vals: ") 283 | for i in range(len(z1_vals)): 284 | print(z1_vals[i].unshare(z2_vals[i]) % mod) 285 | 286 | 287 | m2_vals = [] 288 | for i in range(len(w2_vals)): 289 | m2_vals.append(w_vals[i]*a_vals[i] * mod_scale % mod) 290 | 291 | print("test m vals: " + str(m2_vals)) 292 | 293 | p_vals = [] 294 | p_vals.append(a_vals[0]) 295 | for i in range(1,len(z_vals)): 296 | m_prod = 1 * scale 297 | for j in range(i+1): 298 | m_prod *= m2_vals[j] * mod_scale 299 | p_vals.append(z_vals[i] * m_prod * mod_scale % mod) 300 | 301 | p2_vals = [] 302 | p2_vals.append(a2_vals[0]) 303 | for i in range(1,len(z2_vals)): 304 | m_prod = 1 * scale 305 | for j in range(i+1): 306 | m_prod *= m2_vals[j] 307 | m_prod *= mod_scale 308 | p2_vals.append(z2_vals[i].const_mult(m_prod)) 309 | 310 | def rand_vals(obj,index,rand): 311 | return {'r': r1_vals, 's': s1_vals} 312 | 313 | monkeypatch = MonkeyPatch() 314 | evaluator = SecureEvaluator(None,[],[],1,None,mod,fp_precision=fpp) 315 | monkeypatch.setattr('src.circuits.evaluator.SecureEvaluator.get_truncate_randomness',rand_vals) 316 | monkeypatch.setattr('src.circuits.evaluator.SecureEvaluator._reveal',m._reveal) 317 | monkeypatch.setattr('src.circuits.evaluator.SecureEvaluator._multiply',m._multiply) 318 | 319 | eval_out = evaluator._premul(a1_vals) 320 | p1_vals = eval_out 321 | 322 | print("raw p vals: ") 323 | for i in range(len(p_vals)): 324 | print(p_vals[i] % mod) 325 | print("cooked p vals: ") 326 | for i in range(len(p1_vals)): 327 | print(p1_vals[i].unshare(p2_vals[i]) % mod) 328 | 329 | outs = [] 330 | for i in range(len(eval_out)): 331 | out_val = (eval_out[i].unshare(p2_vals[i])) % mod 332 | outs.append(int(out_val / scale)) 333 | 334 | assert outs == expected 335 | 336 | class mock_bit_lt: 337 | def __init__(self,mod,fpp,a_bits): 338 | self.mod = mod 339 | self.fpp = fpp 340 | self.scale = 10**fpp 341 | self.mod_scale = mod_inverse(self.scale,self.mod) 342 | self.rand_mod = math.floor(mod / self.scale) - 1 343 | self.a_bits = a_bits 344 | 345 | def add_premul_vals(self,vals): 346 | self.premul_vals = vals 347 | 348 | def _premul(self,values): 349 | d_vals = [] 350 | print("cooked d vals:") 351 | for i in range(len(values)): 352 | d_vals.append(values[i].unshare(self.premul_vals[i])) 353 | print(values[i].unshare(self.premul_vals[i])) 354 | 355 | raw_pm_vals = [] 356 | for i in range(len(d_vals)): 357 | prod = 1*self.scale 358 | for j in range(i+1): 359 | prod *= d_vals[j] * self.mod_scale 360 | prod = prod % self.mod 361 | raw_pm_vals.append(prod) 362 | self.premul_raw = raw_pm_vals 363 | 364 | print("cooked p vals: ") 365 | for i in range(len(raw_pm_vals)): 366 | print(raw_pm_vals[i]) 367 | 368 | shrs1 = [] 369 | shrs2 = [] 370 | shrs3 = [] 371 | for i in range(len(raw_pm_vals)): 372 | a = randint(0,self.rand_mod) * self.scale 373 | b = randint(0,self.rand_mod) * self.scale 374 | c = (-(a+b)) % self.mod 375 | 376 | shrs1.append(Share(a,c - raw_pm_vals[i],mod=self.mod,fp_prec=self.fpp)) 377 | shrs2.append(Share(b,a - raw_pm_vals[i],mod=self.mod,fp_prec=self.fpp)) 378 | shrs3.append(Share(c,b - raw_pm_vals[i],mod=self.mod,fp_prec=self.fpp)) 379 | 380 | shrs2.reverse() 381 | shrs3.reverse() 382 | 383 | self.pm_sh1 = shrs1 384 | self.pm_sh2 = shrs2 385 | self.pm_sh3 = shrs3 386 | 387 | self.premul_raw.reverse() 388 | 389 | return shrs1 390 | 391 | def _mod2(self,value,k): 392 | sh1 = self.pm_sh1 393 | s1_vals = [] 394 | for i in range(len(sh1)-1): 395 | s1_vals.append(sh1[i] + sh1[i+1].const_mult(-1,scaled=False)) 396 | s1_vals.append(sh1[-1].const_add(-1,scaled=False)) 397 | 398 | sh2 = self.pm_sh2 399 | s2_vals = [] 400 | for i in range(len(sh2)-1): 401 | s2_vals.append(sh2[i] + sh2[i+1].const_mult(-1,scaled=False)) 402 | s2_vals.append(sh2[-1].const_add(-1,scaled=False)) 403 | 404 | print("cooked s_vals: ") 405 | for i in range(len(s1_vals)): 406 | print(s1_vals[i].unshare(s2_vals[i])) 407 | 408 | #self.a_bits.reverse() 409 | s2len = len(s2_vals) 410 | s1 = Share(0,0,mod=self.mod,fp_prec=self.fpp) 411 | for i in range(len(s1_vals)): 412 | s1 += s1_vals[i].const_mult(self.scale - self.a_bits[i]) 413 | 414 | s_val = Share(0,0,mod=self.mod,fp_prec=self.fpp) 415 | 416 | for i in range(s2len): 417 | s_val += s2_vals[i].const_mult(self.scale - self.a_bits[i]) 418 | 419 | print("cooked a bits: " + str(self.a_bits)) 420 | 421 | print("1st cooked s val: " + str(s1.unshare(s_val))) 422 | print("2nd cooked u val: " + bin(s1.unshare(s_val))[-1]) 423 | 424 | print("cooked s val: " + str(value.unshare(s_val))) 425 | print("cooked u val: " + bin(value.unshare(s_val))[-1]) 426 | 427 | self.u_val = int(bin(value.unshare(s_val) * self.mod_scale % self.mod)[-1]) * self.scale 428 | 429 | a = randint(0,self.rand_mod) * self.scale 430 | b = randint(0,self.rand_mod) * self.scale 431 | c = (-(a+b)) % self.mod 432 | 433 | self.u_sh1 = Share(a,c-self.u_val,mod=self.mod,fp_prec=self.fpp) 434 | self.u_sh2 = Share(b,a-self.u_val,mod=self.mod,fp_prec=self.fpp) 435 | self.u_sh3 = Share(c,b-self.u_val,mod=self.mod,fp_prec=self.fpp) 436 | 437 | return self.u_sh1 438 | 439 | def get_u_sh2(self): 440 | return self.u_sh2 441 | 442 | @pytest.mark.parametrize( 443 | "val1,val2_bits,mod,fpp,expected", 444 | [(3,[0,1],11,0,0),(3,[1,1,1],11,0,1),(15,[1,0,0,1,0,0],11003,0,1), 445 | (50,[1,0,1,0,0],11003,0,0),(200,[1,0,0,1,0,1,1,0,0],11003,0,1), 446 | (6,[1,1,0],11003,0,0),(6,[1,1,1],11003,1,1),(19,[1,0,0,0,1],15485867,3,0), 447 | (19,[1,0,0,1,1],15485867,2,0)] 448 | ) 449 | def test_bit_lt(val1,val2_bits,mod,fpp,expected): 450 | scale = 10**fpp 451 | #val1 = val1 * scale % mod 452 | rand_mod = math.floor(mod / scale) - 1 453 | mod_scale = mod_inverse(scale,mod) 454 | 455 | val1_len = len(bin(val1)[2:]) 456 | print("val1 len: " + str(val1_len)) 457 | print("val2 len: " + str(len(val2_bits))) 458 | 459 | for i,bit in enumerate(val2_bits): 460 | val2_bits[i] = bit*scale 461 | 462 | print("test") 463 | 464 | if len(val2_bits) < val1_len: 465 | print("add bits") 466 | val2_bits = [0]*(val1_len - len(val2_bits)) + val2_bits 467 | 468 | val1_bits = [] 469 | for bit in bin(val1)[2:]: 470 | val1_bits.append(int(bit)*scale) 471 | if len(val1_bits) > len(val2_bits): 472 | print("not enough bits for val 2!!!!!!") 473 | else: 474 | print("adding val1 bits") 475 | val1_bits = [0]*(len(val2_bits) - len(val1_bits)) + val1_bits 476 | print("val1: " + str(val1) + " val1_bits: " + str(val1_bits)) 477 | print("val2_bits: " + str(val2_bits)) 478 | 479 | if len(val1_bits) != len(val2_bits): 480 | print("LENGHT MISMATCH!!") 481 | 482 | raw_d = [] 483 | print("raw d vals: ") 484 | for i in range(len(val1_bits)): 485 | d_val = val1_bits[i] + val2_bits[i] - 2*val1_bits[i]*val2_bits[i]*mod_scale + 1*scale 486 | print(d_val % mod) 487 | raw_d.append(d_val % mod) 488 | 489 | bit_shares1 = [] 490 | bit_shares2 = [] 491 | bit_shares3 = [] 492 | for i in range(len(val2_bits)): 493 | a = randint(0,rand_mod) * scale 494 | b = randint(0,rand_mod) * scale 495 | c = (-(a+b)) % mod 496 | 497 | bit_shares1.append(Share(a,c - val2_bits[i],mod=mod,fp_prec=fpp)) 498 | bit_shares2.append(Share(b,a - val2_bits[i],mod=mod,fp_prec=fpp)) 499 | bit_shares3.append(Share(c,b - val2_bits[i],mod=mod,fp_prec=fpp)) 500 | 501 | m = mock_bit_lt(mod,fpp,val1_bits) 502 | 503 | pm2_vals = [] 504 | for i in range(len(val2_bits)): 505 | d_val = bit_shares2[i].const_add(val1_bits[i]) 506 | d_val += bit_shares2[i].const_mult(-2*val1_bits[i]) 507 | d_val = d_val.const_add(1,scaled=False) 508 | pm2_vals.append(d_val) 509 | 510 | m.add_premul_vals(pm2_vals) 511 | 512 | raw_p = [] 513 | print("raw p vals: ") 514 | for i in range(len(raw_d)): 515 | prod = 1*scale 516 | for j in range(i+1): 517 | prod *= raw_d[j] 518 | prod *= mod_scale 519 | prod = prod % mod 520 | raw_p.append(prod) 521 | print(prod) 522 | 523 | raw_p.reverse() 524 | raw_s_vals = [] 525 | print("raw s vals: ") 526 | for i in range(len(raw_p)-1): 527 | s_vals = raw_p[i] - raw_p[i+1] 528 | s_vals = s_vals % mod 529 | print(s_vals) 530 | raw_s_vals.append(s_vals) 531 | print(raw_p[-1] - 1*scale) 532 | raw_s_vals.append((raw_p[-1] - 1*scale) % mod) 533 | 534 | raw_s = 0 535 | print("raw intermediate s:") 536 | val1_bits.reverse() 537 | print("raw a bits: " + str(val1_bits)) 538 | for i in range(len(raw_s_vals)): 539 | print(raw_s_vals[i] * (1*scale - val1_bits[i]) * mod_scale % mod) 540 | raw_s += raw_s_vals[i]*(1*scale - val1_bits[i]) * mod_scale 541 | raw_s = raw_s % mod 542 | print("raw s: ") 543 | print(raw_s) 544 | 545 | print("raw u: " + bin(raw_s * mod_scale % mod)[-1]) 546 | 547 | monkeypatch = MonkeyPatch() 548 | evaluator = SecureEvaluator(None,[],[],1,None,mod,fp_precision=fpp) 549 | monkeypatch.setattr('src.circuits.evaluator.SecureEvaluator._premul',m._premul) 550 | monkeypatch.setattr('src.circuits.evaluator.SecureEvaluator._mod2',m._mod2) 551 | 552 | out_val = evaluator._bit_lt(val1,bit_shares1) 553 | u_sh2 = m.get_u_sh2() 554 | 555 | assert ((out_val.unshare(u_sh2)) / scale) % mod == expected 556 | 557 | class mock_mod2m: 558 | 559 | def __init__(self,mod,fpp): 560 | self.mod = mod 561 | self.fpp = fpp 562 | self.scale = 10**fpp 563 | self.mod_scale = mod_inverse(self.scale,self.mod) 564 | self.rand_mod = math.floor(mod / self.scale) - 1 565 | 566 | def add_reveal(self,reveal): 567 | self.reveal = reveal 568 | 569 | def add_raw_r1(self, r1): 570 | self.r1 = r1 571 | 572 | def _reveal(self,value): 573 | self.c = value.unshare(self.reveal) % self.mod 574 | return self.c 575 | 576 | def _bit_lt(self, a, b_bits): 577 | self.u = (a < self.r1) * self.scale 578 | 579 | ua = randint(0,self.rand_mod) * self.scale 580 | ub = randint(0,self.rand_mod) * self.scale 581 | uc = (-(ua+ub)) % self.mod 582 | 583 | self.u1 = Share(ua,uc-self.u,mod=self.mod,fp_prec=self.fpp) 584 | self.u2 = Share(ub,ua-self.u,mod=self.mod,fp_prec=self.fpp) 585 | self.u3 = Share(uc,ub-self.u,mod=self.mod,fp_prec=self.fpp) 586 | 587 | return self.u1 588 | 589 | def get_c(self): 590 | return self.c 591 | 592 | def get_u1(self): 593 | return self.u1 594 | 595 | def get_u2(self): 596 | return self.u2 597 | 598 | @pytest.mark.parametrize( 599 | "value,k,fpp,mod_exp,mod,expected", 600 | [(14,2,0,5,11003,14),(20,2,0,4,11003,4), 601 | (104,3,1,6,11003,16),(3,14,3,8,15485867,184), 602 | (6,10,5,1,15458567,0),(2,11,1,3,11003,4)] 603 | ) 604 | def test_mod2m(value,k,fpp,mod_exp,mod,expected): 605 | scale = 10**fpp 606 | mod_scale = mod_inverse(scale,mod) 607 | rand_mod = math.floor(mod / scale) - 1 608 | value = value * scale 609 | rand_val_mod = 2**mod_exp - 1 610 | 611 | vala = randint(0,rand_mod) * scale 612 | valb = randint(0,rand_mod) * scale 613 | valc = (-(vala+valb)) % mod 614 | 615 | val_1 = Share(vala,valc - value,mod=mod,fp_prec=fpp) 616 | val_2 = Share(valb,vala - value,mod=mod,fp_prec=fpp) 617 | val_3 = Share(valc,valb - value,mod=mod,fp_prec=fpp) 618 | 619 | r2 = randint(0,rand_val_mod) 620 | r1 = randint(0,rand_val_mod) 621 | r1_bits = [] 622 | for bit in bin(r1)[2:]: 623 | r1_bits.append(int(bit)*scale) 624 | r1_bits = [0]*(mod_exp - len(r1_bits)) + r1_bits 625 | 626 | print("raw r2: " + str(r2)) 627 | print("raw r1: " + str(r1)) 628 | print("raw r1_bits: " + str(r1_bits)) 629 | 630 | r2_r1_shares1 = [] 631 | r2_r1_shares2 = [] 632 | r2_r1_shares3 = [] 633 | 634 | r2a = randint(0,rand_mod) * scale 635 | r2b = randint(0,rand_mod) * scale 636 | r2c = (-(r2a + r2b)) % mod 637 | r1a = randint(0,rand_mod) * scale 638 | r1b = randint(0,rand_mod) * scale 639 | r1c = (-(r1a + r1b)) % mod 640 | 641 | r2_share1 = Share(r2a,r2c-r2,mod=mod,fp_prec=fpp) 642 | r2_share2 = Share(r2b,r2a-r2,mod=mod,fp_prec=fpp) 643 | r2_share3 = Share(r2c,r2b-r2,mod=mod,fp_prec=fpp) 644 | 645 | r1_share1 = Share(r1a,r1c-r1,mod=mod,fp_prec=fpp) 646 | r1_share2 = Share(r1b,r1a-r1,mod=mod,fp_prec=fpp) 647 | r1_share3 = Share(r1c,r1b-r1,mod=mod,fp_prec=fpp) 648 | 649 | r2_r1_shares1.append(r2_share1) 650 | r2_r1_shares1.append(r1_share1) 651 | r2_r1_shares2.append(r2_share2) 652 | r2_r1_shares2.append(r1_share2) 653 | r2_r1_shares3.append(r2_share3) 654 | r2_r1_shares3.append(r1_share3) 655 | 656 | for i in range(len(r1_bits)): 657 | a = randint(0,rand_mod) * scale 658 | b = randint(0,rand_mod) * scale 659 | c = (-(a+b)) % mod 660 | 661 | shr1 = Share(a,c-r1_bits[i],mod=mod,fp_prec=fpp) 662 | shr2 = Share(b,a-r1_bits[i],mod=mod,fp_prec=fpp) 663 | shr3 = Share(c,b-r1_bits[i],mod=mod,fp_prec=fpp) 664 | 665 | r2_r1_shares1.append(shr1) 666 | r2_r1_shares2.append(shr2) 667 | r2_r1_shares3.append(shr3) 668 | 669 | def rand_vals(obj,index,rand): 670 | return r2_r1_shares1 671 | 672 | r2_1 = r2_r1_shares1[0] 673 | r1_1 = r2_r1_shares1[1] 674 | 675 | r2_2 = r2_r1_shares2[0] 676 | r1_2 = r2_r1_shares2[1] 677 | r1_bits_2 = r2_r1_shares2[2:] 678 | 679 | print("real r2: " + str(r2) + ", chk r2: " + str(r2_1.unshare(r2_2) % mod)) 680 | print("real r1: " + str(r1) + ", chk r1: " + str(r1_1.unshare(r1_2) % mod)) 681 | 682 | raw_c = value + (2**mod_exp)*r2 + r1 683 | raw_c_prime = raw_c % 2**mod_exp 684 | print("raw c: " + str(raw_c)) 685 | print("raw c_prime: " + str(raw_c_prime)) 686 | 687 | raw_u = int(raw_c_prime < r1) 688 | print("raw u: " + str(raw_u)) 689 | 690 | raw_a_prime = raw_c_prime - r1 + (2**mod_exp)*raw_u 691 | print("raw a_prime: " + str(raw_a_prime)) 692 | 693 | #pre2_c = val_2.const_add(2**(k-1),scaled=False) 694 | pre2_c = val_2 695 | pre2_c += r2_2.const_mult(2**mod_exp,scaled=False) 696 | pre2_c += r1_2 697 | 698 | pre1_c = val_1 699 | pre1_c += r2_1.const_mult(2**mod_exp,scaled=False) 700 | pre1_c += r1_1 701 | 702 | print("test c: " + str(pre1_c.unshare(pre2_c) % mod)) 703 | 704 | m = mock_mod2m(mod,fpp) 705 | m.add_reveal(pre2_c) 706 | m.add_raw_r1(r1) 707 | 708 | monkeypatch = MonkeyPatch() 709 | evaluator = SecureEvaluator(None,[],[],1,None,mod,fp_precision=fpp) 710 | monkeypatch.setattr('src.circuits.evaluator.SecureEvaluator.get_truncate_randomness',rand_vals) 711 | monkeypatch.setattr('src.circuits.evaluator.SecureEvaluator._reveal',m._reveal) 712 | monkeypatch.setattr('src.circuits.evaluator.SecureEvaluator._bit_lt',m._bit_lt) 713 | 714 | real_a1_prime = evaluator._mod2m(val_1,k,mod_exp) 715 | 716 | c = m.get_c() 717 | print("cooked c: " + str(c)) 718 | c_prime = int(c % 2**mod_exp) 719 | print("cooked c_prime: " + str(c_prime)) 720 | u1 = m.get_u1() 721 | u2 = m.get_u2() 722 | #print("cooked u: " + str(u2)) 723 | 724 | a1_prime = r1_1.const_mult(-1,scaled=False) 725 | a2_prime = r1_2.const_mult(-1,scaled=False) 726 | print("raw intermediate a_prime: " + str(-r1 % mod)) 727 | print("cooked intermediate a_prime: " + str(a1_prime.unshare(a2_prime) % mod)) 728 | a1_prime = a1_prime.const_add(c_prime) 729 | a2_prime = a2_prime.const_add(c_prime) 730 | print("raw intermediate a_prime: " + str((-r1 + raw_c_prime) % mod)) 731 | print("cooked intermediate a_prime: " + str(a1_prime.unshare(a2_prime) % mod)) 732 | 733 | print("raw u * 2**mod_exp: " + str(raw_u * (2**mod_exp))) 734 | print("cooked u * 2**mod_exp: " + str(u1.const_mult(2**mod_exp).unshare(u2.const_mult(2**mod_exp)))) 735 | 736 | a1_prime += u1.const_mult(2**mod_exp) 737 | a2_prime += u2.const_mult(2**mod_exp) 738 | print("raw final a_prime: " + str((-r1 + raw_c_prime + raw_u*(2**mod_exp)) % mod)) 739 | print("cooked final a_prime: " + str(a1_prime.unshare(a2_prime) % mod)) 740 | 741 | assert (real_a1_prime.unshare(a2_prime) % mod) == expected 742 | 743 | class mock_truncate: 744 | 745 | def __init__(self,mod,fpp): 746 | self.mod = mod 747 | self.fpp = fpp 748 | self.scale = 10**fpp 749 | self.mod_scale = mod_inverse(self.scale,self.mod) 750 | self.rand_mod = math.floor(mod / self.scale) - 1 751 | 752 | def add_value(self, value): 753 | self.value = value 754 | 755 | def _mod2m(self, val, k, m): 756 | self.raw_mod2m = self.value % 2**m 757 | 758 | a = randint(0,self.rand_mod) * self.scale 759 | b = randint(0,self.rand_mod) * self.scale 760 | c = (-(a+b)) % self.mod 761 | 762 | self.shr1 = Share(a,c - self.raw_mod2m,mod=self.mod,fp_prec=self.fpp) 763 | self.shr2 = Share(b,a - self.raw_mod2m,mod=self.mod,fp_prec=self.fpp) 764 | self.shr3 = Share(c,b - self.raw_mod2m,mod=self.mod,fp_prec=self.fpp) 765 | 766 | return self.shr1 767 | 768 | def get_shr1(self): 769 | return self.shr1 770 | 771 | def get_shr2(self): 772 | return self.shr2 773 | 774 | @pytest.mark.parametrize( 775 | "value,k,fpp,trunc_exp,mod,expected", 776 | [(3,0,0,1,43,1),(15,0,0,3,43,1),(15,0,0,2,43,3), 777 | (12,0,0,2,43,3),(401,0,0,2,11003,100), 778 | (7,0,3,6,11003,109),(12,0,2,5,11003,37)] 779 | ) 780 | def test_truncate(value,k,fpp,trunc_exp,mod,expected): 781 | scale = 10**fpp 782 | mod_scale = mod_inverse(scale,mod) 783 | rand_mod = math.floor(mod / scale) - 1 784 | value = value * scale 785 | trunc = 2**trunc_exp 786 | 787 | raw_a_prime = value % trunc 788 | print("raw a_prime: " + str(raw_a_prime)) 789 | 790 | raw_d = ((value - raw_a_prime) * mod_inverse(trunc,mod)) % mod 791 | print("raw d: " + str(raw_d)) 792 | print("value: " + str(value)) 793 | print("trunc: " + str(trunc)) 794 | print("expected: " + str(expected)) 795 | 796 | a = randint(0,rand_mod) * scale 797 | b = randint(0,rand_mod) * scale 798 | c = (-(a+b)) % mod 799 | 800 | val1 = Share(a,c-value,mod=mod,fp_prec=fpp) 801 | val2 = Share(b,a-value,mod=mod,fp_prec=fpp) 802 | val3 = Share(c,b-value,mod=mod,fp_prec=fpp) 803 | 804 | m = mock_truncate(mod,fpp) 805 | m.add_value(value) 806 | 807 | monkeypatch = MonkeyPatch() 808 | evaluator = SecureEvaluator(None,[],[],1,None,mod,fp_precision=fpp) 809 | monkeypatch.setattr('src.circuits.evaluator.SecureEvaluator._mod2m',m._mod2m) 810 | 811 | d1_val = evaluator._truncate(val1,k,trunc_exp) 812 | 813 | a2_prime = m.get_shr2() 814 | d2 = val2 + a2_prime.const_mult(-1,scaled=False) 815 | d2 = d2.const_mult(mod_inverse(trunc,mod),scaled=False) 816 | 817 | assert (d1_val.unshare(d2) % mod) == expected 818 | -------------------------------------------------------------------------------- /tests/src/circuits/test_share.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from random import randint 3 | 4 | from src.circuits.share import Share 5 | 6 | 7 | @pytest.mark.parametrize( 8 | "share,const,mod,expected", 9 | [(2,2,13,4),(1,2,11,3),(14,15,43,29), 10 | (21,6,23,4),(100,200,5,0),(14,-6,43,8), 11 | (51,-53,11003,11001)] 12 | ) 13 | def test_const_add(share,const,mod,expected): 14 | a = randint(0,mod-1) 15 | b = randint(0,mod-1) 16 | c = (-(a+b)) % mod 17 | 18 | share1 = Share(a, c - share, mod=mod, fp_prec=0) 19 | share2 = Share(b, a - share, mod=mod, fp_prec=0) 20 | 21 | share1_add = share1.const_add(const) 22 | share2_add = share2.const_add(const) 23 | 24 | assert expected == (share1_add.unshare(share2_add) % mod) 25 | 26 | @pytest.mark.parametrize( 27 | "share,const,mod,fpp,expected", 28 | [(2,2,11003,2,4),(1,2,11003,1,3),(14,15,43,0,29), 29 | (21,6,23,0,4),(100,200,5,0,0),(14,-6,43,0,8), 30 | (51,-53,11003,1,11001)] 31 | ) 32 | def test_const_add_scale(share,const,mod,fpp,expected): 33 | scale = 10**fpp 34 | share = share * scale 35 | 36 | a = randint(0,mod-1) 37 | b = randint(0,mod-1) 38 | c = (-(a+b)) % mod 39 | 40 | share1 = Share(a, c - share, mod=mod, fp_prec=fpp) 41 | share2 = Share(b, a - share, mod=mod, fp_prec=fpp) 42 | 43 | share1_add = share1.const_add(const,scaled=False) 44 | share2_add = share2.const_add(const,scaled=False) 45 | 46 | assert expected * scale == (share1_add.unshare(share2_add) % mod) 47 | 48 | @pytest.mark.parametrize( 49 | "share,const,mod,expected", 50 | [(2,2,13,4),(1,2,11,2),(3,3,17,9),(14,2,5,3)] 51 | ) 52 | def test_const_mult(share,const,mod,expected): 53 | 54 | a = randint(0,mod-1) 55 | b = randint(0,mod-1) 56 | c = (-(a+b)) % mod 57 | 58 | share1 = Share(a, c - share, mod=mod, fp_prec=0) 59 | share2 = Share(b, a - share, mod=mod, fp_prec=0) 60 | 61 | share1_mult = share1.const_mult(const) 62 | share2_mult = share2.const_mult(const) 63 | 64 | assert expected == (share1_mult.unshare(share2_mult) % mod) 65 | 66 | @pytest.mark.parametrize( 67 | "share,old_prec,new_prec,mod,expected", 68 | [(20,1,0,43,2),(300,2,1,1009,30),(1400,2,1,11003,140),(20,1,3,11003,2000)] 69 | ) 70 | def test_switch_precision(share,old_prec,new_prec,mod,expected): 71 | a = randint(0,mod-1) 72 | b = randint(0,mod-1) 73 | c = (-(a+b)) % mod 74 | 75 | share1 = Share(a, c - share, mod=mod, fp_prec=old_prec) 76 | share2 = Share(b, a - share, mod=mod, fp_prec=old_prec) 77 | 78 | share1_new = share1.switch_precision(new_prec) 79 | share2_new = share2.switch_precision(new_prec) 80 | 81 | assert expected == (share1_new.unshare(share2_new) % mod) 82 | 83 | @pytest.mark.parametrize( 84 | "share1,share2,mod,fpp,expected", 85 | [((1,1),(1,1),11,0,True),((1,2),(12,13),11,0,True),((1,5),(1,4),11,0,False)] 86 | ) 87 | def test_eq(share1,share2,mod,fpp,expected): 88 | shr1 = Share(share1[0],share1[1],mod=mod,fp_prec=fpp) 89 | shr2 = Share(share2[0],share2[1],mod=mod,fp_prec=fpp) 90 | 91 | assert (shr1 == shr2) == expected -------------------------------------------------------------------------------- /tests/util/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/trailofbits/mpc-learning/7fa64451f9d6d5a4bb5fe2465762f2734952fc4c/tests/util/__init__.py -------------------------------------------------------------------------------- /tests/util/test_mod.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | import src.util.mod as mod 4 | 5 | @pytest.mark.parametrize( 6 | "int1,int2,expected", 7 | [(1,5,1), (2,5,3), (3,17,6), (2,21,11)] 8 | ) 9 | def test_mod_inverse(int1,int2,expected): 10 | assert mod.mod_inverse(int1,int2) == expected --------------------------------------------------------------------------------