├── .gitignore ├── README.rst ├── examples.py ├── lr_attack.py ├── main.py ├── qpsolver.py ├── setup.cfg ├── staxmod.py ├── utils.py └── weights └── convnet.pickle /.gitignore: -------------------------------------------------------------------------------- 1 | cifar-10-python.tar.gz 2 | cifar-10-batches-py/ 3 | 4 | # Created by https://www.gitignore.io/api/python 5 | # Edit at https://www.gitignore.io/?templates=python 6 | 7 | ### Python ### 8 | # Byte-compiled / optimized / DLL files 9 | __pycache__/ 10 | *.py[cod] 11 | *$py.class 12 | 13 | # C extensions 14 | *.so 15 | 16 | # Distribution / packaging 17 | .Python 18 | build/ 19 | develop-eggs/ 20 | dist/ 21 | downloads/ 22 | eggs/ 23 | .eggs/ 24 | lib/ 25 | lib64/ 26 | parts/ 27 | sdist/ 28 | var/ 29 | wheels/ 30 | share/python-wheels/ 31 | *.egg-info/ 32 | .installed.cfg 33 | *.egg 34 | MANIFEST 35 | 36 | # PyInstaller 37 | # Usually these files are written by a python script from a template 38 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 39 | *.manifest 40 | *.spec 41 | 42 | # Installer logs 43 | pip-log.txt 44 | pip-delete-this-directory.txt 45 | 46 | # Unit test / coverage reports 47 | htmlcov/ 48 | .tox/ 49 | .nox/ 50 | .coverage 51 | .coverage.* 52 | .cache 53 | nosetests.xml 54 | coverage.xml 55 | *.cover 56 | .hypothesis/ 57 | .pytest_cache/ 58 | 59 | # Translations 60 | *.mo 61 | *.pot 62 | 63 | # Django stuff: 64 | *.log 65 | local_settings.py 66 | db.sqlite3 67 | 68 | # Flask stuff: 69 | instance/ 70 | .webassets-cache 71 | 72 | # Scrapy stuff: 73 | .scrapy 74 | 75 | # Sphinx documentation 76 | docs/_build/ 77 | 78 | # PyBuilder 79 | target/ 80 | 81 | # Jupyter Notebook 82 | .ipynb_checkpoints 83 | 84 | # IPython 85 | profile_default/ 86 | ipython_config.py 87 | 88 | # pyenv 89 | .python-version 90 | 91 | # celery beat schedule file 92 | celerybeat-schedule 93 | 94 | # SageMath parsed files 95 | *.sage.py 96 | 97 | # Environments 98 | .env 99 | .venv 100 | env/ 101 | venv/ 102 | ENV/ 103 | env.bak/ 104 | venv.bak/ 105 | 106 | # Spyder project settings 107 | .spyderproject 108 | .spyproject 109 | 110 | # Rope project settings 111 | .ropeproject 112 | 113 | # mkdocs documentation 114 | /site 115 | 116 | # mypy 117 | .mypy_cache/ 118 | .dmypy.json 119 | dmypy.json 120 | 121 | # Pyre type checker 122 | .pyre/ 123 | 124 | ### Python Patch ### 125 | .venv/ 126 | 127 | # End of https://www.gitignore.io/api/python 128 | -------------------------------------------------------------------------------- /README.rst: -------------------------------------------------------------------------------- 1 | ======= 2 | Linear Region Attack 3 | ======= 4 | 5 | The Linear Region attack is a powerful white-box adversarial attack that 6 | exploits knowledge about the geometry of neural networks to find minimal 7 | adversarial perturbations without doing gradient descent. 8 | 9 | This repository provides an efficient GPU impelementation of the Linear Region 10 | attack. If you find our attack useful or use this code, please cite our paper 11 | `Scaling up the randomized gradient free adversarial attack reveals 12 | overestimation of robustness using established attacks `_. 13 | 14 | BibTeX 15 | ------ 16 | 17 | .. code-block:: 18 | 19 | @article{croce2019scaling, 20 | author="Croce, Francesco 21 | and Rauber, Jonas 22 | and Hein, Matthias", 23 | title="Scaling up the Randomized Gradient-Free Adversarial Attack Reveals Overestimation of Robustness Using Established Attacks", 24 | journal="International Journal of Computer Vision", 25 | year="2019", 26 | month="Oct", 27 | day="03", 28 | issn="1573-1405", 29 | doi="10.1007/s11263-019-01213-0", 30 | url="https://doi.org/10.1007/s11263-019-01213-0" 31 | } 32 | 33 | Requirements 34 | ------------ 35 | 36 | This impelementation requires Python 3.6 or newer, NumPy and JAX. 37 | Before installing JAX, you need to install jaxlib with GPU support: 38 | 39 | .. code-block:: bash 40 | 41 | PYTHON_VERSION=cp36 42 | CUDA_VERSION=cuda100 43 | PLATFORM=linux_x86_64 44 | BASE_URL='https://storage.googleapis.com/jax-wheels' 45 | python3 -m pip install --upgrade $BASE_URL/$CUDA_VERSION/jaxlib-0.1.11-$PYTHON_VERSION-none-$PLATFORM.whl 46 | 47 | python3 -m pip install --upgrade jax 48 | 49 | For details regarding the installation of JAX, please check the `JAX readme `_. 50 | 51 | We have successfully used ``Python 3.6``, ``NumPy 1.16``, ``JAX 0.1.21`` and ``jaxlib 0.1.11``. 52 | 53 | Usage 54 | ----- 55 | 56 | To run the attack on a 10-layer convnet trained on CIFAR10 for the first image in the CIFAR-10 test set, just run this: 57 | 58 | .. code-block:: bash 59 | 60 | ./main.py cifar_convnet --regions 40 # just for illustration; we recommend more regions, e.g. 400 61 | 62 | Note: To run the example, you need CIFAR-10: 63 | 64 | .. code-block:: bash 65 | 66 | wget http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz 67 | tar -zxvf cifar-10-python.tar.gz 68 | -------------------------------------------------------------------------------- /examples.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import sys 3 | import os 4 | import numpy as onp 5 | import logging 6 | import jax 7 | from functools import partial 8 | 9 | from staxmod import Conv, Dense, Flatten, Relu 10 | from staxmod import serial 11 | 12 | from utils import is_device_array 13 | 14 | 15 | train_list = [ 16 | ['data_batch_1', 'c99cafc152244af753f735de768cd75f'], 17 | ['data_batch_2', 'd4bba439e000b95fd0a9bffe97cbabec'], 18 | ['data_batch_3', '54ebc095f3ab1f0389bbae665268c751'], 19 | ['data_batch_4', '634d18415352ddfa80567beed471001a'], 20 | ['data_batch_5', '482c414d41f54cd18b22e5b47cb7c3cb'], 21 | ] 22 | 23 | test_list = [ 24 | ['test_batch', '40351d587109b95175f43aff81a1287e'], 25 | ] 26 | 27 | 28 | def load_cifar10(*, train): 29 | if train: 30 | filenames = train_list 31 | else: 32 | filenames = test_list 33 | 34 | images = [] 35 | labels = [] 36 | 37 | try: 38 | for filename, checksum in filenames: 39 | path = os.path.join('cifar-10-batches-py', filename) 40 | path = os.path.expanduser(path) 41 | with open(path, 'rb') as f: 42 | entry = pickle.load(f, encoding='latin1') 43 | images.append(entry['data']) 44 | labels.extend(entry['labels']) 45 | except FileNotFoundError: 46 | logging.error('Could not load CIFAR. Run the following commands to download it:\n' 47 | '\n' 48 | 'wget http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz\n' 49 | 'tar -zxvf cifar-10-python.tar.gz') 50 | sys.exit(1) 51 | 52 | images = onp.vstack(images).reshape((-1, 3, 32, 32)).transpose((0, 2, 3, 1)) 53 | labels = onp.asarray(labels, dtype=onp.int32) 54 | 55 | assert images.dtype == onp.uint8 56 | images = images.astype(onp.float32) / 255. 57 | return images, labels 58 | 59 | 60 | def ConvNet(): 61 | return serial( 62 | Conv(96, (3, 3), padding='SAME'), Relu, 63 | Conv(96, (3, 3), padding='SAME'), Relu, 64 | Conv(192, (3, 3), padding='SAME', strides=(2, 2)), Relu, 65 | Conv(192, (3, 3), padding='SAME'), Relu, 66 | Conv(192, (3, 3), padding='SAME'), Relu, 67 | Conv(192, (3, 3), padding='SAME', strides=(2, 2)), Relu, 68 | Conv(192, (3, 3), padding='SAME'), Relu, 69 | Conv(384, (2, 2), padding='SAME', strides=(2, 2)), Relu, 70 | Flatten, 71 | Dense(1200), Relu, 72 | Dense(10)) 73 | 74 | 75 | def load_params(path): 76 | with open(path, 'rb') as f: 77 | params = pickle.load(f) 78 | params = jax.tree_map(jax.device_put, params) 79 | return params 80 | 81 | 82 | def find_starting_point(images, labels, args, x, label, logits, predict_class): 83 | strategy = args.nth_likely_class_starting_point 84 | if strategy is None: 85 | return find_starting_point_simple_strategy(images, labels, x, label, predict_class) 86 | return find_starting_point_likely_class_strategy(images, labels, x, label, logits, predict_class, nth=strategy) 87 | 88 | 89 | def find_starting_point_simple_strategy(images, labels, x, label, predict_class): 90 | """returns the image in images that is closest to x that has a 91 | different label and predicted class than the provided label of x""" 92 | 93 | assert x.shape[0] == 1 94 | 95 | assert not is_device_array(x) 96 | assert not is_device_array(label) 97 | 98 | assert not is_device_array(images) 99 | assert not is_device_array(labels) 100 | assert not is_device_array(x) 101 | assert not is_device_array(label) 102 | 103 | # filter those with the same label 104 | images = images[labels != label] 105 | 106 | # get closest images from other classes 107 | diff = images - x 108 | diff = diff.reshape((diff.shape[0], -1)) 109 | diff = onp.square(diff).sum(axis=-1) 110 | diff = onp.argsort(diff) 111 | assert diff.ndim == 1 112 | 113 | for j, index in enumerate(diff): 114 | logging.info(f'trying {j + 1}. candidate ({index})') 115 | candidate = images[index][onp.newaxis] 116 | class_ = jax.device_get(predict_class(candidate).squeeze(axis=0)) 117 | logging.info(f'label = {label}, candidate class = {class_}') 118 | if class_ != label: 119 | return candidate, class_ 120 | 121 | 122 | def find_starting_point_likely_class_strategy(images, labels, x, label, logits, predict_class, *, nth): 123 | assert x.shape[0] == 1 124 | 125 | assert not is_device_array(x) 126 | assert not is_device_array(label) 127 | 128 | assert not is_device_array(images) 129 | assert not is_device_array(labels) 130 | assert not is_device_array(x) 131 | assert not is_device_array(label) 132 | 133 | # determine nth likely class 134 | logits = logits.squeeze(axis=0) 135 | ordered_classes = onp.argsort(logits) 136 | assert ordered_classes[-1] == label 137 | 138 | assert 2 <= nth <= len(logits) 139 | nth_class = ordered_classes[-nth] 140 | 141 | # select those from the nth most likely class 142 | images = images[labels == nth_class] 143 | 144 | # get closest images from other classes 145 | diff = images - x 146 | diff = diff.reshape((diff.shape[0], -1)) 147 | diff = onp.square(diff).sum(axis=-1) 148 | diff = onp.argsort(diff) 149 | assert diff.ndim == 1 150 | 151 | for j, index in enumerate(diff): 152 | logging.info(f'trying {j + 1}. candidate ({index})') 153 | candidate = images[index][onp.newaxis] 154 | class_ = jax.device_get(predict_class(candidate).squeeze(axis=0)) 155 | logging.info(f'label = {label}, candidate class = {class_}') 156 | if class_ != label: 157 | return candidate, class_ 158 | 159 | 160 | def _cifar_example(architecture, weights=None): 161 | init, predict = architecture() 162 | output_shape, params = init((-1, 32, 32, 3)) 163 | if weights is not None: 164 | params = load_params(weights) 165 | n_classes = output_shape[-1] 166 | images, labels = load_cifar10(train=False) 167 | train_images, train_labels = load_cifar10(train=True) 168 | assert not is_device_array(train_images) and not is_device_array(train_labels) 169 | find_starting_point_2 = partial(find_starting_point, train_images, train_labels) 170 | return n_classes, predict, params, images, labels, find_starting_point_2 171 | 172 | 173 | def get_cifar_example(load_weights=True): 174 | weights = 'weights/convnet.pickle' if load_weights else None 175 | return _cifar_example(ConvNet, weights) 176 | 177 | 178 | def get_example(name): 179 | return { 180 | 'cifar_convnet': lambda: get_cifar_example(load_weights=True), 181 | }[name]() 182 | -------------------------------------------------------------------------------- /lr_attack.py: -------------------------------------------------------------------------------- 1 | import time 2 | import logging 3 | import random 4 | import jax 5 | import jax.numpy as np 6 | import numpy as onp 7 | import tqdm 8 | from functools import partial 9 | from functools import wraps 10 | 11 | from qpsolver import solve 12 | from utils import is_device_array, scatter, onehot 13 | 14 | 15 | def accuracy(predict_class, images, labels, batch_size=100): 16 | total = len(images) 17 | correct = 0 18 | for i in tqdm.trange(0, len(images), batch_size): 19 | j = i + batch_size 20 | predicted_class = predict_class(images[i:j]) 21 | correct += onp.sum(predicted_class == labels[i:j]) 22 | return correct / total 23 | 24 | 25 | def l2_distance(x0, x): 26 | assert x0.shape == x.shape 27 | return onp.linalg.norm(jax.device_get(x0) - jax.device_get(x)) 28 | 29 | 30 | def misclassification_polytope(a, c, ls): 31 | """creates misclassification constraints""" 32 | assert a.ndim == 2 33 | assert a.shape[0] == 1 # only batch size 1 is supported 34 | n_classes = a.shape[1] 35 | 36 | u = a[:, ls] - a[:, c] 37 | 38 | c = np.atleast_1d(np.asarray([c]).squeeze()) 39 | ls = np.atleast_1d(np.asarray([ls]).squeeze()) 40 | 41 | Av = lambda Vv: Vv[:, c] - Vv[:, ls] # noqa: E731 42 | vA = lambda v: (scatter(c, np.sum(np.atleast_2d(v), axis=-1, keepdims=True), n_classes) + # noqa: E731 43 | scatter(ls, -np.atleast_2d(v), n_classes)) 44 | 45 | return Av, vA, u 46 | 47 | 48 | def relu_polytope(a, f): 49 | """creates polytope constraints""" 50 | sf = np.sign(f) 51 | nsf = -sf 52 | u = sf * a 53 | 54 | Av = lambda Vv: nsf * Vv # noqa: E731 55 | vA = lambda v: nsf * v # noqa: E731 56 | 57 | # Some of these constrains are always fulfilled and can be removed. 58 | non_trivial = sf != 0 59 | # create a vector with 0s for all non-trivial ones, 60 | # and inf for trivial ones 61 | trivial_inf = 1 / non_trivial.astype(np.float32) - 1 62 | # make the upper bound of trivial ones infinity to make them least violated 63 | u = u + trivial_inf 64 | return Av, vA, u 65 | 66 | 67 | def get_other_classes(*, exclude, total, first=None): 68 | rest = [x for x in range(total) if x != exclude and x != first] 69 | random.shuffle(rest) 70 | first = [] if first is None else [first] 71 | return first + rest 72 | 73 | 74 | def flatten(x): 75 | if isinstance(x, list): 76 | for y in x: 77 | yield from flatten(y) 78 | else: 79 | assert isinstance(x, tuple) 80 | for y in x: 81 | yield y 82 | 83 | 84 | def flatten_dims(x): 85 | return np.reshape(x, (x.shape[0], -1)) 86 | 87 | 88 | def layer_size(x): 89 | _, n = x.shape 90 | return n 91 | 92 | 93 | def flatten_predict(predict): 94 | @wraps(predict) 95 | def flat_predict(x): 96 | output, additional_outputs = predict(x) 97 | additional_outputs = list(flatten(additional_outputs)) 98 | additional_outputs = list(map(flatten_dims, additional_outputs)) 99 | additional_outputs = np.concatenate(additional_outputs, axis=-1) 100 | return output, additional_outputs 101 | 102 | return flat_predict 103 | 104 | 105 | def return_classes_logits_layer_sizes(f, *args, **kwargs): 106 | logging.info(f'compiling return_classes_logits_layer_sizes') 107 | logits, additional_outputs = f(*args, **kwargs) 108 | additional_outputs = list(flatten(additional_outputs)) 109 | additional_outputs = list(map(flatten_dims, additional_outputs)) 110 | rows_per_layer = list(map(layer_size, additional_outputs)) 111 | return np.argmax(logits, axis=-1), logits, rows_per_layer 112 | 113 | 114 | def generic_get_A(predict, label, other_classes, xr, normalizer): 115 | # linearize net around xr 116 | fxr, vjp_fun = jax.vjp(predict, xr) 117 | jvp_fun = partial(jax.jvp, predict, (xr,)) 118 | _, jxrp = jvp_fun((xr,)) 119 | offset = tuple(fx - Jx for fx, Jx in zip(fxr, jxrp)) 120 | 121 | Av_misc, vA_misc, u_misc = misclassification_polytope(offset[0], label, other_classes) 122 | Av_relu, vA_relu, u_relu = relu_polytope(offset[1], fxr[1]) 123 | 124 | assert u_misc.ndim == u_relu.ndim == 2 125 | assert u_misc.shape[0] == u_relu.shape[0] # batch dimension 126 | 127 | n_constraints = u_misc.shape[1] + u_relu.shape[1] 128 | 129 | if normalizer is not None: 130 | assert normalizer.shape == (n_constraints,) 131 | assert normalizer.dtype == np.float32 132 | 133 | _, N = u_misc.shape 134 | if normalizer is not None: 135 | u_misc = u_misc * normalizer[:N] 136 | u_relu = u_relu * normalizer[N:] 137 | 138 | def Adot(v): 139 | v = v.reshape((1,) + xr.shape[1:]) 140 | _, Vv = jvp_fun((v,)) 141 | Vv_misc, Vv_relu = Vv 142 | r_misc = Av_misc(Vv_misc) 143 | r_relu = Av_relu(Vv_relu) 144 | assert r_misc.ndim == r_relu.ndim == 2 145 | assert r_misc.shape[0] == r_relu.shape[0] # batch dimension 146 | r = np.concatenate((r_misc, r_relu), axis=1) 147 | r = r.squeeze(axis=0) 148 | assert r.shape == (n_constraints,) 149 | if normalizer is not None: 150 | r = normalizer * r 151 | return r 152 | 153 | def ATdot(v): 154 | assert v.shape == (n_constraints,) 155 | v = v[onp.newaxis] 156 | if normalizer is not None: 157 | v = normalizer * v 158 | _, N = u_misc.shape 159 | assert v.ndim == 2 160 | v_misc, v_relu = v[:, :N], v[:, N:] 161 | v_misc = vA_misc(v_misc) 162 | v_relu = vA_relu(v_relu) 163 | v = (v_misc, v_relu) 164 | r, = vjp_fun(v) 165 | r = r.reshape((-1,)) 166 | return r 167 | 168 | assert u_misc.shape[1] == len(other_classes) 169 | return Adot, ATdot, n_constraints, u_misc, u_relu 170 | 171 | 172 | @partial(jax.jit, static_argnums=(0,)) 173 | def operator_norm_lower_bound(get_A, xr, normalizer): 174 | logging.info('compiling operator_norm_lower_bound') 175 | Adot, ATdot, _, _, _ = get_A(xr, normalizer) 176 | 177 | def body_fun(i, state): 178 | z, n = state 179 | u = ATdot(Adot(z)) 180 | n = np.linalg.norm(u) 181 | return u / n, n 182 | 183 | # z = np.ones_like(xr.reshape(-1)) 184 | z = xr.reshape(-1) # a constant vector of e.g. ones fails if mean is subtracted 185 | _, n = jax.lax.fori_loop(0, 20, body_fun, (z, 0.)) 186 | return n 187 | 188 | 189 | @partial(jax.jit, static_argnums=(0,)) 190 | def init_region(get_A, xr, normalizer, v): 191 | logging.info('compiling init_region') 192 | Adot, _, _, u_misc, u_relu = get_A(xr, normalizer) 193 | Av = Adot(v) 194 | return Av, u_misc, u_relu 195 | 196 | 197 | @partial(jax.jit, static_argnums=(0, 2, 3)) 198 | def calculate_normalizer(get_A, xr, n_constraints, rows_per_layer, *, k, normalizer=None, misc_factor=1.): 199 | logging.info('compiling calculate_normalizer') 200 | if normalizer is None: 201 | normalizer = np.ones((n_constraints,), dtype=np.float32) 202 | _, ATdot, _, u_misc, u_relu = get_A(xr, normalizer) 203 | _, n_misc = u_misc.shape 204 | misc_norms, layer_norms = estimate_layer_norms(ATdot, n_misc, rows_per_layer, k=k) 205 | assert misc_norms.shape == (n_misc,) 206 | normalizer = [misc_factor / misc_norms] 207 | assert len(rows_per_layer) == len(layer_norms) 208 | for n, norm in zip(rows_per_layer, layer_norms): 209 | normalizer.append(np.ones((n,)) / norm) 210 | normalizer = np.concatenate(normalizer) 211 | return normalizer, misc_norms, layer_norms 212 | 213 | 214 | def estimate_layer_norms(ATdot, n_misc, rows_per_layer, *, k): 215 | """for each layer, samples k of the rows of A corresponding to that 216 | layer as well as all rows corresponding to the n logits and then 217 | estimates the norm of rows of A corresponding to each layer""" 218 | 219 | # TODO: consider using jax.random and thus drawing new samples every time; 220 | # right now we do the whole onehot vector creation statically 221 | indices = list(range(n_misc)) 222 | offset = n_misc 223 | for layer_size in rows_per_layer: 224 | indices.extend(random.sample(range(offset, offset + layer_size), k)) 225 | offset += layer_size 226 | 227 | logging.info(f'{len(indices)} randomly selected rows of A: {indices}') 228 | 229 | n_constraints = offset 230 | 231 | vs = onp.zeros((len(indices), n_constraints), dtype=onp.float32) 232 | for row, col in enumerate(indices): 233 | vs[row, col] = 1. 234 | 235 | ATdot = jax.vmap(ATdot) 236 | rows = ATdot(vs) 237 | assert rows.ndim == 2 238 | 239 | norms = np.linalg.norm(rows, axis=-1) 240 | 241 | # TODO: use median once supported by jax.numpy: 242 | # https://github.com/google/jax/issues/70 243 | 244 | layer_norms = [] 245 | for i in range(n_misc, len(norms), k): 246 | assert i + k <= len(norms) 247 | m = np.mean(norms[i:i + k]) 248 | layer_norms.append(m) 249 | 250 | return norms[:n_misc], layer_norms 251 | 252 | 253 | def line_search(predict_class, x0, label, x, minimum=0., maximum=1., num=100, s=None): 254 | x = jax.device_get(x) 255 | 256 | assert not is_device_array(x0) 257 | assert not is_device_array(label) 258 | 259 | assert x0.shape == x.shape 260 | assert x0.shape[0] == 1 # batch dimension 261 | 262 | if s is None: 263 | s = onp.linspace(minimum, maximum, num=num + 1)[1:] 264 | 265 | p = x - x0 266 | ps = s.reshape((-1,) + (1,) * (p.ndim - 1)) * p 267 | xs = x0 + ps 268 | 269 | assert xs.shape[1:] == x0.shape[1:] 270 | 271 | classes = jax.device_get(predict_class(xs)) 272 | assert classes.ndim == 1 273 | indices = onp.flatnonzero(classes != label) 274 | assert indices.ndim == 1 275 | try: 276 | best = indices[0] 277 | except IndexError: 278 | raise ValueError 279 | logging.info(f'best: {best} -> {s[best]}') 280 | return xs[best][onp.newaxis], classes[best], best 281 | 282 | 283 | def get_region(k, x0, best_adv, *, gamma): 284 | x0 = jax.device_get(x0) 285 | best_adv = jax.device_get(best_adv) 286 | 287 | # TODO: maybe check region around original input 288 | # if k == 0: 289 | # # try the region around the original input 290 | # return x0 291 | 292 | delta = biased_direction(x0, best_adv, prob=0.8) 293 | 294 | u = onp.linalg.norm(best_adv - x0) 295 | r = onp.random.uniform() 296 | logging.debug(f'sampled r = {r}') 297 | x = best_adv + delta / onp.linalg.norm(delta) * u * r**gamma 298 | x = jax.device_put(x) 299 | return x 300 | 301 | 302 | def biased_direction(x0, best_adv, *, prob): 303 | dx = x0 - best_adv 304 | dx = dx / onp.linalg.norm(dx.reshape(-1)) 305 | 306 | delta = onp.random.normal(size=x0.shape) 307 | delta = delta - onp.dot(delta.reshape(-1), dx.reshape(-1)) * dx 308 | delta = delta / onp.linalg.norm(delta.reshape(-1)) 309 | 310 | alpha = onp.random.uniform(0., onp.pi) 311 | 312 | if onp.random.uniform() > prob: 313 | # with probability 1 - prob, sample from the half space further away from x0 314 | alpha = -alpha 315 | 316 | return onp.sin(alpha) * dx + onp.cos(alpha) * delta 317 | 318 | 319 | def run(n_classes, predict, params, images, labels, find_starting_point, args): 320 | t0 = time.time() 321 | 322 | random.seed(22) 323 | onp.random.seed(22) 324 | 325 | logging.info(f'number of samples: {len(images)}') 326 | logging.info(f'n_classes: {n_classes}') 327 | 328 | predict = partial(predict, params) 329 | 330 | predict_class_logits_layer_sizes = partial(return_classes_logits_layer_sizes, predict) 331 | predict_class_logits_layer_sizes = jax.jit(predict_class_logits_layer_sizes) 332 | 333 | predict = flatten_predict(predict) 334 | 335 | def predict_class(x): 336 | return predict_class_logits_layer_sizes(x)[0] 337 | 338 | if args.accuracy: 339 | logging.info(f'accuracy: {accuracy(predict_class, images, labels)}') 340 | 341 | x0_host = images[args.image][onp.newaxis] 342 | label_host = labels[args.image][onp.newaxis] 343 | logging.info(f'label: {label_host}') 344 | 345 | x0 = jax.device_put(x0_host) 346 | x0_flat = x0.reshape((-1,)) 347 | 348 | l2 = partial(l2_distance, x0_host) 349 | 350 | x0_class, x0_logits, rows_per_layer = jax.device_get(predict_class_logits_layer_sizes(x0)) 351 | logging.info(f'predicted class: {x0_class}, logits: {x0_logits}') 352 | 353 | logging.info(f'rows per layer: {rows_per_layer}') 354 | 355 | if x0_class != label_host: 356 | logging.warning(f'unperturbed input is misclassified by the model as {x0_class}') 357 | result = { 358 | 'is_adv': True, 359 | 'x0': x0_host, 360 | 'label': label_host.item(), 361 | 'adv': x0_host, 362 | 'adv_class': x0_class, 363 | 'l2': 0., 364 | 'duration': time.time() - t0, 365 | } 366 | return result 367 | 368 | label = jax.device_put(label_host) 369 | 370 | best_adv, best_adv_class = find_starting_point(args, x0_host, label_host, x0_logits, predict_class) 371 | best_adv_l2 = l2(best_adv) 372 | best_adv_l2_hist = [(time.time() - t0, best_adv_l2)] 373 | 374 | if not args.no_line_search: 375 | logging.info('running line search to determine better starting point') 376 | best_adv, best_adv_class, _ = line_search(predict_class, x0_host, label_host, best_adv) 377 | best_adv_l2 = l2(best_adv) 378 | 379 | best_adv_l2_hist.append((time.time() - t0, best_adv_l2)) 380 | 381 | logging.info(f'starting point class: {best_adv_class}') 382 | 383 | best_adv_l2_hist_hist = [best_adv_l2_hist] 384 | 385 | other_classes = get_other_classes(exclude=label.squeeze(), total=n_classes, first=best_adv_class) 386 | if args.max_other_classes: 387 | other_classes = other_classes[:args.max_other_classes] 388 | logging.info(f'other classes: {other_classes}') 389 | 390 | n_constraints = len(other_classes) + sum(rows_per_layer) 391 | logging.info(f'n_constraints: {n_constraints}') 392 | 393 | total_solver_iterations = 0 394 | 395 | get_A = partial(generic_get_A, predict, label, other_classes) 396 | 397 | # ------------------------------------------------------------------------ 398 | # Loop over region 399 | # ------------------------------------------------------------------------ 400 | for region in range(args.regions): 401 | logging.info('-' * 70) 402 | logging.info(f'{region + 1}. REGION') 403 | logging.info('-' * 70) 404 | 405 | xr = get_region(region, x0, best_adv, gamma=args.gamma) 406 | 407 | if not args.no_normalization: 408 | normalizer, misc_norms, layer_norms = calculate_normalizer( 409 | get_A, xr, n_constraints, rows_per_layer, k=10, misc_factor=args.misc_factor) 410 | logging.info(f'misc norms: {misc_norms}') 411 | logging.info(f'layer norms: {layer_norms}') 412 | else: 413 | normalizer = None 414 | 415 | Ax0, u_misc, u_relu = init_region(get_A, xr, normalizer, x0) 416 | 417 | L = operator_norm_lower_bound(get_A, xr, normalizer) 418 | logging.info(f'L = {L}') 419 | 420 | best_adv_l2_hist = [(time.time() - t0, best_adv_l2)] 421 | 422 | # ------------------------------------------------------------------------ 423 | # Loop over other classes 424 | # ------------------------------------------------------------------------ 425 | for active in range(len(other_classes)): 426 | # update upper bounds 427 | mask = onehot(active, len(other_classes), dtype=np.float32) 428 | infs = 1 / mask - 1 429 | u_misc_active = u_misc + infs 430 | u = np.concatenate((u_misc_active, u_relu), axis=1) 431 | u = u.squeeze(axis=0) 432 | 433 | assert best_adv.shape[0] == x0.shape[0] == 1 434 | bound = 0.5 * best_adv_l2 ** 2 435 | logging.info(f'bound: {bound}') 436 | 437 | potential_adv, best_dual, counter = solve( 438 | x0_flat, Ax0, get_A, xr, normalizer, u, L, 439 | bound=bound, maxiter=args.iterations) 440 | 441 | total_solver_iterations += jax.device_get(counter).item() 442 | 443 | potential_adv = potential_adv.reshape(x0.shape) 444 | potential_adv_l2 = l2(potential_adv) 445 | closer = potential_adv_l2 < best_adv_l2 446 | logging.info(f'closer = {closer}') 447 | 448 | if closer: 449 | try: 450 | ratio = best_adv_l2 / potential_adv_l2 451 | if ratio > 1.1: 452 | s = onp.linspace(0.9, 1.1, num=101, endpoint=True) 453 | else: 454 | s = onp.linspace(0.9, ratio, num=101, endpoint=False) 455 | 456 | logging.info(f'running line search with factors between {s.min()} and {s.max()}') 457 | best_adv, best_adv_class, index = line_search( 458 | predict_class, x0_host, label_host, potential_adv, s=s) 459 | new_l2 = l2(best_adv) 460 | logging.info(f'-> new best adv with l2 = {new_l2} (before: {best_adv_l2})') 461 | assert new_l2 < best_adv_l2 462 | best_adv_l2 = new_l2 463 | except ValueError: 464 | logging.info(f'-> result not adversarial (tried with line search)') 465 | else: # the first line search succeeded 466 | if index == 0: # the range of our line search can be extended to even smaller values 467 | logging.info(f'running another line search with factors between 0 and 0.9') 468 | # this line search should not fail because 0.9 works for sure 469 | best_adv, best_adv_class, _ = line_search(predict_class, x0_host, label_host, potential_adv, 470 | minimum=0., maximum=0.90, num=100) 471 | new_l2 = l2(best_adv) 472 | logging.info(f'-> new best adv with l2 = {new_l2} (before: {best_adv_l2})') 473 | assert new_l2 <= best_adv_l2 474 | best_adv_l2 = new_l2 475 | 476 | best_adv_l2_hist.append((time.time() - t0, best_adv_l2)) 477 | logging.info('-' * 70) 478 | 479 | logging.info([l for _, l in best_adv_l2_hist]) 480 | best_adv_l2_hist_hist.append(best_adv_l2_hist) 481 | 482 | logging.info([[round(l, 2) for _, l in h] for h in best_adv_l2_hist_hist]) 483 | 484 | best_adv_l2 = l2(best_adv) 485 | logging.info(f'final adversarial has l2 = {best_adv_l2}') 486 | logging.info(f'total number of iterations in QP solver: {total_solver_iterations}') 487 | 488 | result = { 489 | 'x0': x0_host, 490 | 'label': label_host.item(), 491 | 'adv': best_adv, 492 | 'adv_class': best_adv_class, 493 | 'l2': best_adv_l2, 494 | 'duration': time.time() - t0, 495 | 'history': best_adv_l2_hist_hist, 496 | 'other_classes': onp.asarray(other_classes).tolist(), 497 | 'total_solver_iterations': total_solver_iterations, 498 | } 499 | return result 500 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import os 3 | import pickle 4 | import argparse 5 | import logging 6 | 7 | from examples import get_example 8 | from lr_attack import run 9 | 10 | 11 | def main(): 12 | parser = argparse.ArgumentParser() 13 | 14 | parser.add_argument('model', help='name of the model to attack') 15 | parser.add_argument('--image', type=int, default=0) 16 | parser.add_argument('--accuracy', action='store_true', help='first determines the accuracy of the model') 17 | parser.add_argument('--save', type=str, default=None, help='filename to save result to') 18 | 19 | # hyperparameters 20 | parser.add_argument('--regions', type=int, default=400) 21 | parser.add_argument('--iterations', type=int, default=500) 22 | parser.add_argument('--gamma', type=int, default=6, help='hyperparam of region selection') 23 | parser.add_argument('--misc-factor', type=float, default=75.) 24 | 25 | # advanced control over certain aspects (only if you know what you are doing) 26 | parser.add_argument('--nth-likely-class-starting-point', type=int, default=None) 27 | parser.add_argument('--no-line-search', action='store_true') 28 | parser.add_argument('--max-other-classes', type=int, default=None) 29 | parser.add_argument('--no-normalization', action='store_true') 30 | 31 | args = parser.parse_args() 32 | 33 | logging.getLogger().setLevel(logging.INFO) 34 | 35 | if args.save is not None: 36 | if os.path.exists(args.save): 37 | logging.warning(f'not runnning because results already exist: {args.save}') 38 | return 39 | 40 | result = run(*get_example(args.model), args=args) 41 | 42 | if args.save is not None: 43 | directory = os.path.dirname(args.save) 44 | if len(directory) > 0 and not os.path.exists(directory): 45 | os.makedirs(directory) 46 | with open(args.save, 'wb') as f: 47 | pickle.dump(result, f) 48 | 49 | 50 | if __name__ == '__main__': 51 | main() 52 | -------------------------------------------------------------------------------- /qpsolver.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import time 3 | import jax.numpy as np 4 | import jax 5 | from functools import partial 6 | 7 | 8 | def step(x0, get_A, xr, normalizer, b_finite, vec, L, counter, xp, mustar, mu): 9 | # --------------------------------------------------------------------- 10 | 11 | # fixed param: x0, b_finite, vec 12 | # variable param: L, counter 13 | # taken and returned: xp, mustar, mu 14 | # returned: maxfeasible, dual_objective, primal_dual_gap 15 | 16 | # --------------------------------------------------------------------- 17 | 18 | logging.info('compiling step') 19 | 20 | Adot, ATdot, _, _, _ = get_A(xr, normalizer) 21 | 22 | # compute gradient of the dual objective 23 | Axp = Adot(xp) 24 | gradq = Axp - vec 25 | 26 | # compute step of accelerated projected gradient descent 27 | mustarold = mustar 28 | mustar = np.maximum(0, mu - gradq / L) 29 | mu = mustar + (counter / (counter + 3)) * (mustar - mustarold) 30 | mu = np.maximum(0, mu) 31 | 32 | # update ATmu, alpha, beta are set to their optimal values in the dual 33 | # NOTICE: this improves the dual value, but it is a HACK as we optimize 34 | # over mu (and jump around wrt to alpha, beta) 35 | ATmu = ATdot(mu) 36 | alpha = np.maximum(0, x0 - ATmu - 1) 37 | beta = np.maximum(0, ATmu - x0) 38 | xp = alpha + ATmu - beta 39 | 40 | # compute primal objective 41 | # x = x0 - xp is the primal variable (at the dual optimal), need not be feasible 42 | primal_objective = 0.5 * np.linalg.norm(xp)**2 43 | dual_objective = -0.5 * np.linalg.norm(xp)**2 + x0.T.dot(xp) - b_finite.T.dot(mu) - np.sum(alpha) 44 | primal_dual_gap = primal_objective - dual_objective 45 | 46 | feasible = vec - Axp 47 | maxfeasible = np.amax(feasible) 48 | 49 | return xp, mustar, mu, dual_objective, primal_dual_gap, maxfeasible 50 | 51 | 52 | def cond_fun(maxiter, bound, feasStop, state): 53 | logging.info('compiling cond_fun') 54 | counter = state[1] 55 | dual_objective, primal_dual_gap, maxfeasible = state[7:10] 56 | 57 | cond1 = counter <= maxiter 58 | cond2 = dual_objective < bound 59 | cond3 = np.logical_or( 60 | np.absolute(primal_dual_gap) > 1e-6, 61 | np.logical_or( 62 | maxfeasible > feasStop, 63 | np.logical_and(counter < 200, maxfeasible >= 0) 64 | ) 65 | ) 66 | return np.logical_and(cond1, np.logical_and(cond2, cond3)) 67 | 68 | 69 | def state_update_fun(get_A, xr, normalizer, b_finite, vec, x0, state): 70 | logging.info('compiling state_update_fun') 71 | 72 | L, counter = state[0:2] 73 | xp, mustar, mu = state[2:5] 74 | xp, mustar, mu, dual_objective, primal_dual_gap, maxfeasible = step( 75 | x0, get_A, xr, normalizer, b_finite, vec, 76 | L, counter, xp, mustar, mu) 77 | 78 | # TODO: maybe use lax.cond once available https://github.com/google/jax/issues/325 79 | best_dual_objective, best_x = state[5:7] 80 | update_best = jax.lax.gt(dual_objective, best_dual_objective) 81 | best_x = update_best * (x0 - xp) + (1 - update_best) * best_x 82 | best_dual_objective = update_best * dual_objective + (1 - update_best) * best_dual_objective 83 | 84 | counter = counter + 1 85 | 86 | # TODO: update L if dual smaller than -100 87 | # if dual_objective < -100: 88 | # logging.warning('divergence due to hack with Lipschitz constant') 89 | # if L < LMAX: 90 | # logging.warning('increasing L and restarting from scratch') 91 | # L = min(10 * L, LMAX) 92 | # mu = np.zeros_like(mu) 93 | # mustar = mu 94 | # counter = 1 95 | 96 | return ( 97 | L, counter, 98 | xp, mustar, mu, 99 | best_dual_objective, best_x, 100 | dual_objective, primal_dual_gap, maxfeasible, 101 | ) 102 | 103 | 104 | # kwargs are treated as static, but bound changes, so don't use kwargs here 105 | @partial(jax.jit, static_argnums=(0, 2)) 106 | def solve_jit(x0, Ax0, get_A, xr, normalizer, b, L, bound, maxiter, feasStop): 107 | logging.info('compiling solve_jit') 108 | 109 | # constants 110 | vec = Ax0 - b 111 | 112 | # initialization 113 | mu = np.zeros_like(b) 114 | mustar = mu 115 | xp = np.zeros_like(x0) 116 | best_dual_objective = 0. 117 | best_x = x0 118 | counter = 1 119 | 120 | b_finite = np.where(np.isposinf(b), np.array(np.finfo(b.dtype).max), b) 121 | 122 | init_state = (L, counter, xp, mustar, mu, best_dual_objective, best_x, 0., np.inf, np.inf) 123 | 124 | _cond_fun = partial(cond_fun, maxiter, bound, feasStop) 125 | _state_update_fun = partial(state_update_fun, get_A, xr, normalizer, b_finite, vec, x0) 126 | 127 | final_state = jax.lax.while_loop(_cond_fun, _state_update_fun, init_state) 128 | 129 | counter = final_state[1] 130 | best_dual_objective = final_state[5] 131 | best_x = final_state[6] 132 | 133 | return counter, best_dual_objective, best_x 134 | 135 | 136 | def solve(x0, Ax0, get_A, xr, normalizer, b, L, *, bound=np.inf, maxiter=4000, feasStop=1e-8): 137 | """Solves the following quadratic programming (QP) problem: 138 | 139 | min_x 1/2 (x - x0)' * (x - x0) 140 | s.t. A * x ≤ b and 0 ≤ x ≤ 1 141 | """ 142 | 143 | t0 = time.time() 144 | counter, best_dual_objective, best_x = solve_jit(x0, Ax0, get_A, xr, normalizer, b, L, bound, maxiter, feasStop) 145 | t0 = time.time() - t0 146 | 147 | logging.info(f'took {t0:.1f} secs for {counter} it -> {counter / t0:.1f} it/sec') 148 | return best_x, best_dual_objective, counter 149 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max-line-length = 120 3 | -------------------------------------------------------------------------------- /staxmod.py: -------------------------------------------------------------------------------- 1 | """A modified version of stax that supports tracking of intermediate 2 | activations, in particular the inputs to non-affine layers.""" 3 | from jax.experimental.stax import (AvgPool, BatchNorm, Conv, Dense, FanInSum, 4 | FanOut, Flatten, GeneralConv, Identity, 5 | Relu) 6 | from jax import random 7 | import jax.numpy as np 8 | from jax.experimental import stax 9 | 10 | 11 | def affine(layer_fun): 12 | """Decorator that turns a layer into one that's compatible with tracking 13 | of additional outputs.""" 14 | # @functools.wraps(layer_fun) 15 | def wrapper(*args, **kwargs): 16 | init_fun, apply_fun = layer_fun(*args, **kwargs) 17 | 18 | def new_apply_fun(*args, **kwargs): 19 | return apply_fun(*args, **kwargs), () 20 | return init_fun, new_apply_fun 21 | return wrapper 22 | 23 | 24 | def affine_no_params(layer): 25 | """Decorator that turns a layer into one that's compatible with tracking 26 | of additional outputs.""" 27 | init_fun, apply_fun = layer 28 | 29 | def new_apply_fun(*args, **kwargs): 30 | return apply_fun(*args, **kwargs), () 31 | return init_fun, new_apply_fun 32 | 33 | 34 | def track_input_no_params(layer): 35 | init_fun, apply_fun = layer 36 | 37 | def new_apply_fun(params, inputs, rng=None): 38 | return apply_fun(params, inputs, rng=rng), (inputs,) 39 | return init_fun, new_apply_fun 40 | 41 | 42 | def serial(*layers): 43 | """Like stax.serial but separately tracks additional outputs 44 | for each layer.""" 45 | nlayers = len(layers) 46 | init_funs, apply_funs = zip(*layers) 47 | 48 | def init_fun(input_shape): 49 | params = [] 50 | for init_fun in init_funs: 51 | input_shape, param = init_fun(input_shape) 52 | params.append(param) 53 | return input_shape, params 54 | 55 | def apply_fun(params, inputs, rng=None): 56 | rngs = random.split(rng, nlayers) if rng is not None else (None,) * nlayers 57 | additional_outputs = [] 58 | for fun, param, rng in zip(apply_funs, params, rngs): 59 | inputs, additional_output = fun(param, inputs, rng=rng) 60 | additional_outputs.append(additional_output) 61 | return inputs, additional_outputs 62 | return init_fun, apply_fun 63 | 64 | 65 | def parallel(*layers): 66 | """Like stax.parallel but separately tracks additional outputs 67 | for each layer.""" 68 | nlayers = len(layers) 69 | init_funs, apply_funs = zip(*layers) 70 | 71 | def init_fun(input_shape): 72 | return zip(*[init(shape) for init, shape in zip(init_funs, input_shape)]) 73 | 74 | def apply_fun(params, inputs, rng=None): 75 | rngs = random.split(rng, nlayers) if rng is not None else (None,) * nlayers 76 | outputs = [] 77 | additional_outputs = [] 78 | for f, p, x, r in zip(apply_funs, params, inputs, rngs): 79 | output, additional_output = f(p, x, rng=r) 80 | outputs.append(output) 81 | additional_outputs.append(additional_output) 82 | return outputs, additional_outputs 83 | return init_fun, apply_fun 84 | 85 | 86 | AvgPool = affine(AvgPool) 87 | BatchNorm = affine(BatchNorm) 88 | Conv = affine(Conv) 89 | Dense = affine(Dense) 90 | FanInSum = affine_no_params(FanInSum) 91 | FanOut = affine(FanOut) 92 | Flatten = affine_no_params(Flatten) 93 | GeneralConv = affine(GeneralConv) 94 | Identity = affine_no_params(Identity) 95 | 96 | Relu = track_input_no_params(Relu) 97 | 98 | 99 | def leaky_relu(x, leakiness=0.01): 100 | return np.where(x >= 0, x, leakiness * x) 101 | 102 | 103 | LeakyRelu = stax._elemwise_no_params(leaky_relu) 104 | LeakyRelu = track_input_no_params(LeakyRelu) 105 | 106 | 107 | # TODO: MaxPool constraints 108 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as np 2 | import jax 3 | from jax.interpreters import xla 4 | 5 | 6 | def is_device_array(x): 7 | return isinstance(x, xla.DeviceArray) 8 | 9 | 10 | def scatter(indices, values, length): 11 | assert indices.ndim == 1 12 | assert values.shape[1:] == indices.shape 13 | batch_size = values.shape[0] 14 | 15 | def f(a): 16 | return a[:, indices] 17 | 18 | base = np.zeros((batch_size, length), np.float32) 19 | _, grad = jax.vjp(f, base) 20 | (out,) = grad(values) 21 | return base + out 22 | 23 | 24 | # def scatter(indices, values, length): 25 | # dnums = jax.lax.ScatterDimensionNumbers( 26 | # update_window_dims=(), 27 | # inserted_window_dims=(0,), 28 | # scatter_dims_to_operand_dims=(0,), 29 | # index_vector_dim=1) 30 | # indices = np.atleast_1d(np.asarray([indices]).squeeze()) 31 | # values = np.atleast_1d(np.asarray([values]).squeeze()) 32 | # assert indices.ndim == values.ndim == 1 33 | # assert indices.shape == values.shape 34 | # return jax.lax.scatter_add(np.zeros(length, values.dtype), indices, values, dnums) 35 | 36 | 37 | def onehot(index, length, dtype=np.int32): 38 | assert isinstance(index, int) 39 | onehot = np.arange(length) == index 40 | onehot = onehot.astype(dtype) 41 | return onehot 42 | -------------------------------------------------------------------------------- /weights/convnet.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jonasrauber/linear-region-attack/01f51aa2b79258d5040fcba68b24cd942c0aba51/weights/convnet.pickle --------------------------------------------------------------------------------