Run attacks on logic-locked circuits.
26 |
27 |
28 | Expand source code
29 |
30 | """Run attacks on logic-locked circuits."""
31 | import code
32 | import random
33 | import time
34 |
35 | import circuitgraph as cg
36 |
37 |
38 | def _localtime():
39 | return time.asctime(time.localtime(time.time()))
40 |
41 |
42 | def miter_attack(
43 | cl,
44 | key,
45 | timeout=None,
46 | key_cons=None,
47 | unroll_cyclic=True,
48 | verbose=True,
49 | code_on_error=False,
50 | ):
51 | """
52 | Launch a miter-based sat attack on a locked circuit.
53 |
54 | Parameters
55 | ----------
56 | cl: circuitgraph.Circuit
57 | The locked circuit to attack
58 | key: dict of str:bool
59 | The correct key, used to construct the oracle
60 | timeout: int
61 | Timeout for the attack, in seconds
62 | key_cons: circuitgraph.Circuit or iter of circuitgraph.Circuit
63 | Key conditions to satisfy during attack,
64 | must have output 'sat' and be a function of the key inputs
65 | unroll_cyclic: bool
66 | If True, convert cyclic circuits to acyclic versions
67 | verbose: bool
68 | If True, attack progress will be printed
69 | code_on_error: bool
70 | If True, drop into an interactive session on an error
71 |
72 | Returns
73 | -------
74 | dict
75 | A dictionary containing attack info and results
76 |
77 | """
78 | start_time = time.time()
79 |
80 | if cl.is_cyclic():
81 | if unroll_cyclic:
82 | cl = cg.tx.acyclic_unroll(cl)
83 | else:
84 | raise ValueError(
85 | "Circuit is cyclic. Set 'unroll_cyclic' to True to run sat on "
86 | "this circuit"
87 | )
88 |
89 | # setup vars
90 | keys = tuple(key.keys())
91 | ins = tuple(cl.startpoints() - key.keys())
92 | outs = tuple(cl.endpoints())
93 |
94 | # create simulation solver
95 | s_sim, v_sim = cg.sat.construct_solver(cl, key)
96 |
97 | # create miter solver
98 | m = cg.tx.miter(cl, startpoints=set(ins))
99 | s_miter, v_miter = cg.sat.construct_solver(m)
100 |
101 | # add key constraints
102 | if key_cons:
103 | if isinstance(key_cons, cg.Circuit):
104 | key_cons = [key_cons]
105 | for key_con in key_cons:
106 | if verbose:
107 | print(
108 | f"[{_localtime()}] circuit: {cl.name}, "
109 | f"adding constraints: {key_con.name}"
110 | )
111 | formula, v_cons = cg.sat.cnf(key_con)
112 | con_clauses = formula.clauses
113 |
114 | # add constraints circuits
115 | c0_offset = s_miter.nof_vars()
116 | c0 = cg.sat.remap(con_clauses, c0_offset)
117 | s_miter.append_formula(c0)
118 | c1_offset = s_miter.nof_vars()
119 | c1 = cg.sat.remap(con_clauses, c1_offset)
120 | s_miter.append_formula(c1)
121 |
122 | # encode keys connections
123 | clauses = [[v_cons.id("sat") + c0_offset], [v_cons.id("sat") + c1_offset]]
124 | clauses += [
125 | [-v_miter.id(f"c0_{n}"), v_cons.id(n) + c0_offset] for n in keys
126 | ]
127 | clauses += [
128 | [v_miter.id(f"c0_{n}"), -v_cons.id(n) - c0_offset] for n in keys
129 | ]
130 | clauses += [
131 | [-v_miter.id(f"c1_{n}"), v_cons.id(n) + c1_offset] for n in keys
132 | ]
133 | clauses += [
134 | [v_miter.id(f"c1_{n}"), -v_cons.id(n) - c1_offset] for n in keys
135 | ]
136 |
137 | s_miter.append_formula(clauses)
138 |
139 | # get circuit clauses
140 | formula, v_c = cg.sat.cnf(cl)
141 | clauses = formula.clauses
142 |
143 | # solve
144 | dis = []
145 | dos = []
146 | iter_times = []
147 | iter_keys = []
148 | while s_miter.solve(assumptions=[v_miter.id("sat")]):
149 |
150 | # get di
151 | model = s_miter.get_model()
152 | di = [model[v_miter.id(n) - 1] > 0 for n in ins]
153 | if tuple(di) in dis:
154 | if code_on_error:
155 | print("Error di")
156 | code.interact(local=dict(globals(), **locals()))
157 | else:
158 | raise ValueError("Saw same di twice")
159 |
160 | # get intermediate keys
161 | k0 = {n: model[v_miter.id(f"c0_{n}") - 1] > 0 for n in keys}
162 | k1 = {n: model[v_miter.id(f"c1_{n}") - 1] > 0 for n in keys}
163 | iter_keys.append((k0, k1))
164 |
165 | # get do
166 | s_sim.solve(assumptions=[(2 * b - 1) * v_sim.id(n) for b, n in zip(di, ins)])
167 | model = s_sim.get_model()
168 | if model is None:
169 | if code_on_error:
170 | print("Error sim")
171 | code.interact(local=dict(globals(), **locals()))
172 | else:
173 | raise ValueError("Could not get simulation model")
174 | do = [model[v_sim.id(n) - 1] > 0 for n in outs]
175 | dis.append(tuple(di))
176 | dos.append(tuple(do))
177 | iter_times.append(time.time() - start_time)
178 |
179 | # add constraints circuits
180 | c0_offset = s_miter.nof_vars()
181 | c0 = cg.sat.remap(clauses, c0_offset)
182 | s_miter.append_formula(c0)
183 | c1_offset = s_miter.nof_vars()
184 | c1 = cg.sat.remap(clauses, c1_offset)
185 | s_miter.append_formula(c1)
186 |
187 | # encode dis + dos
188 | dio_clauses = [
189 | [(2 * b - 1) * (v_c.id(n) + c0_offset)] for b, n in zip(di + do, ins + outs)
190 | ]
191 | dio_clauses += [
192 | [(2 * b - 1) * (v_c.id(n) + c1_offset)] for b, n in zip(di + do, ins + outs)
193 | ]
194 | s_miter.append_formula(dio_clauses)
195 |
196 | # encode keys connections
197 | key_clauses = [[-v_miter.id(f"c0_{n}"), v_c.id(n) + c0_offset] for n in keys]
198 | key_clauses += [[v_miter.id(f"c0_{n}"), -v_c.id(n) - c0_offset] for n in keys]
199 | key_clauses += [[-v_miter.id(f"c1_{n}"), v_c.id(n) + c1_offset] for n in keys]
200 | key_clauses += [[v_miter.id(f"c1_{n}"), -v_c.id(n) - c1_offset] for n in keys]
201 | s_miter.append_formula(key_clauses)
202 |
203 | # check timeout
204 | if timeout and (time.time() - start_time) > timeout:
205 | print(f"[{_localtime()}] circuit: {cl.name}, Timeout: True")
206 | return {
207 | "Time": None,
208 | "Iterations": len(dis),
209 | "Timeout": True,
210 | "Equivalent": False,
211 | "Key Found": False,
212 | "dis": dis,
213 | "dos": dos,
214 | "iter_times": iter_times,
215 | "iter_keys": iter_keys,
216 | }
217 |
218 | if verbose:
219 | print(
220 | f"[{_localtime()}] "
221 | f"circuit: {cl.name}, iter: {len(dis)}, "
222 | f"time: {time.time()-start_time}, "
223 | f"clauses: {s_miter.nof_clauses()}, "
224 | f"vars: {s_miter.nof_vars()}"
225 | )
226 |
227 | # check if a satisfying key remains
228 | key_found = s_miter.solve()
229 | if verbose:
230 | print(f"[{_localtime()}] circuit: {cl.name}, key found: {key_found}")
231 | if not key_found:
232 | return {
233 | "Time": None,
234 | "Iterations": len(dis),
235 | "Timeout": False,
236 | "Equivalent": False,
237 | "Key Found": False,
238 | "dis": dis,
239 | "dos": dos,
240 | "iter_times": iter_times,
241 | "iter_keys": iter_keys,
242 | }
243 |
244 | # get key
245 | model = s_miter.get_model()
246 | attack_key = {n: model[v_miter.id(f"c1_{n}") - 1] > 0 for n in keys}
247 |
248 | # check key
249 | assumptions = {
250 | **{f"c0_{k}": v for k, v in key.items()},
251 | **{f"c1_{k}": v for k, v in attack_key.items()},
252 | "sat": True,
253 | }
254 | equivalent = not cg.sat.solve(m, assumptions)
255 | if verbose:
256 | print(f"[{_localtime()}] circuit: {cl.name}, equivalent: {equivalent}")
257 |
258 | exec_time = time.time() - start_time
259 | if verbose:
260 | print(f"[{_localtime()}] circuit: {cl.name}, elasped time: {exec_time}")
261 |
262 | return {
263 | "Time": exec_time,
264 | "Iterations": len(dis),
265 | "Timeout": False,
266 | "Equivalent": equivalent,
267 | "Key Found": True,
268 | "dis": dis,
269 | "dos": dos,
270 | "iter_times": iter_times,
271 | "iter_keys": iter_keys,
272 | "attack_key": attack_key,
273 | }
274 |
275 |
276 | def decision_tree_attack(c_or_cl, nsamples, key=None, verbose=True):
277 | """
278 | Launch a decision tree attack on a locked circuit.
279 |
280 | Attempts to capture the functionality of the oracle circuit using a
281 | decision tree.
282 |
283 | Paramters
284 | ---------
285 | c_or_cl: circuitgraph.Circuit
286 | The circuit to reverse engineer. Can either be
287 | the oracle or the locked circuit. If the locked
288 | circuit, must pass in the correct key using
289 | the `key` parameter
290 | nsamples: int
291 | The number of samples to train the decision tree on
292 | key: dict of str:bool
293 | The correct key, used to construct the oracle if
294 | the locked circuit is given.
295 | verbose: bool
296 | If True, attack progress will be printed
297 |
298 | Returns
299 | -------
300 | dict of str:sklearn.tree.DecisionTreeClassifier
301 | The trained classifier for each output.
302 |
303 | """
304 | from sklearn.tree import DecisionTreeClassifier
305 |
306 | if key:
307 | cl = c_or_cl
308 | for k, v in key.items():
309 | cl.set_type(k, str(int(v)))
310 | c = cl
311 | else:
312 | c = c_or_cl
313 |
314 | ins = tuple(c.startpoints())
315 | outs = tuple(c.endpoints())
316 |
317 | # generate training samples
318 | x = []
319 | y = {o: [] for o in outs}
320 | if verbose:
321 | print(f"[{_localtime()}] Generating samples")
322 | for i in range(nsamples):
323 | x += [[random.choice((True, False)) for i in ins]]
324 | result = cg.sat.solve(c, {i: v for i, v in zip(ins, x[-1])})
325 | for o in outs:
326 | y[o] += [result[o]]
327 |
328 | if verbose:
329 | print(f"[{_localtime()}] Training decision trees")
330 | estimators = {o: DecisionTreeClassifier() for o in outs}
331 | for o in outs:
332 | estimators[o].fit(x, y[o])
333 |
334 | if verbose:
335 | print(f"[{_localtime()}] Testing decision trees")
336 | ncorrect = 0
337 | for i in range(nsamples):
338 | x = [[random.choice((True, False)) for i in ins]]
339 | result = cg.sat.solve(c, {i: v for i, v in zip(ins, x[-1])})
340 | if all(result[o] == estimators[o].predict(x) for o in outs):
341 | ncorrect += 1
342 |
343 | if verbose:
344 | print(f"[{_localtime()}] Test accuracy: {ncorrect / nsamples}")
345 | return estimators
346 |
347 |