├── .editorconfig ├── .gitignore ├── LICENSE.txt ├── README.rst ├── robustml ├── __init__.py ├── attack.py ├── dataset.py ├── evaluate.py ├── model.py ├── provider.py └── threat_model.py ├── setup.py └── test.py /.editorconfig: -------------------------------------------------------------------------------- 1 | ; http://editorconfig.org 2 | root = true 3 | 4 | [*] 5 | indent_style = space 6 | indent_size = 4 7 | charset = utf-8 8 | end_of_line = lf 9 | insert_final_newline = true 10 | trim_trailing_whitespace = true 11 | 12 | [*.md] 13 | trim_trailing_whitespace = false 14 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | dist/ 2 | build/ 3 | *.pyc 4 | *.egg-info 5 | *.gz 6 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | Copyright (c) 2018 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in all 11 | copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 19 | SOFTWARE. 20 | -------------------------------------------------------------------------------- /README.rst: -------------------------------------------------------------------------------- 1 | robust-ml 2 | ========= 3 | 4 | Interfaces for defining Robust ML models and precisely specifying the threat 5 | models under which they claim to be secure. Also includes interfaces for 6 | specifying attacks and evaluating attacks against models. 7 | 8 | The motivation behind this project is to make it easy to make specific, 9 | testable claims about the robustness about machine learning models. Read more 10 | in the `FAQ `__. 11 | 12 | Installation 13 | ------------ 14 | 15 | You can install from PyPI: ``pip install robustml``. 16 | 17 | Usage 18 | ----- 19 | 20 | See `this repository `__ for a complete 21 | example of implenenting a model, implementing an attack, and evaluating the 22 | attack against the model. 23 | 24 | If you're implementing a **defense**, you should implement 25 | ``robustml.model.Model``. See `here 26 | `__ for an 27 | example. 28 | 29 | If you're implementing an **attack** against a specific defense, you should 30 | implement ``robustml.attack.Attack``. See `here 31 | `__ for an example. 32 | 33 | To **evaluate** a specific attack against a specific defense, use 34 | ``robustml.evaluate.evaluate()``. See `here 35 | `__ for an example. 36 | 37 | Contributing 38 | ------------ 39 | 40 | Do you have ideas on how to improve the robustml package? Have a feature 41 | request (such as a specification of a new threat model) or bug report? Great! 42 | Please open an `issue `__ or 43 | submit a `pull request `__. 44 | 45 | Before contributing a major change, it's recommended that you open a pull 46 | request first and get feedback on the idea before investing time in the 47 | implementation. 48 | 49 | Packaging 50 | --------- 51 | 52 | 1. Update version information. 53 | 54 | 2. Build the package using ``python setup.py sdist bdist_wheel``. 55 | 56 | 3. Sign and upload the package using ``twine upload -s dist/*``. 57 | 58 | 4. Create a signed tag in the git repo with the version number that was 59 | uploaded to PyPI. 60 | -------------------------------------------------------------------------------- /robustml/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = '0.0.3' 2 | 3 | from . import threat_model 4 | from . import dataset 5 | from . import model 6 | from . import attack 7 | from . import provider 8 | from . import evaluate 9 | -------------------------------------------------------------------------------- /robustml/attack.py: -------------------------------------------------------------------------------- 1 | from abc import ABCMeta, abstractmethod 2 | 3 | class Attack(metaclass=ABCMeta): 4 | @abstractmethod 5 | def run(self, x, y, target): 6 | ''' 7 | Returns an adversarial example for original input `x` and true label 8 | `y`. If `target` is not `None`, then the adversarial example should be 9 | targeted to be classified as `target`. 10 | ''' 11 | raise NotImplementedError 12 | -------------------------------------------------------------------------------- /robustml/dataset.py: -------------------------------------------------------------------------------- 1 | from abc import ABCMeta, abstractmethod 2 | 3 | ''' 4 | Identifiers for known datasets. 5 | ''' 6 | 7 | class Dataset: 8 | ''' 9 | You should not subclass this in your own code, you should only use the 10 | datasets defined here. 11 | ''' 12 | 13 | @property 14 | @abstractmethod 15 | def shape(self): 16 | raise NotImplementedError 17 | 18 | @property 19 | @abstractmethod 20 | def labels(self): 21 | raise NotImplementedError 22 | 23 | class MNIST(Dataset): 24 | ''' 25 | Data points are 28x28 arrays with elements in [0, 1]. 26 | ''' 27 | 28 | @property 29 | def shape(self): 30 | return (28, 28) 31 | 32 | @property 33 | def labels(self): 34 | return 10 35 | 36 | class FMNIST(Dataset): 37 | ''' 38 | Data points are 28x28 arrays with elements in [0, 1]. 39 | ''' 40 | 41 | @property 42 | def shape(self): 43 | return (28, 28) 44 | 45 | @property 46 | def labels(self): 47 | return 10 48 | 49 | class GTS(Dataset): 50 | ''' 51 | Data points are 32x32x3 arrays with elements in [0, 1]. 52 | ''' 53 | 54 | @property 55 | def shape(self): 56 | return (32, 32, 3) 57 | 58 | @property 59 | def labels(self): 60 | return 43 61 | 62 | class CIFAR10(Dataset): 63 | ''' 64 | Data points are 32x32x3 arrays with elements in [0, 1]. 65 | ''' 66 | 67 | @property 68 | def shape(self): 69 | return (32, 32, 3) 70 | 71 | @property 72 | def labels(self): 73 | return 10 74 | 75 | class ImageNet(Dataset): 76 | ''' 77 | Data points are ?x?x3 arrays with elements in [0, 1]. 78 | 79 | Dimensions are specified in the constructor. 80 | ''' 81 | 82 | def __init__(self, shape=None): 83 | ''' 84 | Shape is a 3-tuple (height, width, channels) describing the shape of 85 | the input image to the model. 86 | ''' 87 | if not isinstance(shape, tuple) or len(shape) != 3 \ 88 | or not all(isinstance(i, int) for i in shape) \ 89 | or not shape[-1] == 3: 90 | raise ValueError('bad shape: %s' % str(shape)) 91 | self._shape = shape 92 | 93 | @property 94 | def shape(self): 95 | return self._shape 96 | 97 | @property 98 | def labels(self): 99 | return 1000 100 | -------------------------------------------------------------------------------- /robustml/evaluate.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import sys 3 | 4 | def evaluate(model, attack, provider, start=None, end=None, deterministic=False, debug=False): 5 | ''' 6 | Evaluate an attack on a particular model and return attack success rate. 7 | 8 | An attack is allowed to be adaptive, so it's fine to design the attack 9 | based on the specific model it's supposed to break. 10 | 11 | `start` (inclusive) and `end` (exclusive) are indices to evaluate on. If 12 | unspecified, evaluates on the entire dataset. 13 | 14 | `deterministic` specifies whether to seed the RNG with a constant value for 15 | a more deterministic test (so randomly selected target classes are chosen 16 | in a pseudorandom way). 17 | ''' 18 | 19 | if not provider.provides(model.dataset): 20 | raise ValueError('provider does not provide correct dataset') 21 | if start is not None and not (0 <= start < len(provider)): 22 | raise ValueError('start value out of range') 23 | if end is not None and not (0 <= end <= len(provider)): 24 | raise ValueError('end value out of range') 25 | 26 | threat_model = model.threat_model 27 | targeted = threat_model.targeted 28 | 29 | success = 0 30 | total = 0 31 | for i in range(start, end): 32 | print('evaluating %d of [%d, %d)' % (i, start, end), file=sys.stderr) 33 | total += 1 34 | x, y = provider[i] 35 | target = None 36 | if targeted: 37 | target = choose_target(i, y, model.dataset.labels, deterministic) 38 | x_adv = attack.run(np.copy(x), y, target) 39 | if not threat_model.check(np.copy(x), np.copy(x_adv)): 40 | if debug: 41 | print('check failed', file=sys.stderr) 42 | continue 43 | y_adv = model.classify(np.copy(x_adv)) 44 | if debug: 45 | print('true = %d, adv = %d' % (y, y_adv), file=sys.stderr) 46 | if targeted: 47 | if y_adv == target: 48 | success += 1 49 | else: 50 | if y_adv != y: 51 | success += 1 52 | 53 | success_rate = success / total 54 | 55 | return success_rate 56 | 57 | def choose_target(index, true_label, num_labels, deterministic=False): 58 | if deterministic: 59 | rng = np.random.RandomState(index) 60 | else: 61 | rng = np.random.RandomState() 62 | 63 | target = true_label 64 | while target == true_label: 65 | target = rng.randint(0, num_labels) 66 | 67 | return target 68 | -------------------------------------------------------------------------------- /robustml/model.py: -------------------------------------------------------------------------------- 1 | from abc import ABCMeta, abstractmethod 2 | 3 | class Model(metaclass=ABCMeta): 4 | ''' 5 | Interface for a model (classifier). 6 | 7 | Besides the required methods below, a model should do a reasonable job of 8 | providing easy access to internals to make white box attacks easier. For 9 | example, a model using TensorFlow might want to provide access to the input 10 | tensor placeholder and the tensor representing the logits output of the 11 | classifier. 12 | ''' 13 | 14 | @property 15 | @abstractmethod 16 | def dataset(self): 17 | ''' 18 | A concrete instance of a subclass of `robustml.dataset.Dataset`. 19 | ''' 20 | raise NotImplementedError 21 | 22 | @property 23 | @abstractmethod 24 | def threat_model(self): 25 | ''' 26 | An instance of `robustml.threat_model.ThreatModel`, ideally 27 | one of the pre-defined concrete threat models. 28 | ''' 29 | raise NotImplementedError 30 | 31 | @abstractmethod 32 | def classify(self, x): 33 | ''' 34 | Returns the label for the input x (as a Python integer). 35 | ''' 36 | raise NotImplementedError 37 | -------------------------------------------------------------------------------- /robustml/provider.py: -------------------------------------------------------------------------------- 1 | from abc import ABCMeta, abstractmethod 2 | import os 3 | import numpy as np 4 | import pickle 5 | import gzip 6 | import PIL.Image 7 | import csv 8 | 9 | from . import dataset as d 10 | 11 | class Provider(metaclass=ABCMeta): 12 | @abstractmethod 13 | def provides(self, dataset): 14 | ''' 15 | Returns whether or not this provider can provide data for the given 16 | dataset. 17 | ''' 18 | raise NotImplementedError 19 | 20 | @abstractmethod 21 | def __len__(self): 22 | raise NotImplementedError 23 | 24 | @abstractmethod 25 | def __getitem__(self, index): 26 | ''' 27 | Returns a tuple (x, y) containing the data and label of the given 28 | index. 29 | ''' 30 | raise NotImplementedError 31 | 32 | class MNIST(Provider): 33 | def __init__(self, image_path, label_path): 34 | ''' 35 | image_path is a path to the 't10k-images-idx3-ubyte.gz' from 36 | 'http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz' 37 | 38 | label_path is a path to the 't10k-labels-idx1-ubyte.gz' from 39 | 'http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz' 40 | ''' 41 | with gzip.open(image_path, 'rb') as f: 42 | images = f.read() 43 | assert images[:4] == b'\x00\x00\x08\x03' 44 | images = np.frombuffer(images[16:], dtype=np.uint8) 45 | assert len(images) == 7840000 46 | images = images.reshape((10000, 28, 28)).astype(np.float32) / 255.0 47 | with gzip.open(label_path, 'rb') as f: 48 | labels = f.read() 49 | assert labels[:4] == b'\x00\x00\x08\x01' 50 | labels = np.frombuffer(labels[8:], dtype=np.uint8) 51 | assert len(labels) == 10000 52 | 53 | self.xs = images 54 | self.ys = labels 55 | 56 | def provides(self, dataset): 57 | return isinstance(dataset, d.MNIST) 58 | 59 | def __len__(self): 60 | return 10000 61 | 62 | def __getitem__(self, index): 63 | x = self.xs[index] 64 | y = self.ys[index] 65 | return x, y 66 | 67 | class FMNIST(Provider): 68 | def __init__(self, image_path, label_path): 69 | ''' 70 | image_path is a path to the 't10k-images-idx3-ubyte.gz' from 71 | 'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz' 72 | 73 | label_path is a path to the 't10k-labels-idx1-ubyte.gz' from 74 | 'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz' 75 | ''' 76 | with gzip.open(image_path, 'rb') as f: 77 | images = f.read() 78 | assert images[:4] == b'\x00\x00\x08\x03' 79 | images = np.frombuffer(images[16:], dtype=np.uint8) 80 | assert len(images) == 7840000 81 | images = images.reshape((10000, 28, 28)).astype(np.float32) / 255.0 82 | with gzip.open(label_path, 'rb') as f: 83 | labels = f.read() 84 | assert labels[:4] == b'\x00\x00\x08\x01' 85 | labels = np.frombuffer(labels[8:], dtype=np.uint8) 86 | assert len(labels) == 10000 87 | 88 | self.xs = images 89 | self.ys = labels 90 | 91 | def provides(self, dataset): 92 | return isinstance(dataset, d.FMNIST) 93 | 94 | def __len__(self): 95 | return 10000 96 | 97 | def __getitem__(self, index): 98 | x = self.xs[index] 99 | y = self.ys[index] 100 | return x, y 101 | 102 | class GTS(Provider): 103 | def __init__(self, test_data): 104 | ''' 105 | test_data is a path to the test set of German Traffic Sign dataset (GTSRB/) from 106 | `http://benchmark.ini.rub.de/Dataset/GTSRB_Final_Test_Images.zip` 107 | 108 | Note that the labels should be obtained separately from 109 | `http://benchmark.ini.rub.de/Dataset/GTSRB_Final_Test_GT.zip` 110 | and placed to `test_data` 111 | ''' 112 | height, width = 32, 32 # to resize the GTS images that originally have various sizes (from 15x15 to 250x250) 113 | images, labels = [], [] 114 | 115 | f_annotation = test_data + 'GT-final_test.csv' 116 | gtFile = open(f_annotation) # annotations file 117 | gtReader = csv.reader(gtFile, delimiter=';') # csv parser for annotations 118 | for row in list(gtReader)[1:]: # loop over all images in the annotation file (ignoring the header) 119 | image_file_name = row[0] 120 | img = PIL.Image.open(test_data + 'Final_Test/Images/' + image_file_name) 121 | img = img.resize((height, width), PIL.Image.ANTIALIAS) 122 | images.append(np.array(img)) 123 | labels.append(int(row[7])) 124 | gtFile.close() 125 | self.xs = np.array(images, dtype=np.float32) / 255.0 126 | self.ys = np.array(labels) 127 | 128 | def provides(self, dataset): 129 | return isinstance(dataset, d.GTS) 130 | 131 | def __len__(self): 132 | return 12630 133 | 134 | def __getitem__(self, index): 135 | return self.xs[index], self.ys[index] 136 | 137 | class CIFAR10(Provider): 138 | def __init__(self, test_data): 139 | ''' 140 | test_data is a path to the 'test_batch' file from 141 | 'http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz' 142 | ''' 143 | with open(test_data, 'rb') as f: 144 | data = pickle.load(f, encoding='bytes') 145 | self.xs = data[b'data'].reshape((10000,3,32,32)).astype(np.float32) / 255.0 146 | self.xs = np.transpose(self.xs, (0,2,3,1)) # (N,3,32,32) -> (N,32,32,3) 147 | self.ys = data[b'labels'] 148 | 149 | def provides(self, dataset): 150 | return isinstance(dataset, d.CIFAR10) 151 | 152 | def __len__(self): 153 | return 10000 154 | 155 | def __getitem__(self, index): 156 | return self.xs[index], self.ys[index] 157 | 158 | class ImageNet(Provider): 159 | def __init__(self, path, shape): 160 | self._path = path 161 | self._shape = shape 162 | 163 | def provides(self, dataset): 164 | return isinstance(dataset, d.ImageNet) and self._shape == dataset.shape 165 | 166 | def __len__(self): 167 | return 50000 168 | 169 | def __getitem__(self, index): 170 | data_path = os.path.join(self._path, 'val') 171 | image_paths = sorted([os.path.join(data_path, i) for i in os.listdir(data_path)]) 172 | assert len(image_paths) == 50000 173 | labels_path = os.path.join(self._path, 'val.txt') 174 | with open(labels_path) as labels_file: 175 | labels = [i.split(' ') for i in labels_file.read().strip().split('\n')] 176 | labels = {os.path.basename(i[0]): int(i[1]) for i in labels} 177 | path = image_paths[index] 178 | x = self._load_image(path) 179 | y = labels[os.path.basename(path)] 180 | return x, y 181 | 182 | # get centered crop of self._size 183 | def _load_image(self, path): 184 | h, w, c = self._shape 185 | aspect = w / h 186 | image = PIL.Image.open(path) 187 | image_aspect = image.width / image.height 188 | 189 | if image_aspect > aspect: 190 | # image is wider than our aspect ratio 191 | new_height = image.height 192 | height_off = 0 193 | new_width = int(aspect * new_height) 194 | width_off = (image.width - new_width) // 2 195 | else: 196 | # image is taller than our aspect ratio 197 | new_width = image.width 198 | width_off = 0 199 | new_height = int(new_width / aspect) 200 | height_off = (image.height - new_height) // 2 201 | 202 | # box is (left, upper, right, lower) 203 | image = image.crop(( 204 | width_off, 205 | height_off, 206 | width_off+new_width, 207 | height_off+new_height 208 | )) 209 | 210 | image = image.resize((w, h)) 211 | 212 | arr = np.asarray(image).astype(np.float32) / 255.0 213 | if arr.ndim == 2: 214 | # stack greyscale image 215 | arr = np.repeat(arr[:,:,np.newaxis], repeats=3, axis=2) 216 | if arr.shape[2] == 4: 217 | # remove alpha channel 218 | arr = arr[:,:,:3] 219 | assert arr.shape == self._shape 220 | return arr 221 | 222 | -------------------------------------------------------------------------------- /robustml/threat_model.py: -------------------------------------------------------------------------------- 1 | from abc import ABCMeta, abstractmethod 2 | import numpy as np 3 | 4 | from . import dataset 5 | 6 | class ThreatModel(metaclass=ABCMeta): 7 | @abstractmethod 8 | def check(self, original, perturbed): 9 | ''' 10 | Returns whether the perturbed image is a valid perturbation of the 11 | original under the threat model. 12 | 13 | `original` and `perturbed` are numpy arrays of the same dtype and 14 | shape. 15 | ''' 16 | raise NotImplementedError 17 | 18 | @property 19 | @abstractmethod 20 | def targeted(self): 21 | ''' 22 | Returns whether the threat model only includes targeted attacks 23 | (requiring the attack to be capable of synthesizing targeted 24 | adversarial examples). 25 | ''' 26 | raise NotImplementedError 27 | 28 | class Or(ThreatModel): 29 | ''' 30 | A union of threat models. 31 | ''' 32 | 33 | def __init__(self, *threat_models): 34 | self._threat_models = threat_models 35 | 36 | def check(self, original, perturbed): 37 | return any(i.check(original, perturbed) for i in self._threat_models) 38 | 39 | @property 40 | def targeted(self): 41 | return all(i.targeted for i in self._threat_models) 42 | 43 | class And(ThreatModel): 44 | ''' 45 | An intersection of threat models. 46 | ''' 47 | 48 | def __init__(self, *threat_models): 49 | self._threat_models = threat_models 50 | 51 | def check(self, original, perturbed): 52 | return all(i.check(original, perturbed) for i in self._threat_models) 53 | 54 | @property 55 | def targeted(self): 56 | return any(i.targeted for i in self._threat_models) 57 | 58 | class Lp(ThreatModel): 59 | ''' 60 | Bounded L_p perturbation. Given a `p` and `epsilon`, x' is a valid 61 | perturbation of x if the following holds: 62 | 63 | || x - x' ||_p <= \epsilon 64 | ''' 65 | 66 | _SLOP = 0.0001 # to account for rounding errors 67 | 68 | def __init__(self, p, epsilon, targeted=False): 69 | self._p = p 70 | self._epsilon = epsilon 71 | self._targeted = targeted 72 | 73 | def check(self, original, perturbed): 74 | # we want to treat the inputs as big vectors 75 | original = np.ndarray.flatten(original) 76 | perturbed = np.ndarray.flatten(perturbed) 77 | # ensure it's a valid image 78 | if np.min(perturbed) < -self._SLOP or np.max(perturbed) > 1+self._SLOP: 79 | return False 80 | norm = np.linalg.norm(original - perturbed, ord=self._p) 81 | return norm <= self._epsilon + self._SLOP 82 | 83 | @property 84 | def targeted(self): 85 | return self._targeted 86 | 87 | @property 88 | def p(self): 89 | return self._p 90 | 91 | @property 92 | def epsilon(self): 93 | return self._epsilon 94 | 95 | class L0(Lp): 96 | def __init__(self, epsilon, targeted=False): 97 | super().__init__(p=0, epsilon=epsilon, targeted=targeted) 98 | 99 | class L1(Lp): 100 | def __init__(self, epsilon, targeted=False): 101 | super().__init__(p=1, epsilon=epsilon, targeted=targeted) 102 | 103 | class L2(Lp): 104 | def __init__(self, epsilon, targeted=False): 105 | super().__init__(p=2, epsilon=epsilon, targeted=targeted) 106 | 107 | class Linf(Lp): 108 | ''' 109 | Bounded L_inf perturbation. Given a `p` and `epsilon`, x' is a valid 110 | perturbation of x if the following holds: 111 | 112 | || x - x' ||_\infty <= \epsilon 113 | 114 | >>> model = Linf(0.1) 115 | >>> x = np.array([0.1, 0.2, 0.3]) 116 | >>> model.check(x, x) 117 | True 118 | >>> model.targeted 119 | False 120 | >>> model = Linf(0.1, targeted=True) 121 | >>> model.targeted 122 | True 123 | >>> y = np.array([0.1, 0.25, 0.32]) 124 | >>> model.check(x, y) 125 | True 126 | >>> z = np.array([0.3, 0.2, 0.3]) 127 | >>> model.check(x, z) 128 | False 129 | ''' 130 | 131 | def __init__(self, epsilon, targeted=False): 132 | super().__init__(p=np.inf, epsilon=epsilon, targeted=targeted) 133 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from codecs import open # For a consistent encoding 3 | from os import path 4 | import re 5 | 6 | 7 | here = path.dirname(__file__) 8 | 9 | 10 | with open(path.join(here, 'README.rst'), encoding='utf-8') as f: 11 | long_description = f.read() 12 | 13 | 14 | def read(*names, **kwargs): 15 | with open( 16 | path.join(here, *names), 17 | encoding=kwargs.get("encoding", "utf8") 18 | ) as fp: 19 | return fp.read() 20 | 21 | 22 | def find_version(*file_paths): 23 | version_file = read(*file_paths) 24 | version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", 25 | version_file, re.M) 26 | if version_match: 27 | return version_match.group(1) 28 | raise RuntimeError("Unable to find version string.") 29 | 30 | 31 | setup( 32 | name='robustml', 33 | 34 | version=find_version('robustml', '__init__.py'), 35 | 36 | description='Robust ML API', 37 | long_description=long_description, 38 | 39 | url='https://github.com/robust-ml/robust-ml', 40 | 41 | author='Anish Athalye', 42 | author_email='me@anishathalye.com', 43 | 44 | license='MIT', 45 | 46 | packages=['robustml'], 47 | 48 | install_requires=[ 49 | 'numpy>=1,<2', 50 | 'Pillow>=5,<6', 51 | ], 52 | ) 53 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import doctest 2 | import robustml 3 | 4 | # We need to explicitly list all modules here. This is not super pretty, but 5 | # the testing here shouldn't be too involved. If there's ever a need for 6 | # fancier testing, we can switch to a more complete testing framework. 7 | 8 | TEST_MODULES = [ 9 | robustml.threat_model, 10 | ] 11 | 12 | if __name__ == '__main__': 13 | for module in TEST_MODULES: 14 | doctest.testmod(module) 15 | --------------------------------------------------------------------------------