├── CONTRIBUTING ├── LICENSE ├── README ├── check_solution_milp.py ├── check_solution_svd.py ├── extract.py ├── models └── .gitkeep ├── src ├── find_witnesses.py ├── global_vars.py ├── hyperplane_normal.py ├── layer_recovery.py ├── refine_precision.py ├── sign_recovery.py └── utils.py └── train_models.py /CONTRIBUTING: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | We'd love to accept your patches and contributions to this project. There are 4 | just a few small guidelines you need to follow. 5 | 6 | ## Contributor License Agreement 7 | 8 | Contributions to this project must be accompanied by a Contributor License 9 | Agreement. You (or your employer) retain the copyright to your contribution; 10 | this simply gives us permission to use and redistribute your contributions as 11 | part of the project. Head over to to see 12 | your current agreements on file or to sign a new one. 13 | 14 | You generally only need to submit a CLA once, so if you've already submitted one 15 | (even if it was for a different project), you probably don't need to do it 16 | again. 17 | 18 | ## Code reviews 19 | 20 | All submissions, including submissions by project members, require review. We 21 | use GitHub pull requests for this purpose. Consult 22 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 23 | information on using pull requests. 24 | 25 | ## Community Guidelines 26 | 27 | This project follows [Google's Open Source Community 28 | Guidelines](https://opensource.google.com/conduct/). 29 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /README: -------------------------------------------------------------------------------- 1 | CRYPTANALYTIC EXTRACTION OF NEURAL NETWORK MODELS 2 | 3 | This repository contains an implementation of the model extraction attack in our CRYPTO'20 paper 4 | 5 | Cryptanalytic Extraction of Neural Network Models 6 | https://arxiv.org/abs/2003.04884 7 | Nicholas Carlini, Matthew Jagielski, Ilya Mironov 8 | 9 | 10 | INSTALLING 11 | 12 | To get started you will need to install some dependencies. It should suffice to run 13 | 14 | > pip install numpy scipy jax jaxlib matplotlib networkx 15 | 16 | Sometimes JaX (or, more correctly, XLA) puts up a fight during the install, 17 | but if the above works then everything should run properly. 18 | 19 | 20 | EXTRACTING EXAMPLE MODELS 21 | 22 | First, generate a model that we will extract by running 23 | 24 | > python3 train_models.py 10-15-15-1 42 25 | 26 | and then extract it with 27 | 28 | > python3 extract.py 10-15-15-1 42 29 | 30 | this should be quick to extract and then check the quality of this extraction with 31 | 32 | > python3 check_solution_svd.py 10-15-15-1 33 | 34 | or if you have MILP solver installed you can run 35 | 36 | > python3 check_solution_milp.py 10-15-15-1 37 | 38 | and then running the solver on /tmp/test.mod 39 | 40 | By default, the code is set up so that it won't cheat and look at the weights of the 41 | actual neural network we're extracting (and will throw ugly errors if we try). 42 | Some logging looks better if we're allowed to cheat though (e.g., to catch errors 43 | earlier in the process). 44 | 45 | To enable this, set CHEATING=True in src/global_vars.py. 46 | 47 | 48 | EXTRACTING YOUR OWN MODELS 49 | 50 | The code can currently extract only fully-connected neural networks. 51 | 52 | To extract a model, save it as a numpy array in the format [weights, biases]. For 53 | example, a 20-10-1 network could be saved to models/UID_20-10-1.npy 54 | [[np.random.normal(size=(20,10)), np.random.normal(size=(10, 1))], [np.zeros((10,)), np.zeros((1,))]] 55 | 56 | and then run 57 | 58 | > python extract.py UID 20-10-1 59 | 60 | 61 | CITING THIS WORK 62 | 63 | If you find this code useful you can cite 64 | 65 | @article{carlini2020cryptanalytic, 66 | title={Cryptanalytic Extraction of Neural Network Models}, 67 | author={Carlini, Nicholas and Jagielski, Matthew and Mironov, Ilya}, 68 | booktitle={Annual International Cryptology Conference}, 69 | year={2020} 70 | } 71 | -------------------------------------------------------------------------------- /check_solution_milp.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import numpy as np 16 | import sys 17 | import pickle 18 | 19 | np.random.seed(0) 20 | 21 | NN = 0 22 | 23 | QQ = 0 24 | def n(): 25 | global QQ 26 | QQ += 1 27 | return QQ 28 | 29 | var = 0 30 | allvars = [] 31 | 32 | def nvar(p='x'): 33 | global var 34 | res = p+str(var) 35 | allvars.append(res) 36 | var += 1 37 | return res 38 | 39 | def dump_net(previous_outputs): 40 | global NN 41 | computation = [] 42 | 43 | while len(allvars): 44 | allvars.pop() 45 | 46 | for i in range(len(A)): 47 | next_outputs = [] 48 | for out_neuron in range(A[i].shape[1]): 49 | x_var = nvar() 50 | s_var = nvar('s') 51 | r_var = nvar('r') 52 | next_outputs.append(x_var) 53 | combination = " + ".join("%.17f * %s"%(a,b) for a,b in zip(A[i][:,out_neuron],previous_outputs)) 54 | computation.append("s.t. ok_%d: "%n() + combination + " + " + str(B[i][out_neuron]) + " = " + x_var + " - " + s_var + ";") 55 | computation.append("s.t. relu_%d: %s <= 1000 * %s;" % (n(), x_var, r_var)) 56 | computation.append("s.t. relu_%d: %s <= 1000 * (1 - %s);" % (n(), s_var, r_var)) 57 | 58 | previous_outputs = next_outputs 59 | 60 | finalx = [x for x in allvars if x[0] == 'x'][-1] 61 | 62 | prefix = [] 63 | 64 | for v in allvars: 65 | if v == finalx: 66 | prefix.append('var '+v+';') 67 | elif v[0] == 'x': 68 | prefix.append('var '+v+' >= 0;') 69 | elif v[0] == 's': 70 | prefix.append('var '+v+' >= 0;') 71 | elif v[0] == 'i': 72 | prefix.append('var '+v+' >= -1;') 73 | elif v[0] == 'r': 74 | prefix.append('var '+v+' binary;') 75 | else: 76 | raise 77 | 78 | NN += 1 79 | return prefix, computation[:-2] + ["s.t. final%d: %s = 0;"%(NN,[x for x in allvars if x[0] == 's'][-1])], finalx 80 | 81 | 82 | 83 | name = sys.argv[1] 84 | 85 | A, B = pickle.load(open("/tmp/real-%s.p"%name,"rb")) 86 | 87 | sizes = [x.shape[0] for x in A] + [1] 88 | 89 | inputs = [nvar('i') for _ in range(sizes[0])] 90 | 91 | prefix1, rest1, outvar1 = dump_net(inputs) 92 | 93 | A, B = pickle.load(open("/tmp/extracted-%s.p"%name,"rb")) 94 | 95 | prefix2, rest2, outvar2 = dump_net(inputs) 96 | 97 | 98 | 99 | 100 | import sys 101 | sys.stdout=open("/tmp/test.mod","w") 102 | 103 | print("\n".join('var '+v+' >= 0;' for v in inputs)) 104 | 105 | print("\n".join(prefix1)) 106 | print("\n\n") 107 | print("\n".join(prefix2)) 108 | 109 | print("var slack;") 110 | print("var which binary;") 111 | print("maximize obj: %s-%s;"%(outvar1,outvar2)) 112 | 113 | 114 | print("\n".join(rest1)) 115 | print("\n".join(rest2)) 116 | 117 | for v in inputs: 118 | if v[0] == 'i': 119 | print("s.t. bounded%s: %s <= 1;"%(v,v)) 120 | 121 | print("solve;") 122 | print('display %s;'%(", ".join(x for x in inputs))) 123 | print('display %s;'%outvar1) 124 | print('display %s;'%outvar2) 125 | print('display slack;') 126 | 127 | # Now it's on you. Go and run this model. 128 | # you can export to mps with the following command. 129 | 130 | # glpsol --check --wfreemps /tmp/o.mps --model /tmp/test.mod 131 | -------------------------------------------------------------------------------- /check_solution_svd.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import jax 16 | import sys 17 | import numpy as onp 18 | import jax.numpy as jnp 19 | import numpy as np 20 | import numpy.linalg 21 | import networkx as nx 22 | import matplotlib.pyplot as plt 23 | import scipy.optimize 24 | import pickle 25 | 26 | from jax.config import config 27 | config.update("jax_enable_x64", True) 28 | 29 | def relu(x): 30 | return x * (x>0) 31 | 32 | #@jax.jit 33 | def run(x,A,B,debug=True, np=np): 34 | for i,(a,b) in enumerate(zip(A,B)): 35 | x = np.dot(x,a)+b 36 | if i < len(A)-1: 37 | x = x*(x>0) 38 | return x 39 | 40 | 41 | name = sys.argv[1] if len(sys.argv) > 1 else "40-20-10-10-1" 42 | 43 | prefix = "/tmp/" 44 | 45 | A1, B1 = pickle.load(open(prefix+"real-%s.p"%name,"rb")) 46 | A2, B2 = pickle.load(open(prefix+"extracted-%s.p"%name,"rb")) 47 | 48 | A1 = [np.array(x,dtype=np.float64) for x in A1] 49 | A2 = [np.array(x,dtype=np.float64) for x in A2] 50 | 51 | B1 = [np.array(x,dtype=np.float64) for x in B1] 52 | B2 = [np.array(x,dtype=np.float64) for x in B2] 53 | 54 | 55 | print("Compute the matrix alignment for the SVD upper bound") 56 | for layer in range(len(A1)-1): 57 | M_real = np.copy(A1[layer]) 58 | M_fake = np.copy(A2[layer]) 59 | 60 | scores = [] 61 | 62 | for i in range(M_real.shape[1]): 63 | vec = M_real[:,i:i+1] 64 | ratio = np.abs(M_fake/vec) 65 | 66 | scores.append(np.std(A2[layer]/vec,axis=0)) 67 | 68 | 69 | i_s, j_s = scipy.optimize.linear_sum_assignment(scores) 70 | 71 | for i,j in zip(i_s, j_s): 72 | vec = M_real[:,i:i+1] 73 | ratio = np.abs(M_fake/vec) 74 | 75 | ratio = np.median(ratio[:,j]) 76 | #print("Map from", i, j, ratio) 77 | 78 | gap = np.abs(M_fake[:,j]/ratio - M_real[:,i]) 79 | 80 | A2[layer][:,j] /= ratio 81 | B2[layer][j] /= ratio 82 | A2[layer+1][j,:] *= ratio 83 | 84 | A2[layer] = A2[layer][:,j_s] 85 | B2[layer] = B2[layer][j_s] 86 | 87 | A2[layer+1] = A2[layer+1][j_s,:] 88 | 89 | A2[1] *= np.sign(A2[1][0]) 90 | A2[1] *= np.sign(A1[1][0]) 91 | 92 | B2[1] *= np.sign(B2[1]) 93 | B2[1] *= np.sign(B1[1]) 94 | 95 | print("Finished alignment. Now compute the max error in the matrix.") 96 | max_err = 0 97 | for l in range(len(A1)): 98 | print("Matrix diff", np.sum(np.abs(A1[l]-A2[l]))) 99 | print("Bias diff", np.sum(np.abs(B1[l]-B2[l]))) 100 | max_err = max(max_err, np.max(np.abs(A1[l]-A2[l]))) 101 | max_err = max(max_err, np.max(np.abs(B1[l]-B2[l]))) 102 | 103 | print("Number of bits of precision in the weight matrix", 104 | -np.log(max_err)/np.log(2)) 105 | 106 | print("\nComputing SVD upper bound") 107 | high = np.ones(A1[0].shape[0]) 108 | low = -np.ones(A1[0].shape[0]) 109 | input_bound = np.sum((high-low)**2)**.5 110 | prev_bound = 0 111 | for i in range(len(A1)): 112 | largest_value = np.linalg.svd(A1[i]-A2[i])[1][0] * input_bound 113 | largest_value += np.linalg.svd(A1[i])[1][0] * prev_bound 114 | largest_value += np.sum((B1[i]-B2[i])**2)**.5 115 | prev_bound = largest_value 116 | print("\tAt layer", i, "loss is bounded by", largest_value) 117 | 118 | print('Upper bound on number of bits of precision in the output through SVD', -np.log(largest_value)/np.log(2)) 119 | 120 | print("\nFinally estimate it through random samples to make sure we haven't made a mistake") # not that that would ever happen 121 | def loss(x): 122 | return np.abs(run(x, A=A1, B=B1)-run(x, A=A2, B=B2)) 123 | 124 | ls = [] 125 | for _ in range(100): 126 | if _%10 == 0: 127 | print("Iter %d/100"%_) 128 | inp = onp.random.normal(0, 1, (int(1000000/A1[0].shape[0]), A1[0].shape[0])) 129 | inp /= np.sum(inp**2,axis=1,keepdims=True)**.5 130 | inp *= np.sum((np.ones(A1[0].shape[0])*2)**2)**.5 131 | ell = loss(inp).flatten() 132 | ls.extend(ell) 133 | 134 | ls = onp.array(ls).flatten() 135 | 136 | print("Fewest number of bits of precision over", len(ls), "random samples:", -np.log(np.max(ls))/np.log(2)) 137 | 138 | # Finally plot a distribution of the values to see 139 | plt.hist(ls,30) 140 | plt.semilogy() 141 | plt.savefig("/tmp/a.pdf") 142 | exit(0) 143 | -------------------------------------------------------------------------------- /extract.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import sys 16 | import random 17 | import traceback 18 | import time 19 | import numpy.linalg 20 | import pickle 21 | import multiprocessing as mp 22 | import os 23 | import signal 24 | 25 | import numpy as np 26 | 27 | from src.utils import matmul, KnownT, check_quality, SAVED_QUERIES, run 28 | from src.find_witnesses import sweep_for_critical_points 29 | import src.refine_precision as refine_precision 30 | import src.layer_recovery as layer_recovery 31 | import src.sign_recovery as sign_recovery 32 | from src.global_vars import * 33 | 34 | ##################################################################### 35 | ## MAIN FUNCTION. This is where it all happens. ## 36 | ##################################################################### 37 | 38 | def run_full_attack(): 39 | global query_count, SAVED_QUERIES 40 | 41 | extracted_normals = [] 42 | extracted_biases = [] 43 | 44 | known_T = KnownT(extracted_normals, extracted_biases) 45 | 46 | for layer_num in range(0,len(A)-1): 47 | # For each layer of the network ... 48 | 49 | # First setup the critical points generator 50 | critical_points = sweep_for_critical_points(PARAM_SEARCH_AT_LOCATION, known_T) 51 | 52 | # Extract weights corresponding to those critical points 53 | extracted_normal, extracted_bias, mask = layer_recovery.compute_layer_values(critical_points, 54 | known_T, 55 | layer_num) 56 | 57 | # Report how well we're doing 58 | check_quality(layer_num, extracted_normal, extracted_bias) 59 | 60 | # Now, make them more precise 61 | extracted_normal, extracted_bias = refine_precision.improve_layer_precision(layer_num, 62 | known_T, extracted_normal, extracted_bias) 63 | print("Query count", query_count) 64 | 65 | # And print how well we're doing 66 | check_quality(layer_num, extracted_normal, extracted_bias) 67 | 68 | # New generator 69 | critical_points = sweep_for_critical_points(1e1) 70 | 71 | # Solve for signs 72 | if layer_num == 0 and sizes[1] <= sizes[0]: 73 | extracted_sign = sign_recovery.solve_contractive_sign(known_T, extracted_normal, extracted_bias, layer_num) 74 | elif layer_num > 0 and sizes[1] <= sizes[0] and all(sizes[x+1] <= sizes[x]/2 for x in range(1,len(sizes)-1)): 75 | try: 76 | extracted_sign = sign_recovery.solve_contractive_sign(known_T, extracted_normal, extracted_bias, layer_num) 77 | except AcceptableFailure as e: 78 | print("Contractive solving failed; fall back to noncontractive method") 79 | if layer_num == len(A)-2: 80 | print("Solve final two") 81 | break 82 | 83 | extracted_sign, _ = sign_recovery.solve_layer_sign(known_T, extracted_normal, extracted_bias, critical_points, 84 | layer_num, 85 | l1_mask=np.int32(np.sign(mask))) 86 | 87 | else: 88 | if layer_num == len(A)-2: 89 | print("Solve final two") 90 | break 91 | 92 | extracted_sign, _ = sign_recovery.solve_layer_sign(known_T, extracted_normal, extracted_bias, critical_points, 93 | layer_num, 94 | l1_mask=np.int32(np.sign(mask))) 95 | 96 | print("Extracted", extracted_sign) 97 | print('real sign', np.int32(np.sign(mask))) 98 | 99 | print("Total query count", query_count) 100 | 101 | # Correct signs 102 | extracted_normal *= extracted_sign 103 | extracted_bias *= extracted_sign 104 | extracted_bias = np.array(extracted_bias, dtype=np.float64) 105 | 106 | # Report how we're doing 107 | extracted_normal, extracted_bias = check_quality(layer_num, extracted_normal, extracted_bias, do_fix=True) 108 | 109 | extracted_normals.append(extracted_normal) 110 | extracted_biases.append(extracted_bias) 111 | 112 | known_T = KnownT(extracted_normals, extracted_biases) 113 | 114 | for a,b in sorted(query_count_at.items(),key=lambda x: -x[1]): 115 | print('count', b, '\t', 'line:', a, ':', self_lines[a-1].strip()) 116 | 117 | # And then finish up 118 | if len(extracted_normals) == len(sizes)-2: 119 | print("Just solve final layer") 120 | N = int(len(SAVED_QUERIES)/1000) or 1 121 | ins, outs = zip(*SAVED_QUERIES[::N]) 122 | solve_final_layer(known_T, np.array(ins), np.array(outs)) 123 | else: 124 | print("Solve final two") 125 | solve_final_two_layers(known_T, extracted_normal, extracted_bias) 126 | 127 | 128 | def solve_final_two_layers(known_T, known_A0, known_B0): 129 | ## Recover the final two layers through brute forcing signs, then least squares 130 | ## Yes, this is mostly a copy of solve_layer_sign. I am repeating myself. Sorry. 131 | LAYER = len(sizes)-2 132 | filtered_inputs = [] 133 | filtered_outputs = [] 134 | 135 | # How many unique points to use. This seems to work. Tweak if needed... 136 | # (In checking consistency of the final layer signs) 137 | N = int(len(SAVED_QUERIES)/100) or 1 138 | ins, outs = zip(*SAVED_QUERIES[::N]) 139 | filtered_inputs, filtered_outputs = zip(*SAVED_QUERIES[::N]) 140 | print('Total query count', len(SAVED_QUERIES)) 141 | print("Solving on", len(filtered_inputs)) 142 | 143 | inputs, outputs = np.array(filtered_inputs), np.array(filtered_outputs) 144 | known_hidden_so_far = known_T.forward(inputs, with_relu=True) 145 | 146 | K = sizes[LAYER] 147 | print("K IS", K) 148 | shuf = list(range(1<