├── CONTRIBUTING
├── LICENSE
├── README
├── check_solution_milp.py
├── check_solution_svd.py
├── extract.py
├── models
└── .gitkeep
├── src
├── find_witnesses.py
├── global_vars.py
├── hyperplane_normal.py
├── layer_recovery.py
├── refine_precision.py
├── sign_recovery.py
└── utils.py
└── train_models.py
/CONTRIBUTING:
--------------------------------------------------------------------------------
1 | # How to Contribute
2 |
3 | We'd love to accept your patches and contributions to this project. There are
4 | just a few small guidelines you need to follow.
5 |
6 | ## Contributor License Agreement
7 |
8 | Contributions to this project must be accompanied by a Contributor License
9 | Agreement. You (or your employer) retain the copyright to your contribution;
10 | this simply gives us permission to use and redistribute your contributions as
11 | part of the project. Head over to to see
12 | your current agreements on file or to sign a new one.
13 |
14 | You generally only need to submit a CLA once, so if you've already submitted one
15 | (even if it was for a different project), you probably don't need to do it
16 | again.
17 |
18 | ## Code reviews
19 |
20 | All submissions, including submissions by project members, require review. We
21 | use GitHub pull requests for this purpose. Consult
22 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
23 | information on using pull requests.
24 |
25 | ## Community Guidelines
26 |
27 | This project follows [Google's Open Source Community
28 | Guidelines](https://opensource.google.com/conduct/).
29 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 |
2 | Apache License
3 | Version 2.0, January 2004
4 | http://www.apache.org/licenses/
5 |
6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7 |
8 | 1. Definitions.
9 |
10 | "License" shall mean the terms and conditions for use, reproduction,
11 | and distribution as defined by Sections 1 through 9 of this document.
12 |
13 | "Licensor" shall mean the copyright owner or entity authorized by
14 | the copyright owner that is granting the License.
15 |
16 | "Legal Entity" shall mean the union of the acting entity and all
17 | other entities that control, are controlled by, or are under common
18 | control with that entity. For the purposes of this definition,
19 | "control" means (i) the power, direct or indirect, to cause the
20 | direction or management of such entity, whether by contract or
21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
22 | outstanding shares, or (iii) beneficial ownership of such entity.
23 |
24 | "You" (or "Your") shall mean an individual or Legal Entity
25 | exercising permissions granted by this License.
26 |
27 | "Source" form shall mean the preferred form for making modifications,
28 | including but not limited to software source code, documentation
29 | source, and configuration files.
30 |
31 | "Object" form shall mean any form resulting from mechanical
32 | transformation or translation of a Source form, including but
33 | not limited to compiled object code, generated documentation,
34 | and conversions to other media types.
35 |
36 | "Work" shall mean the work of authorship, whether in Source or
37 | Object form, made available under the License, as indicated by a
38 | copyright notice that is included in or attached to the work
39 | (an example is provided in the Appendix below).
40 |
41 | "Derivative Works" shall mean any work, whether in Source or Object
42 | form, that is based on (or derived from) the Work and for which the
43 | editorial revisions, annotations, elaborations, or other modifications
44 | represent, as a whole, an original work of authorship. For the purposes
45 | of this License, Derivative Works shall not include works that remain
46 | separable from, or merely link (or bind by name) to the interfaces of,
47 | the Work and Derivative Works thereof.
48 |
49 | "Contribution" shall mean any work of authorship, including
50 | the original version of the Work and any modifications or additions
51 | to that Work or Derivative Works thereof, that is intentionally
52 | submitted to Licensor for inclusion in the Work by the copyright owner
53 | or by an individual or Legal Entity authorized to submit on behalf of
54 | the copyright owner. For the purposes of this definition, "submitted"
55 | means any form of electronic, verbal, or written communication sent
56 | to the Licensor or its representatives, including but not limited to
57 | communication on electronic mailing lists, source code control systems,
58 | and issue tracking systems that are managed by, or on behalf of, the
59 | Licensor for the purpose of discussing and improving the Work, but
60 | excluding communication that is conspicuously marked or otherwise
61 | designated in writing by the copyright owner as "Not a Contribution."
62 |
63 | "Contributor" shall mean Licensor and any individual or Legal Entity
64 | on behalf of whom a Contribution has been received by Licensor and
65 | subsequently incorporated within the Work.
66 |
67 | 2. Grant of Copyright License. Subject to the terms and conditions of
68 | this License, each Contributor hereby grants to You a perpetual,
69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70 | copyright license to reproduce, prepare Derivative Works of,
71 | publicly display, publicly perform, sublicense, and distribute the
72 | Work and such Derivative Works in Source or Object form.
73 |
74 | 3. Grant of Patent License. Subject to the terms and conditions of
75 | this License, each Contributor hereby grants to You a perpetual,
76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77 | (except as stated in this section) patent license to make, have made,
78 | use, offer to sell, sell, import, and otherwise transfer the Work,
79 | where such license applies only to those patent claims licensable
80 | by such Contributor that are necessarily infringed by their
81 | Contribution(s) alone or by combination of their Contribution(s)
82 | with the Work to which such Contribution(s) was submitted. If You
83 | institute patent litigation against any entity (including a
84 | cross-claim or counterclaim in a lawsuit) alleging that the Work
85 | or a Contribution incorporated within the Work constitutes direct
86 | or contributory patent infringement, then any patent licenses
87 | granted to You under this License for that Work shall terminate
88 | as of the date such litigation is filed.
89 |
90 | 4. Redistribution. You may reproduce and distribute copies of the
91 | Work or Derivative Works thereof in any medium, with or without
92 | modifications, and in Source or Object form, provided that You
93 | meet the following conditions:
94 |
95 | (a) You must give any other recipients of the Work or
96 | Derivative Works a copy of this License; and
97 |
98 | (b) You must cause any modified files to carry prominent notices
99 | stating that You changed the files; and
100 |
101 | (c) You must retain, in the Source form of any Derivative Works
102 | that You distribute, all copyright, patent, trademark, and
103 | attribution notices from the Source form of the Work,
104 | excluding those notices that do not pertain to any part of
105 | the Derivative Works; and
106 |
107 | (d) If the Work includes a "NOTICE" text file as part of its
108 | distribution, then any Derivative Works that You distribute must
109 | include a readable copy of the attribution notices contained
110 | within such NOTICE file, excluding those notices that do not
111 | pertain to any part of the Derivative Works, in at least one
112 | of the following places: within a NOTICE text file distributed
113 | as part of the Derivative Works; within the Source form or
114 | documentation, if provided along with the Derivative Works; or,
115 | within a display generated by the Derivative Works, if and
116 | wherever such third-party notices normally appear. The contents
117 | of the NOTICE file are for informational purposes only and
118 | do not modify the License. You may add Your own attribution
119 | notices within Derivative Works that You distribute, alongside
120 | or as an addendum to the NOTICE text from the Work, provided
121 | that such additional attribution notices cannot be construed
122 | as modifying the License.
123 |
124 | You may add Your own copyright statement to Your modifications and
125 | may provide additional or different license terms and conditions
126 | for use, reproduction, or distribution of Your modifications, or
127 | for any such Derivative Works as a whole, provided Your use,
128 | reproduction, and distribution of the Work otherwise complies with
129 | the conditions stated in this License.
130 |
131 | 5. Submission of Contributions. Unless You explicitly state otherwise,
132 | any Contribution intentionally submitted for inclusion in the Work
133 | by You to the Licensor shall be under the terms and conditions of
134 | this License, without any additional terms or conditions.
135 | Notwithstanding the above, nothing herein shall supersede or modify
136 | the terms of any separate license agreement you may have executed
137 | with Licensor regarding such Contributions.
138 |
139 | 6. Trademarks. This License does not grant permission to use the trade
140 | names, trademarks, service marks, or product names of the Licensor,
141 | except as required for reasonable and customary use in describing the
142 | origin of the Work and reproducing the content of the NOTICE file.
143 |
144 | 7. Disclaimer of Warranty. Unless required by applicable law or
145 | agreed to in writing, Licensor provides the Work (and each
146 | Contributor provides its Contributions) on an "AS IS" BASIS,
147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148 | implied, including, without limitation, any warranties or conditions
149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150 | PARTICULAR PURPOSE. You are solely responsible for determining the
151 | appropriateness of using or redistributing the Work and assume any
152 | risks associated with Your exercise of permissions under this License.
153 |
154 | 8. Limitation of Liability. In no event and under no legal theory,
155 | whether in tort (including negligence), contract, or otherwise,
156 | unless required by applicable law (such as deliberate and grossly
157 | negligent acts) or agreed to in writing, shall any Contributor be
158 | liable to You for damages, including any direct, indirect, special,
159 | incidental, or consequential damages of any character arising as a
160 | result of this License or out of the use or inability to use the
161 | Work (including but not limited to damages for loss of goodwill,
162 | work stoppage, computer failure or malfunction, or any and all
163 | other commercial damages or losses), even if such Contributor
164 | has been advised of the possibility of such damages.
165 |
166 | 9. Accepting Warranty or Additional Liability. While redistributing
167 | the Work or Derivative Works thereof, You may choose to offer,
168 | and charge a fee for, acceptance of support, warranty, indemnity,
169 | or other liability obligations and/or rights consistent with this
170 | License. However, in accepting such obligations, You may act only
171 | on Your own behalf and on Your sole responsibility, not on behalf
172 | of any other Contributor, and only if You agree to indemnify,
173 | defend, and hold each Contributor harmless for any liability
174 | incurred by, or claims asserted against, such Contributor by reason
175 | of your accepting any such warranty or additional liability.
176 |
177 | END OF TERMS AND CONDITIONS
178 |
179 | APPENDIX: How to apply the Apache License to your work.
180 |
181 | To apply the Apache License to your work, attach the following
182 | boilerplate notice, with the fields enclosed by brackets "[]"
183 | replaced with your own identifying information. (Don't include
184 | the brackets!) The text should be enclosed in the appropriate
185 | comment syntax for the file format. We also recommend that a
186 | file or class name and description of purpose be included on the
187 | same "printed page" as the copyright notice for easier
188 | identification within third-party archives.
189 |
190 | Copyright [yyyy] [name of copyright owner]
191 |
192 | Licensed under the Apache License, Version 2.0 (the "License");
193 | you may not use this file except in compliance with the License.
194 | You may obtain a copy of the License at
195 |
196 | http://www.apache.org/licenses/LICENSE-2.0
197 |
198 | Unless required by applicable law or agreed to in writing, software
199 | distributed under the License is distributed on an "AS IS" BASIS,
200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
201 | See the License for the specific language governing permissions and
202 | limitations under the License.
203 |
--------------------------------------------------------------------------------
/README:
--------------------------------------------------------------------------------
1 | CRYPTANALYTIC EXTRACTION OF NEURAL NETWORK MODELS
2 |
3 | This repository contains an implementation of the model extraction attack in our CRYPTO'20 paper
4 |
5 | Cryptanalytic Extraction of Neural Network Models
6 | https://arxiv.org/abs/2003.04884
7 | Nicholas Carlini, Matthew Jagielski, Ilya Mironov
8 |
9 |
10 | INSTALLING
11 |
12 | To get started you will need to install some dependencies. It should suffice to run
13 |
14 | > pip install numpy scipy jax jaxlib matplotlib networkx
15 |
16 | Sometimes JaX (or, more correctly, XLA) puts up a fight during the install,
17 | but if the above works then everything should run properly.
18 |
19 |
20 | EXTRACTING EXAMPLE MODELS
21 |
22 | First, generate a model that we will extract by running
23 |
24 | > python3 train_models.py 10-15-15-1 42
25 |
26 | and then extract it with
27 |
28 | > python3 extract.py 10-15-15-1 42
29 |
30 | this should be quick to extract and then check the quality of this extraction with
31 |
32 | > python3 check_solution_svd.py 10-15-15-1
33 |
34 | or if you have MILP solver installed you can run
35 |
36 | > python3 check_solution_milp.py 10-15-15-1
37 |
38 | and then running the solver on /tmp/test.mod
39 |
40 | By default, the code is set up so that it won't cheat and look at the weights of the
41 | actual neural network we're extracting (and will throw ugly errors if we try).
42 | Some logging looks better if we're allowed to cheat though (e.g., to catch errors
43 | earlier in the process).
44 |
45 | To enable this, set CHEATING=True in src/global_vars.py.
46 |
47 |
48 | EXTRACTING YOUR OWN MODELS
49 |
50 | The code can currently extract only fully-connected neural networks.
51 |
52 | To extract a model, save it as a numpy array in the format [weights, biases]. For
53 | example, a 20-10-1 network could be saved to models/UID_20-10-1.npy
54 | [[np.random.normal(size=(20,10)), np.random.normal(size=(10, 1))], [np.zeros((10,)), np.zeros((1,))]]
55 |
56 | and then run
57 |
58 | > python extract.py UID 20-10-1
59 |
60 |
61 | CITING THIS WORK
62 |
63 | If you find this code useful you can cite
64 |
65 | @article{carlini2020cryptanalytic,
66 | title={Cryptanalytic Extraction of Neural Network Models},
67 | author={Carlini, Nicholas and Jagielski, Matthew and Mironov, Ilya},
68 | booktitle={Annual International Cryptology Conference},
69 | year={2020}
70 | }
71 |
--------------------------------------------------------------------------------
/check_solution_milp.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import numpy as np
16 | import sys
17 | import pickle
18 |
19 | np.random.seed(0)
20 |
21 | NN = 0
22 |
23 | QQ = 0
24 | def n():
25 | global QQ
26 | QQ += 1
27 | return QQ
28 |
29 | var = 0
30 | allvars = []
31 |
32 | def nvar(p='x'):
33 | global var
34 | res = p+str(var)
35 | allvars.append(res)
36 | var += 1
37 | return res
38 |
39 | def dump_net(previous_outputs):
40 | global NN
41 | computation = []
42 |
43 | while len(allvars):
44 | allvars.pop()
45 |
46 | for i in range(len(A)):
47 | next_outputs = []
48 | for out_neuron in range(A[i].shape[1]):
49 | x_var = nvar()
50 | s_var = nvar('s')
51 | r_var = nvar('r')
52 | next_outputs.append(x_var)
53 | combination = " + ".join("%.17f * %s"%(a,b) for a,b in zip(A[i][:,out_neuron],previous_outputs))
54 | computation.append("s.t. ok_%d: "%n() + combination + " + " + str(B[i][out_neuron]) + " = " + x_var + " - " + s_var + ";")
55 | computation.append("s.t. relu_%d: %s <= 1000 * %s;" % (n(), x_var, r_var))
56 | computation.append("s.t. relu_%d: %s <= 1000 * (1 - %s);" % (n(), s_var, r_var))
57 |
58 | previous_outputs = next_outputs
59 |
60 | finalx = [x for x in allvars if x[0] == 'x'][-1]
61 |
62 | prefix = []
63 |
64 | for v in allvars:
65 | if v == finalx:
66 | prefix.append('var '+v+';')
67 | elif v[0] == 'x':
68 | prefix.append('var '+v+' >= 0;')
69 | elif v[0] == 's':
70 | prefix.append('var '+v+' >= 0;')
71 | elif v[0] == 'i':
72 | prefix.append('var '+v+' >= -1;')
73 | elif v[0] == 'r':
74 | prefix.append('var '+v+' binary;')
75 | else:
76 | raise
77 |
78 | NN += 1
79 | return prefix, computation[:-2] + ["s.t. final%d: %s = 0;"%(NN,[x for x in allvars if x[0] == 's'][-1])], finalx
80 |
81 |
82 |
83 | name = sys.argv[1]
84 |
85 | A, B = pickle.load(open("/tmp/real-%s.p"%name,"rb"))
86 |
87 | sizes = [x.shape[0] for x in A] + [1]
88 |
89 | inputs = [nvar('i') for _ in range(sizes[0])]
90 |
91 | prefix1, rest1, outvar1 = dump_net(inputs)
92 |
93 | A, B = pickle.load(open("/tmp/extracted-%s.p"%name,"rb"))
94 |
95 | prefix2, rest2, outvar2 = dump_net(inputs)
96 |
97 |
98 |
99 |
100 | import sys
101 | sys.stdout=open("/tmp/test.mod","w")
102 |
103 | print("\n".join('var '+v+' >= 0;' for v in inputs))
104 |
105 | print("\n".join(prefix1))
106 | print("\n\n")
107 | print("\n".join(prefix2))
108 |
109 | print("var slack;")
110 | print("var which binary;")
111 | print("maximize obj: %s-%s;"%(outvar1,outvar2))
112 |
113 |
114 | print("\n".join(rest1))
115 | print("\n".join(rest2))
116 |
117 | for v in inputs:
118 | if v[0] == 'i':
119 | print("s.t. bounded%s: %s <= 1;"%(v,v))
120 |
121 | print("solve;")
122 | print('display %s;'%(", ".join(x for x in inputs)))
123 | print('display %s;'%outvar1)
124 | print('display %s;'%outvar2)
125 | print('display slack;')
126 |
127 | # Now it's on you. Go and run this model.
128 | # you can export to mps with the following command.
129 |
130 | # glpsol --check --wfreemps /tmp/o.mps --model /tmp/test.mod
131 |
--------------------------------------------------------------------------------
/check_solution_svd.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import jax
16 | import sys
17 | import numpy as onp
18 | import jax.numpy as jnp
19 | import numpy as np
20 | import numpy.linalg
21 | import networkx as nx
22 | import matplotlib.pyplot as plt
23 | import scipy.optimize
24 | import pickle
25 |
26 | from jax.config import config
27 | config.update("jax_enable_x64", True)
28 |
29 | def relu(x):
30 | return x * (x>0)
31 |
32 | #@jax.jit
33 | def run(x,A,B,debug=True, np=np):
34 | for i,(a,b) in enumerate(zip(A,B)):
35 | x = np.dot(x,a)+b
36 | if i < len(A)-1:
37 | x = x*(x>0)
38 | return x
39 |
40 |
41 | name = sys.argv[1] if len(sys.argv) > 1 else "40-20-10-10-1"
42 |
43 | prefix = "/tmp/"
44 |
45 | A1, B1 = pickle.load(open(prefix+"real-%s.p"%name,"rb"))
46 | A2, B2 = pickle.load(open(prefix+"extracted-%s.p"%name,"rb"))
47 |
48 | A1 = [np.array(x,dtype=np.float64) for x in A1]
49 | A2 = [np.array(x,dtype=np.float64) for x in A2]
50 |
51 | B1 = [np.array(x,dtype=np.float64) for x in B1]
52 | B2 = [np.array(x,dtype=np.float64) for x in B2]
53 |
54 |
55 | print("Compute the matrix alignment for the SVD upper bound")
56 | for layer in range(len(A1)-1):
57 | M_real = np.copy(A1[layer])
58 | M_fake = np.copy(A2[layer])
59 |
60 | scores = []
61 |
62 | for i in range(M_real.shape[1]):
63 | vec = M_real[:,i:i+1]
64 | ratio = np.abs(M_fake/vec)
65 |
66 | scores.append(np.std(A2[layer]/vec,axis=0))
67 |
68 |
69 | i_s, j_s = scipy.optimize.linear_sum_assignment(scores)
70 |
71 | for i,j in zip(i_s, j_s):
72 | vec = M_real[:,i:i+1]
73 | ratio = np.abs(M_fake/vec)
74 |
75 | ratio = np.median(ratio[:,j])
76 | #print("Map from", i, j, ratio)
77 |
78 | gap = np.abs(M_fake[:,j]/ratio - M_real[:,i])
79 |
80 | A2[layer][:,j] /= ratio
81 | B2[layer][j] /= ratio
82 | A2[layer+1][j,:] *= ratio
83 |
84 | A2[layer] = A2[layer][:,j_s]
85 | B2[layer] = B2[layer][j_s]
86 |
87 | A2[layer+1] = A2[layer+1][j_s,:]
88 |
89 | A2[1] *= np.sign(A2[1][0])
90 | A2[1] *= np.sign(A1[1][0])
91 |
92 | B2[1] *= np.sign(B2[1])
93 | B2[1] *= np.sign(B1[1])
94 |
95 | print("Finished alignment. Now compute the max error in the matrix.")
96 | max_err = 0
97 | for l in range(len(A1)):
98 | print("Matrix diff", np.sum(np.abs(A1[l]-A2[l])))
99 | print("Bias diff", np.sum(np.abs(B1[l]-B2[l])))
100 | max_err = max(max_err, np.max(np.abs(A1[l]-A2[l])))
101 | max_err = max(max_err, np.max(np.abs(B1[l]-B2[l])))
102 |
103 | print("Number of bits of precision in the weight matrix",
104 | -np.log(max_err)/np.log(2))
105 |
106 | print("\nComputing SVD upper bound")
107 | high = np.ones(A1[0].shape[0])
108 | low = -np.ones(A1[0].shape[0])
109 | input_bound = np.sum((high-low)**2)**.5
110 | prev_bound = 0
111 | for i in range(len(A1)):
112 | largest_value = np.linalg.svd(A1[i]-A2[i])[1][0] * input_bound
113 | largest_value += np.linalg.svd(A1[i])[1][0] * prev_bound
114 | largest_value += np.sum((B1[i]-B2[i])**2)**.5
115 | prev_bound = largest_value
116 | print("\tAt layer", i, "loss is bounded by", largest_value)
117 |
118 | print('Upper bound on number of bits of precision in the output through SVD', -np.log(largest_value)/np.log(2))
119 |
120 | print("\nFinally estimate it through random samples to make sure we haven't made a mistake") # not that that would ever happen
121 | def loss(x):
122 | return np.abs(run(x, A=A1, B=B1)-run(x, A=A2, B=B2))
123 |
124 | ls = []
125 | for _ in range(100):
126 | if _%10 == 0:
127 | print("Iter %d/100"%_)
128 | inp = onp.random.normal(0, 1, (int(1000000/A1[0].shape[0]), A1[0].shape[0]))
129 | inp /= np.sum(inp**2,axis=1,keepdims=True)**.5
130 | inp *= np.sum((np.ones(A1[0].shape[0])*2)**2)**.5
131 | ell = loss(inp).flatten()
132 | ls.extend(ell)
133 |
134 | ls = onp.array(ls).flatten()
135 |
136 | print("Fewest number of bits of precision over", len(ls), "random samples:", -np.log(np.max(ls))/np.log(2))
137 |
138 | # Finally plot a distribution of the values to see
139 | plt.hist(ls,30)
140 | plt.semilogy()
141 | plt.savefig("/tmp/a.pdf")
142 | exit(0)
143 |
--------------------------------------------------------------------------------
/extract.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import sys
16 | import random
17 | import traceback
18 | import time
19 | import numpy.linalg
20 | import pickle
21 | import multiprocessing as mp
22 | import os
23 | import signal
24 |
25 | import numpy as np
26 |
27 | from src.utils import matmul, KnownT, check_quality, SAVED_QUERIES, run
28 | from src.find_witnesses import sweep_for_critical_points
29 | import src.refine_precision as refine_precision
30 | import src.layer_recovery as layer_recovery
31 | import src.sign_recovery as sign_recovery
32 | from src.global_vars import *
33 |
34 | #####################################################################
35 | ## MAIN FUNCTION. This is where it all happens. ##
36 | #####################################################################
37 |
38 | def run_full_attack():
39 | global query_count, SAVED_QUERIES
40 |
41 | extracted_normals = []
42 | extracted_biases = []
43 |
44 | known_T = KnownT(extracted_normals, extracted_biases)
45 |
46 | for layer_num in range(0,len(A)-1):
47 | # For each layer of the network ...
48 |
49 | # First setup the critical points generator
50 | critical_points = sweep_for_critical_points(PARAM_SEARCH_AT_LOCATION, known_T)
51 |
52 | # Extract weights corresponding to those critical points
53 | extracted_normal, extracted_bias, mask = layer_recovery.compute_layer_values(critical_points,
54 | known_T,
55 | layer_num)
56 |
57 | # Report how well we're doing
58 | check_quality(layer_num, extracted_normal, extracted_bias)
59 |
60 | # Now, make them more precise
61 | extracted_normal, extracted_bias = refine_precision.improve_layer_precision(layer_num,
62 | known_T, extracted_normal, extracted_bias)
63 | print("Query count", query_count)
64 |
65 | # And print how well we're doing
66 | check_quality(layer_num, extracted_normal, extracted_bias)
67 |
68 | # New generator
69 | critical_points = sweep_for_critical_points(1e1)
70 |
71 | # Solve for signs
72 | if layer_num == 0 and sizes[1] <= sizes[0]:
73 | extracted_sign = sign_recovery.solve_contractive_sign(known_T, extracted_normal, extracted_bias, layer_num)
74 | elif layer_num > 0 and sizes[1] <= sizes[0] and all(sizes[x+1] <= sizes[x]/2 for x in range(1,len(sizes)-1)):
75 | try:
76 | extracted_sign = sign_recovery.solve_contractive_sign(known_T, extracted_normal, extracted_bias, layer_num)
77 | except AcceptableFailure as e:
78 | print("Contractive solving failed; fall back to noncontractive method")
79 | if layer_num == len(A)-2:
80 | print("Solve final two")
81 | break
82 |
83 | extracted_sign, _ = sign_recovery.solve_layer_sign(known_T, extracted_normal, extracted_bias, critical_points,
84 | layer_num,
85 | l1_mask=np.int32(np.sign(mask)))
86 |
87 | else:
88 | if layer_num == len(A)-2:
89 | print("Solve final two")
90 | break
91 |
92 | extracted_sign, _ = sign_recovery.solve_layer_sign(known_T, extracted_normal, extracted_bias, critical_points,
93 | layer_num,
94 | l1_mask=np.int32(np.sign(mask)))
95 |
96 | print("Extracted", extracted_sign)
97 | print('real sign', np.int32(np.sign(mask)))
98 |
99 | print("Total query count", query_count)
100 |
101 | # Correct signs
102 | extracted_normal *= extracted_sign
103 | extracted_bias *= extracted_sign
104 | extracted_bias = np.array(extracted_bias, dtype=np.float64)
105 |
106 | # Report how we're doing
107 | extracted_normal, extracted_bias = check_quality(layer_num, extracted_normal, extracted_bias, do_fix=True)
108 |
109 | extracted_normals.append(extracted_normal)
110 | extracted_biases.append(extracted_bias)
111 |
112 | known_T = KnownT(extracted_normals, extracted_biases)
113 |
114 | for a,b in sorted(query_count_at.items(),key=lambda x: -x[1]):
115 | print('count', b, '\t', 'line:', a, ':', self_lines[a-1].strip())
116 |
117 | # And then finish up
118 | if len(extracted_normals) == len(sizes)-2:
119 | print("Just solve final layer")
120 | N = int(len(SAVED_QUERIES)/1000) or 1
121 | ins, outs = zip(*SAVED_QUERIES[::N])
122 | solve_final_layer(known_T, np.array(ins), np.array(outs))
123 | else:
124 | print("Solve final two")
125 | solve_final_two_layers(known_T, extracted_normal, extracted_bias)
126 |
127 |
128 | def solve_final_two_layers(known_T, known_A0, known_B0):
129 | ## Recover the final two layers through brute forcing signs, then least squares
130 | ## Yes, this is mostly a copy of solve_layer_sign. I am repeating myself. Sorry.
131 | LAYER = len(sizes)-2
132 | filtered_inputs = []
133 | filtered_outputs = []
134 |
135 | # How many unique points to use. This seems to work. Tweak if needed...
136 | # (In checking consistency of the final layer signs)
137 | N = int(len(SAVED_QUERIES)/100) or 1
138 | ins, outs = zip(*SAVED_QUERIES[::N])
139 | filtered_inputs, filtered_outputs = zip(*SAVED_QUERIES[::N])
140 | print('Total query count', len(SAVED_QUERIES))
141 | print("Solving on", len(filtered_inputs))
142 |
143 | inputs, outputs = np.array(filtered_inputs), np.array(filtered_outputs)
144 | known_hidden_so_far = known_T.forward(inputs, with_relu=True)
145 |
146 | K = sizes[LAYER]
147 | print("K IS", K)
148 | shuf = list(range(1< (offset + direction * high)
28 |
29 | If return_upto_one is true then only return one solution which is the first
30 | solution that is closest to low.
31 | """
32 |
33 | if offset is None:
34 | offset = np.random.normal(0,1,size=(DIM))
35 | if direction is None:
36 | direction = np.random.normal(0,1,size=(DIM))
37 |
38 | c = {}
39 | def memo_forward_pass(x):
40 | if x not in c:
41 | c[x] = run((offset+direction*x)[np.newaxis,:])
42 | return c[x]
43 |
44 | relus = []
45 |
46 | def search(low, high):
47 | mid = (low+high)/2
48 |
49 | y1 = f_low = memo_forward_pass(low)
50 | f_mid = memo_forward_pass(mid)
51 | y2 = f_high = memo_forward_pass(high)
52 |
53 | if CHEATING:
54 | ncross = cheat_num_relu_crosses((offset+direction*low)[np.newaxis,:], (offset+direction*high)[np.newaxis,:])
55 |
56 | # We want to write (f_low + f_high)/2 == f_mid but numerical problems are evil
57 | if np.abs(f_mid - (f_high + f_low)/2)/(high-low) < 1e-8:
58 | # We're in a linear region
59 | if CHEATING:
60 | print("Skip linear", ncross)
61 | print(f_mid - (f_high + f_low)/2, f_mid, (f_high + f_low)/2)
62 | return
63 | elif high-low < 1e-6:
64 | if CHEATING:
65 | print("Find solution", ncross)
66 |
67 | relus.append(offset + direction*mid)
68 | return
69 |
70 |
71 | search(low, mid)
72 | if return_upto_one and len(relus) > 0:
73 | # we're done because we just want the left-most solution; don't do more searching
74 | return
75 | search(mid, high)
76 |
77 | search(np.float64(low),
78 | np.float64(high))
79 |
80 | return relus
81 |
82 |
83 | def do_better_sweep(offset=None, direction=None, low=-1e3, high=1e3, return_upto_one=False,
84 | debug=False, debug2=False, known_T=None, run=run, return_scalar=False):
85 | """
86 | A much more efficient implementation of searching for critical points.
87 | Has the same interface as do_slow_sweep.
88 |
89 | Nearly identical, except that when we are in a region with only one critical
90 | point, does some extra math to identify where exactly the critical point is
91 | and returns it all in one go.
92 | In practice this is both much more efficient and much more accurate.
93 |
94 | """
95 | debug = debug and CHEATING
96 | debug2 = debug2 and CHEATING
97 |
98 | if offset is None:
99 | offset = np.random.normal(0,1,size=(DIM))
100 | if direction is None:
101 | direction = np.random.normal(0,1,size=(DIM))
102 |
103 | def memo_forward_pass(x, c={}):
104 | if x not in c:
105 | c[x] = run((offset+direction*x)[np.newaxis,:])
106 | return c[x]
107 |
108 | relus = []
109 |
110 | def search(low, high):
111 | if debug:
112 | print('low high',low,high)
113 | mid = (low+high)/2
114 |
115 | y1 = f_low = memo_forward_pass(low)
116 | f_mid = memo_forward_pass(mid)
117 | y2 = f_high = memo_forward_pass(high)
118 |
119 | if debug:
120 | ncross = cheat_num_relu_crosses((offset+direction*low)[np.newaxis,:], (offset+direction*high)[np.newaxis,:])
121 | print("ncross", ncross)
122 |
123 |
124 | if debug:
125 | print('aa',f_mid, f_high, f_low)
126 | print('compare', np.abs(f_mid - (f_high + f_low)/2), SKIP_LINEAR_TOL*((high-low)**.5))
127 | print("really", ncross)
128 |
129 | if np.abs(f_mid - (f_high + f_low)/2) < SKIP_LINEAR_TOL*((high-low)**.5):
130 | # We're in a linear region
131 | if debug:
132 | print("Skip linear", sum(ncross), ncross)
133 | return
134 | elif high-low < 1e-8:
135 | # Too close to each other
136 | if debug:
137 | print('wat', ncross)
138 | return
139 | else:
140 | # Check if there is exactly one ReLU switching sign, or if there are multiple.
141 | # To do this, use the 2-linear test from Jagielski et al. 2019
142 | #
143 | #
144 | # /\ <---- real_h_at_x
145 | # / \
146 | # / \
147 | # / \
148 | # / \
149 | # / \
150 | # / \
151 | # low q1 x_s_b q3 high
152 | #
153 | # Use (low,q1) to estimate the direction of the first line
154 | # Use (high,q3) to estimate the direction of the second line
155 | # They should in theory intersect at (x_should_be, y_should_be)
156 | # Query to compute real_h_at_x and then check if that's what we get
157 | # Then check that we're linear from x_should_be to low, and
158 | # linear from x_should_be to high.
159 | # If it all checks out, then return the solution.
160 | # Otherwise recurse again.
161 |
162 |
163 | q1 = (low+mid)*.5
164 | q3 = (high+mid)*.5
165 |
166 | f_q1 = memo_forward_pass(q1)
167 | f_q3 = memo_forward_pass(q3)
168 |
169 |
170 | m1 = (f_q1-f_low)/(q1-low)
171 | m2 = (f_q3-f_high)/(q3-high)
172 |
173 | if m1 != m2:
174 | d = (high-low)
175 | alpha = (y2 - y1 - d * m2) / (d * m1 - d * m2)
176 |
177 | x_should_be = low + (y2 - y1 - d * m2) / (m1 - m2)
178 | height_should_be = y1 + m1*(y2 - y1 - d * m2) / (m1 - m2)
179 |
180 | if m1 == m2:
181 | # If the slopes on both directions are the same (e.g., the function is flat)
182 | # then we need to split and can't learn anything
183 | pass
184 | elif np.all(.25+1e-5 < alpha) and np.all(alpha < .75-1e-5) and np.max(x_should_be)-np.min(x_should_be) < 1e-5:
185 | x_should_be = np.median(x_should_be)
186 | real_h_at_x = memo_forward_pass(x_should_be)
187 |
188 | if np.all(np.abs(real_h_at_x - height_should_be) < SKIP_LINEAR_TOL*100):
189 | # Compute gradient on each side and check for linearity
190 |
191 |
192 | eighth_left = x_should_be-1e-4
193 | eighth_right = x_should_be+1e-4
194 | grad_left = (memo_forward_pass(eighth_left)-real_h_at_x)/(eighth_left-x_should_be)
195 | grad_right = (memo_forward_pass(eighth_right)-real_h_at_x)/(eighth_right-x_should_be)
196 |
197 | if np.all(np.abs(grad_left-m1)>SKIP_LINEAR_TOL*10) or np.all(np.abs(grad_right-m2)>SKIP_LINEAR_TOL*10):
198 | if debug:
199 | print("it's nonlinear")
200 | pass
201 | else:
202 |
203 | if debug:
204 | print("OK", ncross)
205 | vals = cheat_get_inner_layers((offset+direction*x_should_be))
206 | smallest = min([np.min(np.abs(v)) for v in vals])
207 | print("Small", smallest, vals)
208 | if smallest > .01:
209 |
210 | raise
211 | if debug and sum(ncross) > 1:
212 | print("BADNESS")
213 | if return_scalar:
214 | relus.append(x_should_be)
215 | else:
216 | relus.append(offset + direction*x_should_be)
217 | return
218 |
219 |
220 | search(low, mid)
221 | if return_upto_one and len(relus) > 0:
222 | # we're done because we just want the left-most solution; don't do more searching
223 | return
224 | search(mid, high)
225 |
226 | if debug2 or debug:
227 | print("Sweeping", cheat_num_relu_crosses((offset+direction*low)[np.newaxis,:], (offset+direction*high)[np.newaxis,:]))
228 |
229 | # If we know what some of the earlier layers look like, then don't waste compute
230 | # to recover those early layer critical points again.
231 | # Just find the ones on the deeper layers and then add the early-layer ones in
232 | # (where they should be).
233 | # WARNING: this assumes that known_T is high precision. If it is not, then
234 | # it will add in the WRONG locations and that is very bad.
235 | if known_T is not None and False:
236 | def fwd(x):
237 | return np.sum(known_T.forward(x, with_relu=True),axis=1)
238 | prev_solns = do_better_sweep(offset, direction, low, high, run=fwd,
239 | return_scalar=True)
240 | prev_solns = [low]+prev_solns+[high]
241 | for l, h in zip(prev_solns, prev_solns[1:]):
242 | search(l, h)
243 | if h != high:
244 | relus.append(offset+direction*h)
245 | return relus
246 |
247 | search(low,
248 | high)
249 |
250 | return relus
251 |
252 | def sweep_for_critical_points(std=1, known_T=None):
253 | while True:
254 | print("Start another sweep")
255 | qs = query_count
256 | sweep = do_better_sweep(
257 | offset=np.random.normal(0, np.random.uniform(std/10,std), size=DIM),
258 | known_T=known_T,
259 | low=-std*1e3, high=std*1e3, debug=False)
260 | print("Total intersections found", len(sweep))
261 | print('delta queries', query_count - qs)
262 | for point in sweep:
263 | yield point
264 |
265 |
--------------------------------------------------------------------------------
/src/global_vars.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import sys
16 | import multiprocessing as mp
17 | import numpy as np
18 | import random
19 |
20 | #####################################################################
21 | ## GLOBAL VARIABLES. I am a bad person and use globals. I'm sorry. ##
22 | #####################################################################
23 |
24 | from jax.config import config
25 | config.update("jax_enable_x64", True)
26 |
27 | # To ensure reproducible results to help debugging, set seeds for randomness.
28 | seed = int(sys.argv[2]) if len(sys.argv) > 2 else 42 # for luck
29 | np.random.seed(seed)
30 | random.seed(seed)
31 |
32 | # sizes is the number of relus in each layer
33 | sizes = list(map(int,sys.argv[1].split("-")))
34 | dimensions = [tuple([x]) for x in sizes]
35 | neuron_count = sizes
36 |
37 | DIM = sizes[0]
38 |
39 | __cheat_A, __cheat_B = np.load("models/" + str(seed) + "_" + "-".join(map(str,sizes))+".npy", allow_pickle=True)
40 |
41 | # In order to help debugging, we're going to log what lines of code
42 | # cause lots of queries to be generated. Use this to improve things.
43 | query_count = 0
44 | query_count_at = {}
45 |
46 | # HYPERPARAMETERS. Change these at your own risk. It may all die.
47 |
48 | PARAM_SEARCH_AT_LOCATION = 1e2
49 | GRAD_EPS = 1e-4
50 | SKIP_LINEAR_TOL = 1e-8
51 | BLOCK_ERROR_TOL = 1e-3
52 | BLOCK_MULTIPLY_FACTOR = 2
53 | DEAD_NEURON_THRESHOLD = 1000
54 | MIN_SAME_SIZE = 4 # this is most likely what should be changed
55 |
56 | if len(sizes) == 3:
57 | PARAM_SEARCH_AT_LOCATION = 1e4
58 | GRAD_EPS = 1e1
59 | SKIP_LINEAR_TOL = 1e-7
60 | BLOCK_MULTIPLY_FACTOR = 8
61 |
62 | # When we save the results, we're going to use this to make sure that
63 | # (a) we don't trash over old results, but
64 | # (b) we don't keep stale results around
65 | name_hash = "-".join(map(str,sizes))+str(hash(tuple(np.random.get_state()[1])))
66 |
67 | # CHEAT MODE. Turning on lets you read the actual weight matrix.
68 |
69 | # Enable IDDQD mode
70 | # In order to debug sometimes it helps to be able to look at the actual values of the
71 | # true weight matrix.
72 | # When we're allowed to do that, assign them from __cheat_A and __cheat_B
73 | # When we're not, then just give them a constant 0 so the code doesn't crash
74 | CHEATING = False
75 |
76 | if CHEATING:
77 | A = [np.array(x) for x in __cheat_A]
78 | B = [np.array(x) for x in __cheat_B]
79 | else:
80 | A = [np.zeros_like(x) for x in __cheat_A]
81 | B = [np.zeros_like(x) for x in __cheat_B]
82 |
83 | MPROC_THREADS = max(mp.cpu_count(),1)
84 | pool = []
85 |
--------------------------------------------------------------------------------
/src/hyperplane_normal.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import numpy as np
16 |
17 | from src.global_vars import *
18 | from src.utils import run, basis, AcceptableFailure, cheat_get_inner_layers, which_is_zero
19 |
20 | def get_grad(x, direction, eps=1e-6):
21 | """
22 | Finite differences to estimate the gradient.
23 | Uses just two coordinates---that's sufficient for most of the code.
24 |
25 | Can fail if we're right at a critical point and we get the left and right side.
26 | /
27 | X
28 | /
29 | -X--/
30 |
31 | """
32 | x = x[np.newaxis,:]
33 | a = run(x-eps*direction)
34 | b = run(x)
35 | g1 = (b-a)/eps
36 | return g1
37 |
38 | def get_second_grad_unsigned(x, direction, eps, eps2):
39 | """
40 | Compute the second derivitive by computing the first derivitive twice.
41 | """
42 | grad_value = get_grad(x + direction*eps, direction, eps2)+get_grad(x - direction*eps, -direction, eps2)
43 |
44 | return grad_value[0]
45 |
46 | MASK = np.array([1,-1,1,-1])
47 | def get_second_grad_unsigned(x, direction, eps, eps2):
48 | """
49 | Same as the above but batched so it's more efficient.
50 | """
51 | x = np.array([x + direction * (eps - eps2),
52 | x + direction * (eps),
53 | x - direction * (eps - eps2),
54 | x - direction * (eps)])
55 |
56 | out = run(x)
57 |
58 | return np.dot(out.flatten(), MASK)/eps
59 |
60 |
61 |
62 | def get_ratios(critical_points, N, with_sign=True, eps=1e-5):
63 | """
64 | Compute the input weights to one neuron on the first layer.
65 | One of the core algorithms described in the paper.
66 |
67 | Given a set of critical point, compute the gradient for the first N directions.
68 | In practice N = range(DIM)
69 |
70 | Compute the second partial derivitive along each of the axes. This gives
71 | us the unsigned ratios corresponding to the ratio of the weights.
72 |
73 | /
74 | ^ /
75 | | /
76 | |/
77 | <----X----> direction_1
78 | /|
79 | / |
80 | / V
81 | / direction_2
82 |
83 | If we want to recover signs then we should also query on direction_1+direction_2
84 | And check to see if we get the correct solution.
85 | """
86 | ratios = []
87 | for j,point in enumerate(critical_points):
88 | ratio = []
89 | for i in N[j]:
90 | ratio.append(get_second_grad_unsigned(point, basis(i), eps, eps/3))
91 |
92 | if with_sign:
93 | both_ratio = []
94 | for i in N[j]:
95 | both_ratio.append(get_second_grad_unsigned(point, (basis(i) + basis(N[j][0]))/2, eps, eps/3))
96 |
97 | signed_ratio = []
98 | for i in range(len(ratio)):
99 | # When we have at least one y value already we need to orient this one
100 | # so that they point the same way.
101 | # We are given |f(x+d1)| and |f(x+d2)|
102 | # Compute |f(x+d1+d2)|.
103 | # Then either
104 | # |f(x+d1+d2)| = |f(x+d1)| + |f(x+d2)|
105 | # or
106 | # |f(x+d1+d2)| = |f(x+d1)| - |f(x+d2)|
107 | # or
108 | # |f(x+d1+d2)| = |f(x+d2)| - |f(x+d1)|
109 | positive_error = abs(abs(ratio[0]+ratio[i])/2 - abs(both_ratio[i]))
110 | negative_error = abs(abs(ratio[0]-ratio[i])/2 - abs(both_ratio[i]))
111 |
112 | if positive_error > 1e-4 and negative_error > 1e-4:
113 | print("Probably something is borked")
114 | print("d^2(e(i))+d^2(e(j)) != d^2(e(i)+e(j))", positive_error, negative_error)
115 | raise
116 |
117 | if positive_error < negative_error:
118 | signed_ratio.append(ratio[i])
119 | else:
120 | signed_ratio.append(-ratio[i])
121 | else:
122 | signed_ratio = ratio
123 |
124 | ratio = np.array(signed_ratio)
125 |
126 | #print(ratio)
127 | ratios.append(ratio)
128 |
129 | return ratios
130 |
131 | def get_ratios_lstsq(LAYER, critical_points, N, known_T, debug=False, eps=1e-5):
132 | """
133 | Do the same thing as get_ratios, but works when we can't directly control where we want to query.
134 |
135 | This means we can't directly choose orthogonal directions, and so we're going
136 | to just pick random ones and then use least-squares to do it
137 | """
138 | #pickle.dump((LAYER, critical_points, N, known_T, debug, eps),
139 | # open("/tmp/save.p","wb"))
140 | ratios = []
141 | for i,point in enumerate(critical_points):
142 | if CHEATING:
143 | layers = cheat_get_inner_layers(point)
144 | layer_vals = [np.min(np.abs(x)) for x in layers]
145 | which_layer = np.argmin(layer_vals)
146 | #print("real zero", np.argmin(np.abs(layers[0])))
147 | which_neuron = which_is_zero(which_layer, layers)
148 | #print("Which neuron?", which_neuron)
149 |
150 | real = A[which_layer][:,which_neuron]/A[which_layer][0,which_neuron]
151 |
152 | # We're going to create a system of linear equations
153 | # d_matrix is going to hold the inputs,
154 | # and ys is going to hold the resulting learned outputs
155 | d_matrix = []
156 | ys = []
157 |
158 | # Query on N+2 random points, so that we have redundency
159 | # for the least squares solution.
160 | for i in range(np.sum(known_T.forward(point) != 0)+2):
161 | # 1. Choose a random direction
162 | d = np.sign(np.random.normal(0,1,point.shape))
163 | d_matrix.append(d)
164 |
165 | # 2. See what the second partial derivitive at this value is
166 | ratio_val = get_second_grad_unsigned(point, d, eps, eps/3)
167 |
168 | # 3. Get the sign correct
169 | if len(ys) > 0:
170 | # When we have at least one y value already we need to orient this one
171 | # so that they point the same way.
172 | # We are given |f(x+d1)| and |f(x+d2)|
173 | # Compute |f(x+d1+d2)|.
174 | # Then either
175 | # |f(x+d1+d2)| = |f(x+d1)| + |f(x+d2)|
176 | # or
177 | # |f(x+d1+d2)| = |f(x+d1)| - |f(x+d2)|
178 | # or
179 | # |f(x+d1+d2)| = |f(x+d2)| - |f(x+d1)|
180 | both_ratio_val = get_second_grad_unsigned(point, (d+d_matrix[0])/2, eps, eps/3)
181 |
182 | positive_error = abs(abs(ys[0]+ratio_val)/2 - abs(both_ratio_val))
183 | negative_error = abs(abs(ys[0]-ratio_val)/2 - abs(both_ratio_val))
184 |
185 | if positive_error > 1e-4 and negative_error > 1e-4:
186 | print("Probably something is borked")
187 | print("d^2(e(i))+d^2(e(j)) != d^2(e(i)+e(j))", positive_error, negative_error)
188 | raise AcceptableFailure()
189 |
190 |
191 | if negative_error < positive_error:
192 | ratio_val *= -1
193 |
194 | ys.append(ratio_val)
195 |
196 | d_matrix = np.array(d_matrix)
197 | # Now we need to compute the system of equations
198 | # We have to figure out what the vectors look like in hidden space,
199 | # so compute that precisely
200 | h_matrix = np.array(known_T.forward_at(point, d_matrix))
201 |
202 |
203 | # Which dimensions do we lose?
204 | column_is_zero = np.mean(np.abs(h_matrix)<1e-8,axis=0) > .5
205 | assert np.all((known_T.forward(point, with_relu=True) == 0) == column_is_zero)
206 |
207 | #print(h_matrix.shape)
208 |
209 | # Solve the least squares problem and get the solution
210 | # This is equal to solving for the ratios of the weight vector
211 | soln, *rest = np.linalg.lstsq(np.array(h_matrix, dtype=np.float32),
212 | np.array(ys, dtype=np.float32), 1e-5)
213 |
214 | # Set the columns we know to be wrong to NaN so that it's obvious
215 | # this isn't important but it helps us distinguish from genuine errors
216 | # and the kind that we can't avoic because of zero gradients
217 | soln[column_is_zero] = np.nan
218 |
219 | ratios.append(soln)
220 |
221 | return ratios
222 |
--------------------------------------------------------------------------------
/src/layer_recovery.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import jax
16 | import jax.numpy as jnp
17 | import numpy as np
18 | import time
19 | import networkx as nx
20 | import collections
21 |
22 | from src.global_vars import *
23 | from src.find_witnesses import do_better_sweep
24 | from src.hyperplane_normal import get_ratios_lstsq
25 |
26 | from src.utils import AcceptableFailure, GatherMoreData, matmul, KnownT, cheat_get_inner_layers, which_is_zero
27 | import src.sign_recovery as sign_recovery
28 |
29 |
30 | @jax.jit
31 | def process_block(ratios, other_ratios):
32 | """
33 | Let jax efficiently compute pairwise similarity by blocking things.
34 | """
35 | differences = jnp.abs(ratios[:,jnp.newaxis,:] - other_ratios[jnp.newaxis,:,:])
36 | differences = differences / jnp.abs(ratios[:,jnp.newaxis,:]) + differences / jnp.abs(other_ratios[jnp.newaxis,:,:])
37 |
38 | close = differences < BLOCK_ERROR_TOL * jnp.log(ratios.shape[1])
39 |
40 | pairings = jnp.sum(close, axis=2) >= max(MIN_SAME_SIZE,BLOCK_MULTIPLY_FACTOR*(np.log(ratios.shape[1])-2))
41 |
42 | return pairings
43 |
44 | def graph_solve(all_ratios, all_criticals, expected_neurons, LAYER, debug=False):
45 | # 1. Load the critical points and ratios we precomputed
46 |
47 | all_ratios = np.array(all_ratios, dtype=np.float64)
48 | all_ratios_f32 = np.array(all_ratios, dtype=np.float32)
49 | all_criticals = np.array(all_criticals, dtype=np.float64)
50 |
51 | # Batch them to be sensibly sized
52 | ratios_group = [all_ratios_f32[i:i+1000] for i in range(0,len(all_ratios),1000)]
53 | criticals_group = [all_criticals[i:i+1000] for i in range(0,len(all_criticals),1000)]
54 |
55 | # 2. Compute the similarity pairwise between the ratios we've computed
56 |
57 | print("Go up to", len(criticals_group))
58 | now = time.time()
59 | all_pairings = [[] for _ in range(sum(map(len,ratios_group)))]
60 | for batch_index,(criticals,ratios) in enumerate(zip(criticals_group, ratios_group)):
61 | print(batch_index)
62 |
63 | # Compute the all-pairs similarity
64 | axis = list(range(all_ratios.shape[1]))
65 | random.shuffle(axis)
66 | axis = axis[:20]
67 | for dim in axis:
68 | # We may have an error on one of the directions, so let's try all of them
69 | scaled_all_ratios = all_ratios_f32 / all_ratios_f32[:,dim:dim+1]
70 | scaled_ratios = ratios / ratios[:,dim:dim+1]
71 |
72 | batch_pairings = process_block(scaled_ratios, scaled_all_ratios)
73 |
74 | # To get the offset, Compute the cumsum of the length up to batch_index
75 | batch_offset = sum(map(len,ratios_group[:batch_index]))
76 | # And now create the graph matching ratios that are similar
77 | for this_batch_i,global_j in zip(*np.nonzero(np.array(batch_pairings))):
78 | all_pairings[this_batch_i + batch_offset].append(global_j)
79 | print(time.time()-now)
80 |
81 | graph = nx.Graph()
82 | # Add the edges to the graph, removing self-loops
83 | graph.add_edges_from([(i,j) for i,js in enumerate(all_pairings) for j in js if abs(i-j) > 1])
84 | components = list(nx.connected_components(graph))
85 |
86 | sorted_components = sorted(components, key=lambda x: -len(x))
87 |
88 | if CHEATING:
89 | print('Total (unmatched) examples found:', sorted(collections.Counter(which_is_zero(LAYER, cheat_get_inner_layers(all_criticals))).items()))
90 |
91 | #for crit,rat in zip(all_criticals,all_ratios):
92 | # if which_is_zero(LAYER, cheat_get_inner_layers(crit)) == 6:
93 | # print(" ".join("%.6f"%abs(x) if not np.isnan(x) else " nan" for x in rat))
94 |
95 | #cc = which_is_zero(LAYER, cheat_get_inner_layers(all_criticals))
96 | #print("THREES")
97 | #
98 | #threes = []
99 | #print("Pair", process_block
100 | # [all_ratios[x] for x in range(len(all_criticals)) if cc[x] == 3]
101 |
102 |
103 |
104 | if len(components) == 0:
105 | print("No components found")
106 | raise AcceptableFailure()
107 | print("Graph search found", len(components), "different components with the following counts", list(map(len,sorted_components)))
108 |
109 | if CHEATING:
110 | which_neurons = [collections.Counter(which_is_zero(LAYER, cheat_get_inner_layers(all_criticals[list(orig_component)]))) for orig_component in sorted_components]
111 | first_index_of = [-1]*expected_neurons
112 |
113 | for i,items in enumerate(which_neurons):
114 | for item in items.keys():
115 | if first_index_of[item] == -1:
116 | first_index_of[item] = i
117 |
118 | print('These components corresopnd to', which_neurons)
119 | print("Withe the corresponding index in the list:", first_index_of)
120 |
121 | previous_num_components = np.inf
122 |
123 | while previous_num_components > len(sorted_components):
124 | previous_num_components = len(sorted_components)
125 | candidate_rows = []
126 | candidate_components = []
127 |
128 | datas = [all_ratios[list(component)] for component in sorted_components]
129 | results = pool[0].map(ratio_normalize, datas)
130 |
131 | candidate_rows = [x[0] for x in results]
132 | candidate_components = sorted_components
133 |
134 | candidate_rows = np.array(candidate_rows)
135 |
136 | new_pairings = [[] for _ in range(len(candidate_rows))]
137 |
138 | # Re-do the pairings
139 | for dim in range(all_ratios.shape[1]):
140 | scaled_ratios = candidate_rows / candidate_rows[:,dim:dim+1]
141 |
142 | batch_pairings = process_block(scaled_ratios, scaled_ratios)
143 |
144 | # And now create the graph matching ratios that are similar
145 | for this_batch_i,global_j in zip(*np.nonzero(np.array(batch_pairings))):
146 | new_pairings[this_batch_i].append(global_j)
147 |
148 | graph = nx.Graph()
149 | # Add the edges to the graph, ALLOWING self-loops this time
150 | graph.add_edges_from([(i,j) for i,js in enumerate(new_pairings) for j in js])
151 | components = list(nx.connected_components(graph))
152 |
153 | components = [sum([list(candidate_components[y]) for y in comp],[]) for comp in components]
154 |
155 | sorted_components = sorted(components, key=lambda x: -len(x))
156 |
157 | print("After re-doing the graph, the component counts is", len(components), "with items", list(map(len,sorted_components)))
158 |
159 | if CHEATING:
160 | which_neurons = [collections.Counter(which_is_zero(LAYER, cheat_get_inner_layers(all_criticals[list(orig_component)]))) for orig_component in sorted_components]
161 | first_index_of = [-1]*expected_neurons
162 |
163 | for i,items in enumerate(which_neurons):
164 | for item in items.keys():
165 | if first_index_of[item] == -1:
166 | first_index_of[item] = i
167 |
168 | print('Corresponding to', which_neurons)
169 | print("First index:", first_index_of)
170 |
171 | print("Expected neurons", expected_neurons)
172 |
173 |
174 | print("Processing each connected component in turn.")
175 |
176 | resulting_examples = []
177 | resulting_rows = []
178 |
179 | skips_because_of_nan = 0
180 | failure = None
181 |
182 | for c_count, component in enumerate(sorted_components):
183 | if debug:
184 | print("\n")
185 | if c_count >= expected_neurons:
186 | print("WARNING: This one might be a duplicate!")
187 | print("On component", c_count, "with indexs", component)
188 | if debug and CHEATING:
189 | inner = cheat_get_inner_layers(all_criticals[list(component)])
190 | print('Corresponding to (cheating) ', which_is_zero(LAYER, inner))
191 |
192 | possible_matrix_rows = all_ratios[list(component)]
193 |
194 | guessed_row, normalize_axis, normalize_error = ratio_normalize(possible_matrix_rows)
195 |
196 | print('The guessed error in the computation is',normalize_error, 'with', len(component), 'witnesses')
197 | if normalize_error > .01 and len(component) <= 5:
198 | print("Component size less than 5 with high error; this isn't enough to be sure")
199 | continue
200 |
201 | print("Normalize on axis", normalize_axis)
202 |
203 | if len(resulting_rows):
204 | scaled_resulting_rows = np.array(resulting_rows)
205 | #print(scaled_resulting_rows.shape)
206 | scaled_resulting_rows /= scaled_resulting_rows[:,normalize_axis:normalize_axis+1]
207 | delta = np.abs(scaled_resulting_rows - guessed_row[np.newaxis,:])
208 | if min(np.nanmax(delta, axis=1)) < 1e-2:
209 | print("Likely have found this node before")
210 | raise
211 |
212 |
213 | if CHEATING:
214 | # Check our work against the ground truth entries in the corresponding matrix
215 | layers = cheat_get_inner_layers(all_criticals[list(component)[0]])
216 | layer_vals = [np.min(np.abs(x)) for x in layers]
217 | which_layer = np.argmin(layer_vals)
218 |
219 | M = A[which_layer]
220 | which_neuron = which_is_zero(which_layer, layers)
221 | print("Neuron corresponds to", which_neuron)
222 | if which_layer != LAYER:
223 | which_neuron = 0
224 | normalize_axis = 0
225 |
226 | actual_row = M[:,which_neuron]/M[normalize_axis,which_neuron]
227 | actual_row = actual_row[:guessed_row.shape[0]]
228 |
229 | do_print_err = np.any(np.isnan(guessed_row))
230 |
231 | if which_layer == LAYER:
232 | error = np.max(np.abs(np.abs(guessed_row)-np.abs(actual_row)))
233 | else:
234 | error = 1e6
235 | print('max error', "%0.8f"%error, len(component))
236 | if (error > 1e-4 * len(guessed_row) and debug) or do_print_err:
237 | print('real ', " ".join("%2.3f"%x for x in actual_row))
238 | print('guess', " ".join("%2.3f"%x for x in guessed_row))
239 | print('gap', " ".join("%2.3f"%(np.abs(x-y)) for x,y in zip(guessed_row,actual_row)))
240 | #print("scale", scale)
241 | print("--")
242 | for row in possible_matrix_rows:
243 | print('posbl', " ".join("%2.3f"%x for x in row/row[normalize_axis]))
244 | print("--")
245 |
246 | scale = 10**int(np.round(np.log(np.nanmedian(np.abs(possible_matrix_rows)))/np.log(10)))
247 | possible_matrix_rows /= scale
248 | for row in possible_matrix_rows:
249 | print('posbl', " ".join("%2.3f"%x for x in row))
250 | if np.any(np.isnan(guessed_row)) and c_count < expected_neurons:
251 | print("Got NaN, need more data",len(component)/sum(map(len,components)),1/sizes[LAYER+1])
252 | if len(component) >= 3:
253 | if c_count < expected_neurons:
254 | failure = GatherMoreData([all_criticals[x] for x in component])
255 | skips_because_of_nan += 1
256 | continue
257 |
258 | guessed_row[np.isnan(guessed_row)] = 0
259 |
260 | if c_count < expected_neurons and len(component) >= 3:
261 | resulting_rows.append(guessed_row)
262 | resulting_examples.append([all_criticals[x] for x in component])
263 | else:
264 | print("Don't add it to the set")
265 |
266 |
267 | # We set failure when something went wrong but we want to defer crashing
268 | # (so that we can use the partial solution)
269 |
270 | if len(resulting_rows)+skips_because_of_nan < expected_neurons and len(all_ratios) < DEAD_NEURON_THRESHOLD:
271 | print("We have not explored all neurons. Do more random search", len(resulting_rows), skips_because_of_nan, expected_neurons)
272 | raise AcceptableFailure(partial_solution=(np.array(resulting_rows), np.array(resulting_examples)))
273 | else:
274 | print("At this point, we just assume the neuron must be dead")
275 | while len(resulting_rows) < expected_neurons:
276 | resulting_rows.append(np.zeros_like((resulting_rows[0])))
277 | resulting_examples.append([np.zeros_like(resulting_examples[0][0])])
278 |
279 | # Here we know it's a GatherMoreData failure, but we want to only do this
280 | # if there was enough data for everything else
281 | if failure is not None:
282 | print("Need to raise a previously generated failure.")
283 | raise failure
284 |
285 |
286 | print("Successfully returning a solution attempt.\n")
287 | return resulting_examples, resulting_rows
288 |
289 | def ratio_normalize(possible_matrix_rows):
290 | # We get a set of a bunch of numbers
291 | # a1 b1 c1 d1 e1 f1 g1
292 | # a2 b2 c2 d2 e2 f2 g2
293 | # such that some of them are nan
294 | # We want to compute the pairwise ratios ignoring the nans
295 |
296 | now = time.time()
297 | ratio_evidence = [[[] for _ in range(possible_matrix_rows.shape[1])] for _ in range(possible_matrix_rows.shape[1])]
298 |
299 | for row in possible_matrix_rows:
300 | for i in range(len(row)):
301 | for j in range(len(row)):
302 | ratio_evidence[i][j].append(row[i]/row[j])
303 |
304 | if len(ratio_evidence) > 100:
305 | ratio_evidence = np.array(ratio_evidence, dtype=np.float32)
306 | else:
307 | ratio_evidence = np.array(ratio_evidence, dtype=np.float64)
308 |
309 | medians = np.nanmedian(ratio_evidence, axis=2)
310 | errors = np.nanstd(ratio_evidence, axis=2) / np.sum(~np.isnan(ratio_evidence), axis=2)**.5
311 | errors += 1e-2 * (np.sum(~np.isnan(ratio_evidence), axis=2) == 1)
312 | errors /= np.abs(medians)
313 | errors[np.isnan(errors)] = 1e6
314 |
315 | ratio_evidence = medians
316 |
317 | last_nan_count = 1e8
318 | last_total_cost = 1e8
319 |
320 | while (np.sum(np.isnan(ratio_evidence)) < last_nan_count or last_total_cost < np.sum(errors)*.9) and False:
321 | last_nan_count = np.sum(np.isnan(ratio_evidence))
322 | last_total_cost = np.sum(errors)
323 | print('.')
324 | print("Takenc", time.time()-now)
325 | print('nan count', last_nan_count)
326 | print('total cost', last_total_cost)
327 |
328 | cost_i_over_j = ratio_evidence[:,:,np.newaxis]
329 | cost_j_over_k = ratio_evidence
330 | cost_i_over_k = cost_i_over_j * cost_j_over_k
331 | del cost_i_over_j, cost_j_over_k
332 | print(cost_i_over_k.shape, cost_i_over_k.dtype)
333 |
334 | error_i_over_j = errors[:,:,np.newaxis]
335 | error_j_over_k = errors
336 | error_i_over_k = error_i_over_j + error_j_over_k
337 |
338 | best_indexs = np.nanargmin(error_i_over_k,axis=1)
339 | best_errors = np.nanmin(error_i_over_k,axis=1)
340 | del error_i_over_j, error_j_over_k, error_i_over_k
341 |
342 | cost_i_over_k_new = []
343 | for i in range(len(best_indexs)):
344 | cost_i_over_k_new.append(cost_i_over_k[i].T[np.arange(len(best_indexs)),best_indexs[i]])
345 |
346 | cost_i_over_k = np.array(cost_i_over_k_new)
347 |
348 | which = best_errors 0:
387 | if any(np.any(np.abs(x) < 1e-5) for x in known_T.get_hidden_layers(point)):
388 | continue
389 | if CHEATING:
390 | if np.any(np.abs(cheat_get_inner_layers(point)[0]) < 1e-10):
391 | print(cheat_get_inner_layers(point))
392 | print("Looking at one I don't need to")
393 |
394 |
395 | if LAYER > 0 and np.sum(known_T.forward(point) != 0) <= 1:
396 | print("Not enough hidden values are active to get meaningful data")
397 | continue
398 |
399 | if not check_fn(point):
400 | #print("Check function rejected it")
401 | continue
402 | if CHEATING:
403 | print("What layer is this neuron on (by cheating)?",
404 | [(np.min(np.abs(x)), np.argmin(np.abs(x))) for x in cheat_get_inner_layers(point)])
405 |
406 | tmp = query_count
407 | for EPS in [GRAD_EPS, GRAD_EPS/10, GRAD_EPS/100]:
408 | try:
409 | normal = get_ratios_lstsq(LAYER, [point], [range(DIM)], known_T, eps=EPS)[0].flatten()
410 | #normal = get_ratios([point], [range(DIM)], eps=EPS)[0].flatten()
411 | break
412 | except AcceptableFailure:
413 | print("Try again with smaller eps")
414 | pass
415 | #print("LSTSQ Delta queries", query_count-tmp)
416 |
417 | this_layer_critical_points.append((normal, point))
418 |
419 | # coupon collector: we need nlogn points.
420 | print("Up to", len(this_layer_critical_points), 'of', COUNT)
421 | if len(this_layer_critical_points) >= COUNT:
422 | break
423 |
424 | return this_layer_critical_points
425 |
426 | def compute_layer_values(critical_points, known_T, LAYER):
427 | if LAYER == 0:
428 | COUNT = neuron_count[LAYER+1] * 3
429 | else:
430 | COUNT = neuron_count[LAYER+1] * np.log(sizes[LAYER+1]) * 3
431 |
432 |
433 | # type: [(ratios, critical_point)]
434 | this_layer_critical_points = []
435 |
436 | partial_weights = None
437 | partial_biases = None
438 |
439 | def check_fn(point):
440 | if partial_weights is None:
441 | return True
442 | hidden = matmul(known_T.forward(point, with_relu=True), partial_weights.T, partial_biases)
443 | if np.any(np.abs(hidden) < 1e-4):
444 | return False
445 |
446 | return True
447 |
448 |
449 | print()
450 | print("Start running critical point search to find neurons on layer", LAYER)
451 | while True:
452 | print("At this iteration I have", len(this_layer_critical_points), "critical points")
453 |
454 | def reuse_critical_points():
455 | for witness in critical_points:
456 | yield witness
457 |
458 | this_layer_critical_points.extend(gather_ratios(reuse_critical_points(), known_T, check_fn,
459 | LAYER, COUNT))
460 |
461 | print("Query count after that search:", query_count)
462 | print("And now up to ", len(this_layer_critical_points), "critical points")
463 |
464 | ## filter out duplicates
465 | filtered_points = []
466 |
467 | # Let's not add points that are identical to onees we've already done.
468 | for i,(ratio1,point1) in enumerate(this_layer_critical_points):
469 | for ratio2,point2 in this_layer_critical_points[i+1:]:
470 | if np.sum((point1 - point2)**2)**.5 < 1e-10:
471 | break
472 | else:
473 | filtered_points.append((ratio1, point1))
474 |
475 | this_layer_critical_points = filtered_points
476 |
477 | print("After filtering duplicates we're down to ", len(this_layer_critical_points), "critical points")
478 |
479 |
480 | print("Start trying to do the graph solving")
481 | try:
482 | critical_groups, extracted_normals = graph_solve([x[0] for x in this_layer_critical_points],
483 | [x[1] for x in this_layer_critical_points],
484 | neuron_count[LAYER+1],
485 | LAYER=LAYER,
486 | debug=True)
487 | break
488 | except GatherMoreData as e:
489 | print("Graph solving failed because we didn't explore all sides of at least one neuron")
490 | print("Fall back to the hyperplane following algorithm in order to get more data")
491 |
492 | def mine(r):
493 | while len(r) > 0:
494 | print("Yielding a point")
495 | yield r[0]
496 | r = r[1:]
497 | print("No more to give!")
498 |
499 | prev_T = KnownT(known_T.A[:-1], known_T.B[:-1])
500 |
501 | _, more_critical_points = sign_recovery.solve_layer_sign(prev_T, known_T.A[-1], known_T.B[-1], mine(e.data),
502 | LAYER-1, already_checked_critical_points=True,
503 | only_need_positive=True)
504 |
505 | print("Add more", len(more_critical_points))
506 | this_layer_critical_points.extend(gather_ratios(more_critical_points, known_T, check_fn,
507 | LAYER, 1e6))
508 | print("Done adding")
509 |
510 | COUNT = neuron_count[LAYER+1]
511 | except AcceptableFailure as e:
512 | print("Graph solving failed; get more points")
513 | COUNT = neuron_count[LAYER+1]
514 | if 'partial_solution' in dir(e):
515 |
516 | if len(e.partial_solution[0]) > 0:
517 | partial_weights, corresponding_examples = e.partial_solution
518 | print("Got partial solution with shape", partial_weights.shape)
519 | if CHEATING:
520 | print("Corresponding to", np.argmin(np.abs(cheat_get_inner_layers([x[0] for x in corresponding_examples])[LAYER]),axis=1))
521 |
522 | partial_biases = []
523 | for weight, examples in zip(partial_weights, corresponding_examples):
524 |
525 | hidden = known_T.forward(examples, with_relu=True)
526 | print("hidden", np.array(hidden).shape)
527 | bias = -np.median(np.dot(hidden, weight))
528 | partial_biases.append(bias)
529 | partial_biases = np.array(partial_biases)
530 |
531 |
532 | print("Number of critical points per cluster", [len(x) for x in critical_groups])
533 |
534 | point_per_class = [x[0] for x in critical_groups]
535 |
536 | extracted_normals = np.array(extracted_normals).T
537 |
538 | # Compute the bias because we know wx+b=0
539 | extracted_bias = [matmul(known_T.forward(point_per_class[i], with_relu=True), extracted_normals[:,i], c=None) for i in range(neuron_count[LAYER+1])]
540 |
541 | # Don't forget to negate it.
542 | # That's important.
543 | # No, I definitely didn't forget this line the first time around.
544 | extracted_bias = -np.array(extracted_bias)
545 |
546 | # For the failed-to-identify neurons, set the bias to zero
547 | extracted_bias *= np.any(extracted_normals != 0,axis=0)[:,np.newaxis]
548 |
549 |
550 | if CHEATING:
551 | # Compute how far we off from the true matrix
552 | real_scaled = A[LAYER]/A[LAYER][0]
553 | extracted_scaled = extracted_normals/extracted_normals[0]
554 |
555 | mask = []
556 | reorder_rows = []
557 | for i in range(len(extracted_bias)):
558 | which_idx = np.argmin(np.sum(np.abs(real_scaled - extracted_scaled[:,[i]]),axis=0))
559 | reorder_rows.append(which_idx)
560 | mask.append((A[LAYER][0,which_idx]))
561 |
562 | print('matrix norm difference', np.sum(np.abs(extracted_normals*mask - A[LAYER][:,reorder_rows])))
563 | else:
564 | mask = [1]*len(extracted_bias)
565 |
566 |
567 | return extracted_normals, extracted_bias, mask
568 |
--------------------------------------------------------------------------------
/src/refine_precision.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import numpy as np
16 |
17 | import jax
18 | import jax.experimental.optimizers
19 | import jax.numpy as jnp
20 |
21 | from src.global_vars import *
22 | from src.utils import matmul, which_is_zero
23 | from src.find_witnesses import do_better_sweep
24 |
25 | def trim(hidden_layer, out, num_good):
26 | """
27 | Compute least squares in a robust-statistics manner.
28 | See Jagielski et al. 2018 S&P
29 | """
30 | lst, *rest = np.linalg.lstsq(hidden_layer, out)
31 | old = lst
32 | for _ in range(20):
33 | errs = np.abs(np.dot(hidden_layer, lst) - out)
34 | best_errs = np.argsort(errs)[:num_good]
35 | lst, *rest = np.linalg.lstsq(hidden_layer[best_errs], out[best_errs])
36 | if np.linalg.norm(old-lst) < 1e-9:
37 | return lst, best_errs
38 | old = lst
39 | return lst, best_errs
40 |
41 | def improve_row_precision(args):
42 | """
43 | Improve the precision of an extracted row.
44 | We think we know where it is, but let's actually figure it out for sure.
45 |
46 | To do this, start by sampling a bunch of points near where we expect the line to be.
47 | This gives us a picture like this
48 |
49 | X
50 | X
51 |
52 | X
53 | X
54 | X
55 | X
56 |
57 | Where some are correct and some are wrong.
58 | With some robust statistics, try to fit a line that fits through most of the points
59 | (in high dimension!)
60 |
61 | X
62 | / X
63 | /
64 | X
65 | X /
66 | /
67 | X
68 |
69 | This solves the equation and improves the point for us.
70 | """
71 | (LAYER, known_T, known_A, known_B, row, did_again) = args
72 | print("Improve the extracted neuron number", row)
73 |
74 | print(np.sum(np.abs(known_A[:,row])))
75 | if np.sum(np.abs(known_A[:,row])) < 1e-8:
76 | return known_A[:,row], known_B[row]
77 |
78 |
79 | def loss(x, r):
80 | hidden = known_T.forward(x, with_relu=True, np=jnp)
81 | dotted = matmul(hidden, jnp.array(known_A)[:,r], jnp.array(known_B)[r], np=jnp)
82 |
83 | return jnp.sum(jnp.square(dotted))
84 |
85 | loss_grad = jax.jit(jax.grad(loss))
86 | loss = jax.jit(loss)
87 |
88 | extended_T = known_T.extend_by(known_A, known_B)
89 |
90 | def get_more_points(NUM):
91 | """
92 | Gather more points. This procedure is really kind of ugly and should probably be fixed.
93 | We want to find points that are near where we expect them to be.
94 |
95 | So begin by finding preimages to points that are on the line with gradient descent.
96 | This should be completely possible, because we have d_0 input dimensions but
97 | only want to control one inner layer.
98 | """
99 | print("Gather some more actual critical points on the plane")
100 | stepsize = .1
101 | critical_points = []
102 | while len(critical_points) <= NUM:
103 | print("On this iteration I have ", len(critical_points), "critical points on the plane")
104 | points = np.random.normal(0, 1e3, size=(100,DIM,))
105 |
106 | lr = 10
107 | for step in range(5000):
108 | # Use JaX's built in optimizer to do this.
109 | # We want to adjust the LR so that we get a better solution
110 | # as we optimize. Probably there is a better way to do this,
111 | # but this seems to work just fine.
112 |
113 | # No queries involvd here.
114 | if step%1000 == 0:
115 | lr *= .5
116 | init, opt_update, get_params = jax.experimental.optimizers.adam(lr)
117 |
118 | @jax.jit
119 | def update(i, opt_state, batch):
120 | params = get_params(opt_state)
121 | return opt_update(i, loss_grad(batch, row), opt_state)
122 | opt_state = init(points)
123 |
124 | if step%100 == 0:
125 | ell = loss(points, row)
126 | if CHEATING:
127 | # This isn't cheating, but makes things prettier
128 | print(ell)
129 | if ell < 1e-5:
130 | break
131 | opt_state = update(step, opt_state, points)
132 | points = opt_state.packed_state[0][0]
133 |
134 | for point in points:
135 | # For each point, try to see where it actually is.
136 |
137 | # First, if optimization failed, then abort.
138 | if loss(point, row) > 1e-5:
139 | continue
140 |
141 | if LAYER > 0:
142 | # If wee're on a deeper layer, and if a prior layer is zero, then abort
143 | if min(np.min(np.abs(x)) for x in known_T.get_hidden_layers(point)) < 1e-4:
144 | print("is on prior")
145 | continue
146 |
147 |
148 | #print("Stepsize", stepsize)
149 | tmp = query_count
150 | solution = do_better_sweep(offset=point,
151 | low=-stepsize,
152 | high=stepsize,
153 | known_T=known_T)
154 | #print("qs", query_count-tmp)
155 | if len(solution) == 0:
156 | stepsize *= 1.1
157 | elif len(solution) > 1:
158 | stepsize /= 2
159 | elif len(solution) == 1:
160 | stepsize *= 0.98
161 | potential_solution = solution[0]
162 |
163 | hiddens = extended_T.get_hidden_layers(potential_solution)
164 |
165 |
166 | this_hidden_vec = extended_T.forward(potential_solution)
167 | this_hidden = np.min(np.abs(this_hidden_vec))
168 | if min(np.min(np.abs(x)) for x in this_hidden_vec) > np.abs(this_hidden)*0.9:
169 | critical_points.append(potential_solution)
170 | else:
171 | print("Reject it")
172 | print("Finished with a total of", len(critical_points), "critical points")
173 | return critical_points
174 |
175 |
176 | critical_points_list = []
177 | for _ in range(1):
178 | NUM = sizes[LAYER]*2
179 | critical_points_list.extend(get_more_points(NUM))
180 |
181 | critical_points = np.array(critical_points_list)
182 |
183 | hidden_layer = known_T.forward(np.array(critical_points), with_relu=True)
184 |
185 | if CHEATING:
186 | out = np.abs(matmul(hidden_layer, A[LAYER],B[LAYER]))
187 | which_neuron = int(np.median(which_is_zero(0, [out])))
188 | print("NEURON NUM", which_neuron)
189 |
190 | crit_val_0 = out[:,which_neuron]
191 |
192 | print(crit_val_0)
193 |
194 | #print(list(np.sort(np.abs(crit_val_0))))
195 | print('probability ok',np.mean(np.abs(crit_val_0)<1e-8))
196 |
197 | crit_val_1 = matmul(hidden_layer, known_A[:,row], known_B[row])
198 |
199 | best = (None, 1e6)
200 | upto = 100
201 |
202 | for iteration in range(upto):
203 | if iteration%1000 == 0:
204 | print("ITERATION", iteration, "OF", upto)
205 | if iteration%2 == 0 or True:
206 |
207 | # Try 1000 times to make sure that we get at least one non-zero per axis
208 | for _ in range(1000):
209 | randn = np.random.choice(len(hidden_layer), NUM+2, replace=False)
210 | if np.all(np.any(hidden_layer[randn] != 0, axis=0)):
211 | break
212 |
213 | hidden = hidden_layer[randn]
214 | soln,*rest = np.linalg.lstsq(hidden, np.ones(hidden.shape[0]))
215 |
216 |
217 | else:
218 | randn = np.random.choice(len(hidden_layer), min(len(hidden_layer),hidden_layer.shape[1]+20), replace=False)
219 | soln,_ = trim(hidden_layer[randn], np.ones(hidden_layer.shape[0])[randn], hidden_layer.shape[1])
220 |
221 |
222 | crit_val_2 = matmul(hidden_layer, soln, None)-1
223 |
224 | quality = np.median(np.abs(crit_val_2))
225 |
226 | if iteration%100 == 0:
227 | print('quality', quality, best[1])
228 |
229 | if quality < best[1]:
230 | best = (soln, quality)
231 |
232 | if quality < 1e-10: break
233 | if quality < 1e-10 and iteration > 1e4: break
234 | if quality < 1e-8 and iteration > 1e5: break
235 |
236 | soln, _ = best
237 |
238 | if CHEATING:
239 | print("Compare", np.median(np.abs(crit_val_0)))
240 | print("Compare",
241 | np.median(np.abs(crit_val_1)),
242 | best[1])
243 |
244 | if np.all(np.abs(soln) > 1e-10):
245 | break
246 |
247 | print('soln',soln)
248 |
249 | if np.any(np.abs(soln) < 1e-10):
250 | print("THIS IS BAD. FIX ME NOW.")
251 | exit(1)
252 |
253 | rescale = np.median(soln/known_A[:,row])
254 | soln[np.abs(soln) < 1e-10] = known_A[:,row][np.abs(soln) < 1e-10] * rescale
255 |
256 | if CHEATING:
257 | other = A[LAYER][:,which_neuron]
258 | print("real / mine / diff")
259 | print(other/other[0])
260 | print(soln/soln[0])
261 | print(known_A[:,row]/known_A[:,row][0])
262 | print(other/other[0] - soln/soln[0])
263 |
264 |
265 | if best[1] < np.mean(np.abs(crit_val_1)) or True:
266 | return soln, -1
267 | else:
268 | print("FAILED TO IMPROVE ACCURACY OF ROW", row)
269 | print(np.mean(np.abs(crit_val_2)), 'vs', np.mean(np.abs(crit_val_1)))
270 | return known_A[:,row], known_B[row]
271 |
272 |
273 | def improve_layer_precision(LAYER, known_T, known_A, known_B):
274 | new_A = []
275 | new_B = []
276 |
277 | out = map(improve_row_precision,
278 | [(LAYER, known_T, known_A, known_B, row, False) for row in range(neuron_count[LAYER+1])])
279 | new_A, new_B = zip(*out)
280 |
281 | new_A = np.array(new_A).T
282 | new_B = np.array(new_B)
283 |
284 | print("HAVE", new_A, new_B)
285 |
286 | return new_A, new_B
287 |
--------------------------------------------------------------------------------
/src/sign_recovery.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import numpy as np
16 | import jax
17 | import jax.numpy as jnp
18 | import scipy.linalg
19 | import scipy.signal
20 | import time
21 |
22 | from src.global_vars import *
23 | from src.utils import run, get_polytope_at, get_hidden_at, AcceptableFailure, KnownT, matmul, cheat_get_inner_layers, which_is_zero
24 | from src.hyperplane_normal import get_ratios_lstsq, get_ratios
25 | from src.find_witnesses import do_better_sweep
26 |
27 |
28 | def sign_to_int(signs):
29 | """
30 | Convert a list to an integer.
31 | [-1, 1, 1, -1], -> 0b0110 -> 6
32 | """
33 | return int("".join('0' if x == -1 else '1' for x in signs),2)
34 |
35 | def is_on_following_layer(known_T, known_A, known_B, point):
36 |
37 | print("Check if the critical point is on the next layer")
38 |
39 | def is_on_prior_layer(query):
40 | print("Hidden think", known_T.get_hidden_layers(query))
41 | if CHEATING:
42 | print("Hidden real", cheat_get_inner_layers(query))
43 | if any(np.min(np.abs(layer)) < 1e-5 for layer in known_T.get_hidden_layers(query)):
44 | return True
45 | next_hidden = known_T.extend_by(known_A, known_B).forward(query)
46 | print(next_hidden)
47 | if np.min(np.abs(next_hidden)) < 1e-4:
48 | return True
49 | return False
50 |
51 | if is_on_prior_layer(point):
52 | print("It's not, because it's on an earlier layer")
53 | return False
54 |
55 | if CHEATING:
56 | ls = ([np.min(np.abs(x)) for x in cheat_get_inner_layers(point)])
57 |
58 | initial_signs = get_polytope_at(known_T, known_A, known_B, point)
59 |
60 | normal = get_ratios([point], [range(DIM)], eps=GRAD_EPS)[0].flatten()
61 | normal = normal / np.sum(normal**2)**.5
62 |
63 | for tol in range(10):
64 |
65 | random_dir = np.random.normal(size=DIM)
66 | perp_component = np.dot(random_dir,normal)/(np.dot(normal, normal)) * normal
67 | parallel_dir = random_dir - perp_component
68 |
69 | go_direction = parallel_dir/np.sum(parallel_dir**2)**.5
70 |
71 | _, high = binary_search_towards(known_T,
72 | known_A, known_B,
73 | point,
74 | initial_signs,
75 | go_direction)
76 |
77 | if CHEATING:
78 | print(cheat_get_inner_layers(point + go_direction * high/2)[np.argmin(ls)])
79 |
80 | point_in_same_polytope = point + (high * .999 - 1e-4) * go_direction
81 |
82 | print("high", high)
83 |
84 | solutions = do_better_sweep(point_in_same_polytope,
85 | normal,
86 | -1e-4 * high, 1e-4 * high,
87 | known_T=known_T)
88 | if len(solutions) >= 1:
89 | print("Correctly found", len(solutions))
90 | else:
91 | return False
92 |
93 | point_in_different_polytope = point + (high * 1.1 + 1e-1) * go_direction
94 |
95 | solutions = do_better_sweep(point_in_different_polytope,
96 | normal,
97 | -1e-4 * high, 1e-4 * high,
98 | known_T=known_T)
99 | if len(solutions) == 0:
100 | print("Correctly found", len(solutions))
101 | else:
102 | return False
103 |
104 |
105 | #print("I THINK IT'S ON THE NEXT LAYER")
106 | if CHEATING:
107 | soln = [np.min(np.abs(x)) for x in cheat_get_inner_layers(point)]
108 | print(soln)
109 | assert np.argmin(soln) == len(known_T.A)+1
110 |
111 | return True
112 |
113 | def find_plane_angle(known_T,
114 | known_A, known_B,
115 | multiple_intersection_point,
116 | sign_at_init,
117 | init_step,
118 | exponential_base=1.5):
119 | """
120 | Given an input that's at the multiple intersection point, figure out how
121 | to continue along the path after it bends.
122 |
123 |
124 | / X : multiple intersection point
125 | ......../.. ---- : layer N hyperplane
126 | . / . | : layer N+1 hyperplane that bends
127 | . / .
128 | --------X-----------
129 | . | .
130 | . | .
131 | .....|.....
132 | |
133 | |
134 |
135 | We need to make sure to bend, and not turn onto the layer N hyperplane.
136 |
137 | To do this we will draw a box around the X and intersect with the planes
138 | and determine the four coordinates. Then draw another box twice as big.
139 |
140 | The first layer plane will be the two points at a consistent angle.
141 | The second layer plane will have an inconsistent angle.
142 |
143 | Choose the inconsistent angle plane, and make sure we move to a new
144 | polytope and don't just go backwards to where we've already bene.
145 | """
146 | success = None
147 | camefrom = None
148 |
149 | prev_iter_intersections = []
150 |
151 | while True:
152 | x_dir_base = np.sign(np.random.normal(size=DIM))/DIM**.5
153 | y_dir_base = np.sign(np.random.normal(size=DIM))/DIM**.5
154 | # When the input dimension is odd we can't have two orthogonal
155 | # vectors from {-1,1}^DIM
156 | if np.abs(np.dot(x_dir_base, y_dir_base)) <= DIM%2 + 1e-8:
157 | break
158 |
159 | MAX = 35
160 |
161 | start = [10] if init_step > 10 else []
162 | for stepsize in start + list(range(init_step, MAX)):
163 | print("\tTry stepping away", stepsize)
164 | x_dir = x_dir_base * (exponential_base**(stepsize-10))
165 | y_dir = y_dir_base * (exponential_base**(stepsize-10))
166 |
167 | # Draw the box as shown in the diagram above, and compute where
168 | # the critical points are.
169 | top = do_better_sweep(multiple_intersection_point + x_dir,
170 | y_dir, -1, 1,
171 | known_T=known_T)
172 | bot = do_better_sweep(multiple_intersection_point - x_dir,
173 | y_dir, -1, 1,
174 | known_T=known_T)
175 | left = do_better_sweep(multiple_intersection_point + y_dir,
176 | x_dir, -1, 1,
177 | known_T=known_T)
178 | right = do_better_sweep(multiple_intersection_point - y_dir,
179 | x_dir, -1, 1,
180 | known_T=known_T)
181 |
182 | intersections = top + bot + left + right
183 |
184 | # If we only have two critical points, and we're taking a big step,
185 | # then something is seriously messed up.
186 | # This is not an acceptable error. Just abort out and let's try to
187 | # do the whole thing again.
188 | if len(intersections) == 2 and stepsize >= 10:
189 | raise AcceptableFailure()
190 |
191 | if CHEATING:
192 | print("\tHAVE BOX INTERSECT COUNT", len(intersections))
193 | print("\t",len(left), len(right), len(top), len(bot))
194 |
195 | if (len(intersections) == 0 and stepsize > 15):# or (len(intersections) == 3 and stepsize > 5):
196 | # Probably we're in just a constant flat 0 region
197 | # At this point we're basically dead in the water.
198 | # Just fail up and try again.
199 | print("\tIt looks like we're in a flat region, raise failure")
200 | raise AcceptableFailure()
201 |
202 | # If we somehow went from almost no critical points to more than 4,
203 | # then we've really messed up.
204 | # Just fail out and let's hope next time it doesn't happen.
205 | if len(intersections) > 4 and len(prev_iter_intersections) < 2:
206 | print("\tWe didn't get enough inner points")
207 | if exponential_base == 1.2:
208 | print("\tIt didn't work a second time")
209 | return None, None, 0
210 | else:
211 | print("\tTry with smaller step")
212 | return find_plane_angle(known_T,
213 | known_A, known_B,
214 | multiple_intersection_point,
215 | sign_at_init,
216 | init_step,
217 | exponential_base=1.2)
218 |
219 | # This is the good, expected code-path.
220 | # We've seen four intersections at least twice before, and now
221 | # we're seeing more than 4.
222 | if (len(intersections) > 4 or stepsize > 20) and len(prev_iter_intersections) >= 2:
223 | next_intersections = np.array(prev_iter_intersections[-1])
224 | intersections = np.array(prev_iter_intersections[-2])
225 |
226 | # Let's first figure out what points are responsible for the prior-layer neurons
227 | # being zero, and which are from the current-layer neuron being zero
228 | candidate = []
229 | for i,a in enumerate(intersections):
230 | for j,b in enumerate(intersections):
231 | if i == j: continue
232 | score = np.sum(((a+b)/2-multiple_intersection_point)**2)
233 | a_to_b = b-a
234 | a_to_b /= np.sum(a_to_b**2)**.5
235 |
236 | variance = np.std((next_intersections-a)/a_to_b,axis=1)
237 | best_variance = np.min(variance)
238 |
239 | #print(i,j,score, best_variance)
240 |
241 | candidate.append((best_variance, i, j))
242 |
243 | if sorted(candidate)[3][0] < 1e-8:
244 | # It looks like both lines are linear here
245 | # We can't distinguish what way is the next best way to go.
246 | print("\tFailed the box continuation finding procedure. (1)")
247 | print("\t",candidate)
248 | raise AcceptableFailure()
249 |
250 | # Sometimes life is just ugly, and nothing wants to work.
251 | # Just abort.
252 | err, index_0, index_1 = min(candidate)
253 | if err/max(candidate)[0] > 1e-5:
254 | return None, None, 0
255 |
256 | prior_layer_near_zero = np.zeros(4, dtype=np.bool)
257 | prior_layer_near_zero[index_0] = True
258 | prior_layer_near_zero[index_1] = True
259 |
260 | # Now let's walk through each of these points and check that everything looks sane.
261 | should_fail = False
262 | for critical_point, is_prior_layer_zero in zip(intersections,prior_layer_near_zero):
263 | vs = known_T.extend_by(known_A,known_B).get_hidden_layers(critical_point)
264 | #print("VS IS", vs)
265 | #print("Is prior", is_prior_layer_zero)
266 | #if CHEATING:
267 | # print(cheat_get_inner_layers(critical_point))
268 |
269 | if is_prior_layer_zero:
270 | # We expect the prior layer to be zero.
271 | if all([np.min(np.abs(x)) > 1e-5 for x in vs]):
272 | # If it looks like it's not actually zero, then brutally fail.
273 | print("\tAbort 1: failed to find a valid box")
274 | should_fail = True
275 | if any([np.min(np.abs(x)) < 1e-10 for x in vs]):
276 | # We expect the prior layer to be zero.
277 | if not is_prior_layer_zero:
278 | # If it looks like it's not actually zero, then brutally fail.
279 | print("\tAbort 2: failed to find a valid box")
280 | should_fail = True
281 | if should_fail:
282 | return None, None, 0
283 |
284 |
285 |
286 | # Done with error checking, life is good here.
287 | # Find the direction that corresponds to the next direction we can move in
288 | # and continue our search from that point.
289 | for critical_point, is_prior_layer_zero in zip(intersections,prior_layer_near_zero):
290 | sign_at_crit = sign_to_int(get_polytope_at(known_T,
291 | known_A, known_B,
292 | critical_point))
293 | print("\tMove to", sign_at_crit, 'versus', sign_at_init, is_prior_layer_zero)
294 | if not is_prior_layer_zero:
295 | if sign_at_crit != sign_at_init:
296 | success = critical_point
297 | if CHEATING:
298 | print('\tinner at success', cheat_get_inner_layers(success))
299 | print("\tSucceeded")
300 | else:
301 | camefrom = critical_point
302 |
303 | # If we didn't get a solution, then abort out.
304 | # Probably what happened here is that we got more than four points
305 | # on the box but didn't see exactly four points on the box twice before
306 | # this means we should decrease the initial step size and try again.
307 | if success is None:
308 | print("\tFailed the box continuation finding procedure. (2)")
309 | raise AcceptableFailure()
310 | #assert success is not None
311 | break
312 | if len(intersections) == 4:
313 | prev_iter_intersections.append(intersections)
314 | return success, camefrom, min(stepsize, MAX-3)
315 |
316 | def binary_search_towards_slow(known_T, known_A, known_B, start_point, initial_signs, go_direction, maxstep=1e6):
317 | low = 0
318 | high = maxstep
319 | while high-low > 1e-8:
320 | mid = (high+low)/2
321 | query_point = start_point + mid * go_direction
322 |
323 | next_signs = get_polytope_at(known_T, known_A, known_B,
324 | query_point)
325 | if initial_signs == next_signs:
326 | low = mid
327 | else:
328 | high = mid
329 |
330 |
331 | #print('check',np.abs(mid - can_go_dist))
332 |
333 | next_signs = get_polytope_at(known_T, known_A, known_B,
334 | start_point + low * go_direction)
335 | if next_signs != initial_signs:
336 | # It is extremely unlikely, but possible, for us to end up
337 | # skipping over the region of interest.
338 | # If this happens then don't step as far and try again.
339 | # This has only ever happend once, but just in case....
340 | print("Well this is awkward")
341 | return binary_search_towards(known_T, known_A, known_B, start_point, initial_signs, go_direction, maxstep=maxstep/10)
342 |
343 |
344 | # If mid is at the end, it means it never binary searched.
345 | if mid > 1e6-1:
346 | return None, None
347 | else:
348 | a_bit_further = start_point + (high+1e-4)*go_direction
349 | return a_bit_further, high
350 |
351 |
352 | PREV_GRAD = None
353 |
354 | def binary_search_towards(known_T, known_A, known_B, start_point, initial_signs, go_direction, maxstep=1e6):
355 | """
356 | Compute how far we can walk along the hyperplane until it is in a
357 | different polytope from a prior layer.
358 |
359 | It is okay if it's in a differnt polytope in a *later* layer, because
360 | it will still have the same angle.
361 |
362 | (but do it analytically by looking at the signs of the first layer)
363 | this requires no queries and could be done with math but instead
364 | of thinking I'm just going to run binary search.
365 | """
366 | global PREV_GRAD
367 |
368 | #_, slow_ans = binary_search_towards_slow(known_T, known_A, known_B, start_point, initial_signs, go_direction, maxstep)
369 |
370 | plus_T = known_T.extend_by(known_A, known_B)
371 | # this is the hidden state
372 | initial_hidden = np.array(plus_T.get_hidden_layers(start_point, flat=True))
373 | delta_hidden_np = (np.array(plus_T.get_hidden_layers(start_point + 1e-6 * go_direction, flat=True)) - initial_hidden) * 1e6
374 | #
375 | #can_go_dist_all = initial_hidden / delta_hidden
376 |
377 | if PREV_GRAD is None or PREV_GRAD[0] is not known_T or PREV_GRAD[1] is not known_A or PREV_GRAD[2] is not known_B:
378 | def get_grad(x, i):
379 | initial_hidden = plus_T.get_hidden_layers(x, flat=True, np=jnp)
380 | return initial_hidden[i]
381 | g = jax.jit(jax.grad(get_grad))
382 |
383 |
384 | def grads(start_point, go_direction):
385 | return jnp.array([jnp.dot(g(start_point, i), go_direction) for i in range(initial_hidden.shape[0])])
386 |
387 | PREV_GRAD = (known_T, known_A, known_B, jax.jit(grads))
388 | else:
389 | grads = PREV_GRAD[3]
390 |
391 | delta_hidden = grads(start_point, go_direction)
392 |
393 | can_go_dist_all = np.array(initial_hidden / delta_hidden)
394 |
395 | can_go_dist = -can_go_dist_all[can_go_dist_all<0]
396 |
397 | if len(can_go_dist) == 0:
398 | print("Can't go anywhere at all")
399 | raise AcceptableFailure()
400 |
401 | can_go_dist = np.min(can_go_dist)
402 |
403 | a_bit_further = start_point + (can_go_dist+1e-4)*go_direction
404 | return a_bit_further, can_go_dist
405 |
406 | def follow_hyperplane(LAYER, start_point, known_T, known_A, known_B,
407 | history=[], MAX_POINTS=1e3, only_need_positive=False):
408 | """
409 | This is the ugly algorithm that will let us recover sign for expansive networks.
410 | Assumes we have extracted up to layer K-1 correctly, and layer K up to sign.
411 |
412 | start_point is a neuron on layer K+1
413 |
414 | known_T is the transformation that computes up to layer K-1, with
415 | known_A and known_B being the layer K matrix up to sign.
416 |
417 | We're going to come up with a bunch of different inputs,
418 | each of which has the same critical point held constant at zero.
419 | """
420 |
421 | def choose_new_direction_from_minimize(previous_axis):
422 | """
423 | Given the current point which is at a critical point of the next
424 | layer neuron, compute which direction we should travel to continue
425 | with finding more points on this hyperplane.
426 |
427 | Our goal is going to be to pick a direction that lets us explore
428 | a new part of the space we haven't seen before.
429 | """
430 |
431 | print("Choose a new direction to travel in")
432 | if len(history) == 0:
433 | which_to_change = 0
434 | new_perp_dir = perp_dir
435 | new_start_point = start_point
436 | initial_signs = get_polytope_at(known_T, known_A, known_B, start_point)
437 |
438 | # If we're in the 1 region of the polytope then we try to make it smaller
439 | # otherwise make it bigger
440 | fn = min if initial_signs[0] == 1 else max
441 | else:
442 | neuron_values = np.array([x[1] for x in history])
443 |
444 | neuron_positive_count = np.sum(neuron_values>1,axis=0)
445 | neuron_negative_count = np.sum(neuron_values<-1,axis=0)
446 |
447 | mean_plus_neuron_value = neuron_positive_count/(neuron_positive_count + neuron_negative_count + 1)
448 | mean_minus_neuron_value = neuron_negative_count/(neuron_positive_count + neuron_negative_count + 1)
449 |
450 | # we want to find values that are consistently 0 or 1
451 | # So map 0 -> 0 and 1 -> 0 and the middle to higher values
452 | if only_need_positive:
453 | neuron_consistency = mean_plus_neuron_value
454 | else:
455 | neuron_consistency = mean_plus_neuron_value * mean_minus_neuron_value
456 |
457 | # Print out how much progress we've made.
458 | # This estimate is probably worse than Windows 95's estimated time remaining.
459 | # At least it's monotonic. Be thankful for that.
460 | print("Progress", "%.1f"%int(np.mean(neuron_consistency!=0)*100)+"%")
461 | print("Counts on each side of each neuron")
462 | print(neuron_positive_count)
463 | print(neuron_negative_count)
464 |
465 |
466 | # Choose the smallest value, which is the most consistent
467 | which_to_change = np.argmin(neuron_consistency)
468 |
469 | print("Try to explore the other side of neuron", which_to_change)
470 |
471 | if which_to_change != previous_axis:
472 | if previous_axis is not None and neuron_consistency[previous_axis] == neuron_consistency[which_to_change]:
473 | # If the previous thing we were working towards has the same value as this one
474 | # the don't change our mind and just keep going at that one
475 | # (almost always--sometimes we can get stuck, let us get unstuck)
476 | which_to_change = previous_axis
477 | new_start_point = start_point
478 | new_perp_dir = perp_dir
479 | else:
480 | valid_axes = np.where(neuron_consistency == neuron_consistency[which_to_change])[0]
481 |
482 | best = (np.inf, None, None)
483 |
484 | for _, potential_hidden_vector, potential_point in history[-1:]:
485 | for potential_axis in valid_axes:
486 | value = potential_hidden_vector[potential_axis]
487 | if np.abs(value) < best[0]:
488 | best = (np.abs(value), potential_axis, potential_point)
489 |
490 | _, which_to_change, new_start_point = best
491 | new_perp_dir = perp_dir
492 |
493 | else:
494 | new_start_point = start_point
495 | new_perp_dir = perp_dir
496 |
497 |
498 | # If we're in the 1 region of the polytope then we try to make it smaller
499 | # otherwise make it bigger
500 | fn = min if neuron_positive_count[which_to_change] > neuron_negative_count[which_to_change] else max
501 | arg_fn = np.argmin if neuron_positive_count[which_to_change] > neuron_negative_count[which_to_change] else np.argmax
502 | print("Changing", which_to_change, 'to flip sides because mean is', mean_plus_neuron_value[which_to_change])
503 |
504 |
505 | val = matmul(known_T.forward(new_start_point, with_relu=True), known_A, known_B)[which_to_change]
506 |
507 | initial_signs = get_polytope_at(known_T, known_A, known_B, new_start_point)
508 |
509 | # Now we're going to figure out what direction makes this biggest/smallest
510 | # this doesn't take any queries
511 | # There's probably an analytical way to do this.
512 | # But thinking is hard. Just try 1000 random angles.
513 | # There are no queries involved in this process.
514 |
515 | choices = []
516 | for _ in range(1000):
517 | random_dir = np.random.normal(size=DIM)
518 | perp_component = np.dot(random_dir,new_perp_dir)/(np.dot(new_perp_dir, new_perp_dir)) * new_perp_dir
519 | parallel_dir = random_dir - perp_component
520 |
521 | # This is the direction we're going to travel in.
522 | go_direction = parallel_dir/np.sum(parallel_dir**2)**.5
523 |
524 | try:
525 | a_bit_further, high = binary_search_towards(known_T,
526 | known_A, known_B,
527 | new_start_point,
528 | initial_signs,
529 | go_direction)
530 | except AcceptableFailure:
531 | continue
532 | if a_bit_further is None:
533 | continue
534 |
535 | # choose a direction that makes the Kth value go down by the most
536 | val = matmul(known_T.forward(a_bit_further[np.newaxis,:], with_relu=True), known_A, known_B)[0][which_to_change]
537 |
538 | #print('\t', val, high)
539 |
540 | choices.append([val,
541 | new_start_point + high*go_direction])
542 |
543 |
544 | best_value, multiple_intersection_point = fn(choices, key=lambda x: x[0])
545 |
546 | print('Value', best_value)
547 | return new_start_point, multiple_intersection_point, which_to_change
548 |
549 | ###################################################
550 | ### Actual code to do the sign recovery starts. ###
551 | ###################################################
552 |
553 | start_box_step = 0
554 | points_on_plane = []
555 |
556 | if CHEATING:
557 | layer = np.abs(cheat_get_inner_layers(np.array(start_point))[LAYER+1])
558 | print("Layer", layer)
559 | which_is_zero = np.argmin(layer)
560 |
561 | current_change_axis = 0
562 |
563 | while True:
564 | print("\n\n")
565 | print("-----"*10)
566 |
567 | if CHEATING:
568 | layer = np.abs(cheat_get_inner_layers(np.array(start_point))[LAYER+1])
569 | #print('layer',LAYER+1, layer)
570 | #print('all inner layers')
571 | #for e in cheat_get_inner_layers(np.array(start_point)):
572 | # print(e)
573 | which_is_zero_2 = np.argmin(np.abs(layer))
574 |
575 | if which_is_zero_2 != which_is_zero:
576 | print("STARTED WITH", which_is_zero, "NOW IS", which_is_zero_2)
577 | print(layer)
578 | raise
579 |
580 | # Keep track of where we've been, so we can go to new places.
581 | which_polytope = get_polytope_at(known_T, known_A, known_B, start_point, False) # [-1 1 -1]
582 | hidden_vector = get_hidden_at(known_T, known_A, known_B, LAYER, start_point, False)
583 | sign_at_init = sign_to_int(which_polytope) # 0b010 -> 2
584 |
585 | print("Number of collected points", len(points_on_plane))
586 | if len(points_on_plane) > MAX_POINTS:
587 | return points_on_plane, False
588 |
589 | neuron_values = np.array([x[1] for x in history])
590 |
591 | neuron_positive_count = np.sum(neuron_values>1,axis=0)
592 | neuron_negative_count = np.sum(neuron_values<-1,axis=0)
593 |
594 | if (np.all(neuron_positive_count > 0) and np.all(neuron_negative_count > 0)) or \
595 | (only_need_positive and np.all(neuron_positive_count > 0)):
596 | print("Have all the points we need (1)")
597 | print(query_count)
598 | print(neuron_positive_count)
599 | print(neuron_negative_count)
600 |
601 | neuron_values = np.array([get_hidden_at(known_T, known_A, known_B, LAYER, x, False) for x in points_on_plane])
602 |
603 | neuron_positive_count = np.sum(neuron_values>1,axis=0)
604 | neuron_negative_count = np.sum(neuron_values<-1,axis=0)
605 |
606 | print(neuron_positive_count)
607 | print(neuron_negative_count)
608 |
609 | return points_on_plane, True
610 |
611 | # 1. find a way to move along the hyperplane by computing the normal
612 | # direction using the ratios function. Then find a parallel direction.
613 |
614 | try:
615 | #perp_dir = get_ratios([start_point], [range(DIM)], eps=1e-4)[0].flatten()
616 | perp_dir = get_ratios_lstsq(0, [start_point], [range(DIM)], KnownT([], []), eps=1e-5)[0].flatten()
617 |
618 | except AcceptableFailure:
619 | print("Failed to compute ratio at start point. Something very bad happened.")
620 | return points_on_plane, False
621 |
622 | # Record these points.
623 | history.append((which_polytope,
624 | hidden_vector,
625 | np.copy(start_point)))
626 |
627 | # We can't just pick any parallel direction. If we did, then we would
628 | # not end up covering much of the input space.
629 |
630 | # Instead, we're going to figure out which layer-1 hyperplanes are "visible"
631 | # from the current point. Then we're going to try and go reach all of them.
632 |
633 | # This is the point at which the first and second layers intersect.
634 | start_point, multiple_intersection_point, new_change_axis = choose_new_direction_from_minimize(current_change_axis)
635 |
636 | if new_change_axis != current_change_axis:
637 | start_point, multiple_intersection_point, current_change_axis = choose_new_direction_from_minimize(None)
638 |
639 | #if CHEATING:
640 | # print("INIT MULTIPLE", cheat_get_inner_layers(multiple_intersection_point))
641 |
642 | # Refine the direction we're going to travel in---stay numerically stable.
643 | towards_multiple_direction = multiple_intersection_point - start_point
644 | step_distance = np.sum(towards_multiple_direction**2)**.5
645 |
646 | print("Distance we need to step:", step_distance)
647 |
648 | if step_distance > 1 or True:
649 | mid_point = 1e-4 * towards_multiple_direction/np.sum(towards_multiple_direction**2)**.5 + start_point
650 |
651 | random_dir = np.random.normal(size=DIM)
652 |
653 | mid_points = do_better_sweep(mid_point, perp_dir/np.sum(perp_dir**2)**.5,
654 | low=-1e-3,
655 | high=1e-3,
656 | known_T=known_T)
657 |
658 | if len(mid_points) > 0:
659 | mid_point = mid_points[np.argmin(np.sum((mid_point-mid_points)**2,axis=1))]
660 |
661 | towards_multiple_direction = mid_point - start_point
662 | towards_multiple_direction = towards_multiple_direction/np.sum(towards_multiple_direction**2)**.5
663 |
664 | initial_signs = get_polytope_at(known_T, known_A, known_B, start_point)
665 | _, high = binary_search_towards(known_T,
666 | known_A, known_B,
667 | start_point,
668 | initial_signs,
669 | towards_multiple_direction)
670 |
671 | multiple_intersection_point = towards_multiple_direction * high + start_point
672 |
673 |
674 | # Find the angle of the next hyperplane
675 | # First, take random steps away from the intersection point
676 | # Then run the search algorithm to find some intersections
677 | # what we find will either be a layer-1 or layer-2 intersection.
678 |
679 | print("Now try to find the continuation direction")
680 | success = None
681 | while success is None:
682 | if start_box_step < 0:
683 | start_box_step = 0
684 | print("VERY BAD FAILURE")
685 | print("Choose a new random point to start from")
686 | which_point = np.random.randint(0, len(history))
687 | start_point = history[which_point][2]
688 | print("New point is", which_point)
689 | current_change_axis = np.random.randint(0, sizes[LAYER+1])
690 | print("New axis to change", current_change_axis)
691 | break
692 |
693 | print("\tStart the box step with size", start_box_step)
694 | try:
695 | success, camefrom, stepsize = find_plane_angle(known_T,
696 | known_A, known_B,
697 | multiple_intersection_point,
698 | sign_at_init,
699 | start_box_step)
700 | except AcceptableFailure:
701 | # Go back to the top and try with a new start point
702 | print("\tOkay we need to try with a new start point")
703 | start_box_step = -10
704 |
705 | start_box_step -= 2
706 |
707 | if success is None:
708 | continue
709 |
710 | val = matmul(known_T.forward(multiple_intersection_point, with_relu=True), known_A, known_B)[new_change_axis]
711 | print("Value at multiple:", val)
712 | val = matmul(known_T.forward(success, with_relu=True), known_A, known_B)[new_change_axis]
713 | print("Value at success:", val)
714 |
715 | if stepsize < 10:
716 | new_move_direction = success - multiple_intersection_point
717 |
718 | # We don't want to be right next to the multiple intersection point.
719 | # So let's binary search to find how far away we can go while remaining in this polytope.
720 | # Then we'll go half as far as we can maximally go.
721 |
722 | initial_signs = get_polytope_at(known_T, known_A, known_B, success)
723 | print("polytope at initial", sign_to_int(initial_signs))
724 | low = 0
725 | high = 1
726 | while high-low > 1e-2:
727 | mid = (high+low)/2
728 | query_point = multiple_intersection_point + mid * new_move_direction
729 | next_signs = get_polytope_at(known_T, known_A, known_B, query_point)
730 | print("polytope at", mid, sign_to_int(next_signs), "%x"%(sign_to_int(next_signs)^sign_to_int(initial_signs)))
731 | if initial_signs == next_signs:
732 | low = mid
733 | else:
734 | high = mid
735 | print("GO TO", mid)
736 |
737 | success = multiple_intersection_point + (mid/2) * new_move_direction
738 |
739 | val = matmul(known_T.forward(success, with_relu=True), known_A, known_B)[new_change_axis]
740 | print("Value at moved success:", val)
741 |
742 | print("Adding the points to the set of known good points")
743 |
744 | points_on_plane.append(start_point)
745 |
746 | if camefrom is not None:
747 | points_on_plane.append(camefrom)
748 | #print("Old start point", start_point)
749 | #print("Set to success", success)
750 | start_point = success
751 | start_box_step = max(stepsize-1,0)
752 |
753 | return points_on_plane, False
754 |
755 | def is_solution_map(args):
756 | bounds, extra_tuple = args
757 | r = []
758 | for i in range(bounds[0], bounds[1]):
759 | r.append(is_solution((i, extra_tuple)))
760 | return r
761 |
762 |
763 | def is_solution(input_tuple):
764 | signs, (known_A0, known_B0, LAYER, known_hidden_so_far, K, responses) = input_tuple
765 | new_signs = np.array([-1 if x == '0' else 1 for x in bin((1< 1e-2:
801 | return (res, new_signs, solution), 0
802 |
803 | bias = bias.mean(axis=0)
804 |
805 | #solution = np.concatenate([solution, [-bias]])[:, np.newaxis]
806 | mat = (solution/solution[0][0])[:-1,:]
807 | if np.any(np.isnan(mat)) or np.any(np.isinf(mat)):
808 | print("Invalid solution")
809 | return (res, new_signs, solution), 0
810 | else:
811 | s = solution/solution[0][0]
812 | s[np.abs(s)<1e-14] = 0
813 |
814 | return (res, new_signs, solution), 1
815 |
816 | def solve_contractive_sign(known_T, weight, bias, LAYER):
817 |
818 | print("Solve the extraction problem for contractive networks")
819 |
820 | def get_preimage(hidden):
821 | preimage = hidden
822 |
823 | for i,(my_A,my_B) in reversed(list(enumerate(zip(known_T.A+[weight], known_T.B+[bias])))):
824 | if i == 0:
825 | res = scipy.optimize.lsq_linear(my_A.T, preimage-my_B,
826 | bounds=(-np.inf, np.inf))
827 | else:
828 | res = scipy.optimize.lsq_linear(my_A.T, preimage-my_B,
829 | bounds=(0, np.inf))
830 |
831 | preimage = res.x
832 | return preimage[np.newaxis,:]
833 |
834 | hidden = np.zeros((sizes[LAYER+1]))
835 |
836 | preimage = get_preimage(hidden)
837 |
838 | extended_T = known_T.extend_by(weight,bias)
839 |
840 | standard_out = run(preimage)
841 |
842 | signs = []
843 |
844 | for axis in range(len(hidden)):
845 | h = np.array(hidden)
846 | h[axis] = 10
847 | preimage_plus = get_preimage(h)
848 | h[axis] = -10
849 | preimage_minus = get_preimage(h)
850 |
851 | print("Confirm preimage")
852 |
853 | if np.any(extended_T.forward(preimage) > 1e-5):
854 | raise AcceptableFailure()
855 |
856 | out_plus = run(preimage_plus)
857 | out_minus = run(preimage_minus)
858 |
859 | print(standard_out, out_plus, out_minus)
860 |
861 | inverted_if_small = np.sum(np.abs(out_plus-standard_out))
862 | not_inverted_if_small = np.sum(np.abs(out_minus-standard_out))
863 |
864 | print("One of these should be small",
865 | inverted_if_small,
866 | not_inverted_if_small)
867 |
868 | if inverted_if_small < not_inverted_if_small:
869 | signs.append(-1)
870 | else:
871 | signs.append(1)
872 | return signs
873 |
874 | def solve_layer_sign(known_T, known_A0, known_B0, critical_points, LAYER,
875 | already_checked_critical_points=False,
876 | only_need_positive=False, l1_mask=None):
877 | """
878 | Compute the signs for one layer of the network.
879 |
880 | known_T is the transformation that computes up to layer K-1, with
881 | known_A and known_B being the layer K matrix up to sign.
882 | """
883 |
884 | def get_critical_points():
885 | print("Init")
886 | print(critical_points)
887 | for point in critical_points:
888 | print("Tick")
889 | if already_checked_critical_points or is_on_following_layer(known_T, known_A0, known_B0, point):
890 | print("Found layer N point at ", point, already_checked_critical_points)
891 | yield point
892 |
893 | get_critical_point = get_critical_points()
894 |
895 |
896 | print("Start looking for critical point")
897 | MAX_POINTS = 200
898 | which_point = next(get_critical_point)
899 | print("Done looking for critical point")
900 |
901 | initial_points = []
902 | history = []
903 | pts = []
904 | if already_checked_critical_points:
905 | for point in get_critical_point:
906 | initial_points.append(point)
907 | pts.append(point)
908 | which_polytope = get_polytope_at(known_T, known_A0, known_B0, point, False) # [-1 1 -1]
909 | hidden_vector = get_hidden_at(known_T, known_A0, known_B0, LAYER, point, False)
910 | if CHEATING:
911 | layers = cheat_get_inner_layers(point)
912 | print('have',[(np.argmin(np.abs(x)),np.min(np.abs(x))) for x in layers])
913 | history.append((which_polytope,
914 | hidden_vector,
915 | np.copy(point)))
916 |
917 |
918 | while True:
919 | if not already_checked_critical_points:
920 | history = []
921 | pts = []
922 |
923 | prev_count = -10
924 | good = False
925 | while len(pts) > prev_count+2:
926 | print("======"*10)
927 | print("RESTART SEARCH", len(pts), prev_count)
928 | print(which_point)
929 | prev_count = len(pts)
930 | more_points, done = follow_hyperplane(LAYER, which_point,
931 | known_T,
932 | known_A0, known_B0,
933 | history=history,
934 | only_need_positive=only_need_positive)
935 | pts.extend(more_points)
936 | if len(pts) >= MAX_POINTS:
937 | print("Have enough; break")
938 | break
939 |
940 | if len(pts) == 0:
941 | break
942 |
943 | neuron_values = known_T.extend_by(known_A0, known_B0).forward(pts)
944 |
945 | neuron_positive_count = np.sum(neuron_values>1,axis=0)
946 | neuron_negative_count = np.sum(neuron_values<-1,axis=0)
947 | print("Counts")
948 | print(neuron_positive_count)
949 | print(neuron_negative_count)
950 |
951 | print("SHOULD BE DONE?", done, only_need_positive)
952 | if done and only_need_positive:
953 | good = True
954 | break
955 | if np.all(neuron_positive_count > 0) and np.all(neuron_negative_count > 0) or \
956 | (only_need_positive and np.all(neuron_positive_count > 0)):
957 | print("Have all the points we need (2)")
958 | good = True
959 | break
960 |
961 | if len(pts) < MAX_POINTS/2 and good == False:
962 | print("======="*10)
963 | print("Select a new point to start from")
964 | print("======="*10)
965 | if already_checked_critical_points:
966 | print("CHOOSE FROM", len(initial_points), initial_points)
967 | which_point = initial_points[np.random.randint(0,len(initial_points)-1)]
968 | else:
969 | which_point = next(get_critical_point)
970 | else:
971 | print("Abort")
972 | break
973 |
974 | critical_points = np.array(pts)#sorted(list(set(map(tuple,pts))))
975 |
976 |
977 | print("Now have critical points", len(critical_points))
978 |
979 | if CHEATING:
980 | layer = [[np.min(np.abs(x)) for x in cheat_get_inner_layers(x[np.newaxis,:])][LAYER+1] for x in critical_points]
981 |
982 | #print("Which layer is zero?", sorted(layer))
983 | layer = np.abs(cheat_get_inner_layers(np.array(critical_points))[LAYER+1])
984 |
985 | print(layer)
986 |
987 | which_is_zero = np.argmin(layer,axis=1)
988 | print("Which neuron is zero?", which_is_zero)
989 |
990 | which_is_zero = which_is_zero[0]
991 |
992 | print("Query count", query_count)
993 |
994 |
995 | K = neuron_count[LAYER+1]
996 | MAX = (1<0)
31 |
32 | # Okay so this is an ugly hack
33 | # I want to track where the queries come from.
34 | # So in order to pretty print line numer -> code
35 | # open up the current file and use this as a lookup.
36 |
37 | TRACK_LINES = False
38 | self_lines = open(sys.argv[0]).readlines()
39 |
40 | # We're going to keep track of all queries we've generated so that we can use them later on
41 | # (in order to save on query efficiency)
42 | # Format: [(x, f(x))]
43 | SAVED_QUERIES = []
44 |
45 |
46 | def run(x,inner_A=__cheat_A,inner_B=__cheat_B):
47 | """
48 | Run the neural network forward on the input x using the matrix A,B.
49 |
50 | Log the result as having happened so that we can debug errors and
51 | improve query efficiency.
52 | """
53 | global query_count
54 | query_count += x.shape[0]
55 | assert len(x.shape) == 2
56 |
57 | orig_x = x
58 |
59 | for i,(a,b) in enumerate(zip(inner_A,inner_B)):
60 | # Compute the matrix product.
61 | # This is a right-matrix product which means that rows/columns are flipped
62 | # from the definitions in the paper.
63 | # This was the first method I wrote and it doesn't make sense.
64 | # Please forgive me.
65 | x = matmul(x,a,b)
66 | if i < len(sizes)-2:
67 | x = x*(x>0)
68 | SAVED_QUERIES.extend(zip(orig_x,x))
69 |
70 | if TRACK_LINES:
71 | for line in traceback.format_stack():
72 | if 'repeated' in line: continue
73 | line_no = int(line.split("line ")[1].split()[0][:-1])
74 | if line_no not in query_count_at:
75 | query_count_at[line_no] = 0
76 | query_count_at[line_no] += x.shape[0]
77 |
78 | return x
79 |
80 |
81 | class NoCheatingError(Exception):
82 | """
83 | This error is thrown by functions that cheat if we're in no-cheating mode.
84 |
85 | To debug code it's helpful to be able to look at the weights directly,
86 | and inspect the inner activations of the model.
87 |
88 | But sometimes debug code can be left in by accident and we might pollute
89 | the actual results of the paper by cheating. This error is thrown by all
90 | functions that cheat so that we can't possibly do it by accident.
91 | """
92 |
93 | class AcceptableFailure(Exception):
94 | """
95 | Sometimes things fail for entirely acceptable reasons (e.g., we haven't
96 | queried enough points to have seen all the hyperplanes, or we get stuck
97 | in a constant zero region). When that happens we throw an AcceptableFailure
98 | because life is tough but we should just back out and try again after
99 | making the appropriate correction.
100 | """
101 | def __init__(self, *args, **kwargs):
102 | for k,v in kwargs.items():
103 | setattr(self, k, v)
104 |
105 | class GatherMoreData(AcceptableFailure):
106 | """
107 | When gathering witnesses to hyperplanes, sometimes we don't have
108 | enough and need more witnesses to *this particular neuron*.
109 | This error says that we should gather more examples of that one.
110 | """
111 | def __init__(self, data, **kwargs):
112 | super(GatherMoreData, self).__init__(data=data, **kwargs)
113 |
114 | def _cheat_get_inner_layers(x,A=__cheat_A,B=__cheat_B, as_list=False):
115 | """
116 | Cheat to get the inner layers of the neural network.
117 | """
118 | region = []
119 | for i,(a,b) in enumerate(zip(A,B)):
120 | x = matmul(x,a,b)
121 | region.append(np.copy(x))
122 | if i < len(sizes)-2:
123 | x = x*(x>0)
124 | return region
125 |
126 | def cheat_get_inner_layers(x,A=A,B=B, as_list=False):
127 | if not CHEATING: raise NoCheatingError()
128 | return _cheat_get_inner_layers(x,A,B,as_list)
129 |
130 | def _cheat_get_polytope_id(x,A=__cheat_A,B=__cheat_B, as_list=False, flatten=True):
131 | """
132 | Cheat to get the polytope ID of the network.
133 | """
134 | if not CHEATING: raise NoCheatingError()
135 | region = []
136 | for i,(a,b) in enumerate(zip(A,B)):
137 | x = matmul(x,a,b)
138 | if i < len(sizes)-2:
139 | region.append(x<0)
140 | x = x*(x>0)
141 | if flatten:
142 | arr = np.array(np.concatenate(region,axis=1),dtype=np.int64)
143 | else:
144 | arr = region
145 | if as_list:
146 | return arr
147 | arr *= 1<0)
224 | return x
225 | def forward_at(self, point, d_matrix):
226 | if len(self.A) == 0:
227 | return d_matrix
228 |
229 | mask_vectors = [layer > 0 for layer in self.get_hidden_layers(point)]
230 |
231 | h_matrix = np.array(d_matrix)
232 | for i,(matrix,mask) in enumerate(zip(self.A, mask_vectors)):
233 | h_matrix = matmul(h_matrix, matrix, None) * mask
234 |
235 | return h_matrix
236 | def get_hidden_layers(self, x, flat=False, np=np):
237 | if len(self.A) == 0: return []
238 | region = []
239 | for i,(a,b) in enumerate(zip(self.A,self.B)):
240 | x = matmul(x,a,b,np=np)
241 | if np == jnp:
242 | region.append(x)
243 | else:
244 | region.append(np.copy(x))
245 | if i < len(self.A)-1:
246 | x = x*(x>0)
247 | if flat:
248 | region = np.concatenate(region,axis=0)
249 | return region
250 | def get_polytope(self, x):
251 | if len(self.A) == 0: return tuple()
252 | h = self.get_hidden_layers(x)
253 | h = np.concatenate(h, axis=0)
254 | return tuple(np.int32(np.sign(h)))
255 |
256 | def check_quality(layer_num, extracted_normal, extracted_bias, do_fix=False):
257 | """
258 | Check the quality of the solution.
259 |
260 | The first function is read-only, and just reports how good or bad things are.
261 | The second half, when in cheating mode, will align the two matrices.
262 | """
263 |
264 | print("\nCheck the solution of the last weight matrix.")
265 |
266 | reorder = [None]*(neuron_count[layer_num+1])
267 | for i in range(neuron_count[layer_num+1]):
268 | gaps = []
269 | ratios = []
270 | for j in range(neuron_count[layer_num+1]):
271 | if np.all(np.abs(extracted_normal[:,i])) < 1e-9:
272 | extracted_normal[:,i] += 1e-9
273 | ratio = __cheat_A[layer_num][:,j] / extracted_normal[:,i]
274 | ratio = np.median(ratio)
275 | error = __cheat_A[layer_num][:,j] - ratio * extracted_normal[:,i]
276 | error = np.sum(error**2)/np.sum(__cheat_A[layer_num][:,j]**2)
277 | gaps.append(error)
278 | ratios.append(ratio)
279 | print("Neuron", i, "maps on to neuron", np.argmin(gaps), "with error", np.min(gaps)**.5, 'ratio', ratios[np.argmin(gaps)])
280 | print("Bias check", (__cheat_B[layer_num][np.argmin(gaps)]-extracted_bias[i]*ratios[np.argmin(gaps)]))
281 |
282 | reorder[np.argmin(gaps)] = i
283 | if do_fix and CHEATING:
284 | extracted_normal[:,i] *= np.abs(ratios[np.argmin(gaps)])
285 | extracted_bias[i] *= np.abs(ratios[np.argmin(gaps)])
286 |
287 | if min(gaps) > 1e-2:
288 | print("ERROR LAYER EXTRACTED INCORRECTLY")
289 | print("\tGAPS:", " ".join("%.04f"%x for x in gaps))
290 | print("\t Got:", " ".join("%.04f"%x for x in extracted_normal[:,i]/extracted_normal[0,i]))
291 | print("\t Real:", " ".join("%.04f"%x for x in __cheat_A[layer_num][:,np.argmin(gaps)]/__cheat_A[layer_num][0,np.argmin(gaps)]))
292 |
293 |
294 | # Randomly assign the unused neurons.
295 | used = [x for x in reorder if x is not None]
296 | missed = list(set(range(len(reorder))) - set(used))
297 | for i in range(len(reorder)):
298 | if reorder[i] is None:
299 | reorder[i] = missed.pop()
300 |
301 |
302 | if CHEATING:
303 | extracted_normal = extracted_normal[:,reorder]
304 | extracted_bias = extracted_bias[reorder]
305 |
306 | return extracted_normal,extracted_bias
307 |
308 |
--------------------------------------------------------------------------------
/train_models.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import sys
16 | import os
17 | os.environ['CUDA_VISIBLE_DEVICES'] = ''
18 | import numpy as onp
19 | import jax
20 | import jax.experimental.optimizers
21 | import jax.numpy as jnp
22 |
23 | def matmul(a,b,c,np=jnp):
24 | if c is None:
25 | c = np.zeros(1)
26 |
27 | return np.dot(a,b)+c
28 |
29 | seed = int(sys.argv[2]) if len(sys.argv) > 2 else 42 # for luck
30 | onp.random.seed(seed)
31 |
32 | sizes = list(map(int,sys.argv[1].split("-")))
33 | dimensions = [tuple([x]) for x in sizes]
34 | neuron_count = sizes
35 | ops = [matmul]*(len(sizes)-1)
36 |
37 | # Let's not overcomplicate things.
38 | # Initialize with a standard gaussian initialization.
39 | # Yes someone with their xavier spectral kaiming initialization might do better
40 | # But we're memorizing some random points. This works.
41 | A = []
42 | B = []
43 | for i,(op, a,b) in enumerate(zip(ops, sizes, sizes[1:])):
44 | A.append(onp.random.normal(size=(a,b))/(b**.5))
45 | B.append(onp.zeros((b,)))
46 |
47 | def run(x, A, B):
48 | """
49 | Run the neural network forward on the input x using the matrix A,B.
50 | """
51 |
52 | for i,(op,a,b) in enumerate(zip(ops,A,B)):
53 | # Compute the matrix product.
54 | # This is a right-matrix product which means that rows/columns are flipped
55 | # from the definitions in the paper.
56 | # This was the first method I wrote and it doesn't make sense.
57 | # Please forgive me.
58 | x = op(x,a,b)
59 | if i < len(sizes)-2:
60 | x = x*(x>0)
61 |
62 | return x
63 |
64 | def getinner(x, A, B):
65 | """
66 | Cheat to get the inner layers of the neural network.
67 | """
68 | region = []
69 | for i,(op,a,b) in enumerate(zip(ops,A,B)):
70 | x = op(x,a,b)
71 | region.append(onp.copy(x))
72 | if i < len(sizes)-2:
73 | x = x*(x>0)
74 | return region
75 |
76 |
77 | def loss(params, inputs, targets):
78 | logits = run(inputs, params[0], params[1])
79 | # L2 loss is best loss
80 | res = (targets-logits.flatten())**2
81 | return jnp.mean(res)
82 |
83 | # generate random training data
84 |
85 | params = [A,B]
86 |
87 |
88 | SAMPLES = 20
89 |
90 | # Again, let's not think. Just optimize with adam.
91 | # Your cosine cyclic learning rate schedule can have fun elsewhere.
92 | # We just pick 3e-4 because Karpathy said so.
93 | init, opt_update, get_params = jax.experimental.optimizers.adam(3e-4)
94 |
95 | X = onp.random.normal(size=(SAMPLES, sizes[0]))
96 | Y = onp.array(onp.random.normal(size=SAMPLES)>0,dtype=onp.float32)
97 |
98 | loss_grad = jax.grad(loss)
99 |
100 | @jax.jit
101 | def update(i, opt_state, batch_x, batch_y):
102 | params = get_params(opt_state)
103 | return opt_update(i, loss_grad(params, batch_x, batch_y), opt_state)
104 | opt_state = init(params)
105 |
106 |
107 | # Who are we kidding.
108 | # Not like we're running on a TPU pod and need that batch size of 16384
109 | BS = 4
110 |
111 | # Train loop.
112 |
113 | step = 0
114 | for i in range(100):
115 | if i%10 == 0:
116 | print('loss', loss(params, X, Y))
117 |
118 | for j in range(0,SAMPLES,BS):
119 | batch_x = X[j:j+BS]
120 | batch_y = Y[j:j+BS]
121 |
122 | # gradient descent!
123 | opt_state = update(step, opt_state, batch_x, batch_y)
124 | params = get_params(opt_state)
125 |
126 | step += 1
127 |
128 | # Save our amazing model.
129 | onp.save("models/" + str(seed) + "_" + "-".join(map(str,sizes)), params)
130 |
--------------------------------------------------------------------------------