├── .editorconfig
├── .gitignore
├── LICENSE.txt
├── README.rst
├── robustml
├── __init__.py
├── attack.py
├── dataset.py
├── evaluate.py
├── model.py
├── provider.py
└── threat_model.py
├── setup.py
└── test.py
/.editorconfig:
--------------------------------------------------------------------------------
1 | ; http://editorconfig.org
2 | root = true
3 |
4 | [*]
5 | indent_style = space
6 | indent_size = 4
7 | charset = utf-8
8 | end_of_line = lf
9 | insert_final_newline = true
10 | trim_trailing_whitespace = true
11 |
12 | [*.md]
13 | trim_trailing_whitespace = false
14 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | dist/
2 | build/
3 | *.pyc
4 | *.egg-info
5 | *.gz
6 |
--------------------------------------------------------------------------------
/LICENSE.txt:
--------------------------------------------------------------------------------
1 | Copyright (c) 2018
2 |
3 | Permission is hereby granted, free of charge, to any person obtaining a copy
4 | of this software and associated documentation files (the "Software"), to deal
5 | in the Software without restriction, including without limitation the rights
6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7 | copies of the Software, and to permit persons to whom the Software is
8 | furnished to do so, subject to the following conditions:
9 |
10 | The above copyright notice and this permission notice shall be included in all
11 | copies or substantial portions of the Software.
12 |
13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
19 | SOFTWARE.
20 |
--------------------------------------------------------------------------------
/README.rst:
--------------------------------------------------------------------------------
1 | robust-ml
2 | =========
3 |
4 | Interfaces for defining Robust ML models and precisely specifying the threat
5 | models under which they claim to be secure. Also includes interfaces for
6 | specifying attacks and evaluating attacks against models.
7 |
8 | The motivation behind this project is to make it easy to make specific,
9 | testable claims about the robustness about machine learning models. Read more
10 | in the `FAQ `__.
11 |
12 | Installation
13 | ------------
14 |
15 | You can install from PyPI: ``pip install robustml``.
16 |
17 | Usage
18 | -----
19 |
20 | See `this repository `__ for a complete
21 | example of implenenting a model, implementing an attack, and evaluating the
22 | attack against the model.
23 |
24 | If you're implementing a **defense**, you should implement
25 | ``robustml.model.Model``. See `here
26 | `__ for an
27 | example.
28 |
29 | If you're implementing an **attack** against a specific defense, you should
30 | implement ``robustml.attack.Attack``. See `here
31 | `__ for an example.
32 |
33 | To **evaluate** a specific attack against a specific defense, use
34 | ``robustml.evaluate.evaluate()``. See `here
35 | `__ for an example.
36 |
37 | Contributing
38 | ------------
39 |
40 | Do you have ideas on how to improve the robustml package? Have a feature
41 | request (such as a specification of a new threat model) or bug report? Great!
42 | Please open an `issue `__ or
43 | submit a `pull request `__.
44 |
45 | Before contributing a major change, it's recommended that you open a pull
46 | request first and get feedback on the idea before investing time in the
47 | implementation.
48 |
49 | Packaging
50 | ---------
51 |
52 | 1. Update version information.
53 |
54 | 2. Build the package using ``python setup.py sdist bdist_wheel``.
55 |
56 | 3. Sign and upload the package using ``twine upload -s dist/*``.
57 |
58 | 4. Create a signed tag in the git repo with the version number that was
59 | uploaded to PyPI.
60 |
--------------------------------------------------------------------------------
/robustml/__init__.py:
--------------------------------------------------------------------------------
1 | __version__ = '0.0.3'
2 |
3 | from . import threat_model
4 | from . import dataset
5 | from . import model
6 | from . import attack
7 | from . import provider
8 | from . import evaluate
9 |
--------------------------------------------------------------------------------
/robustml/attack.py:
--------------------------------------------------------------------------------
1 | from abc import ABCMeta, abstractmethod
2 |
3 | class Attack(metaclass=ABCMeta):
4 | @abstractmethod
5 | def run(self, x, y, target):
6 | '''
7 | Returns an adversarial example for original input `x` and true label
8 | `y`. If `target` is not `None`, then the adversarial example should be
9 | targeted to be classified as `target`.
10 | '''
11 | raise NotImplementedError
12 |
--------------------------------------------------------------------------------
/robustml/dataset.py:
--------------------------------------------------------------------------------
1 | from abc import ABCMeta, abstractmethod
2 |
3 | '''
4 | Identifiers for known datasets.
5 | '''
6 |
7 | class Dataset:
8 | '''
9 | You should not subclass this in your own code, you should only use the
10 | datasets defined here.
11 | '''
12 |
13 | @property
14 | @abstractmethod
15 | def shape(self):
16 | raise NotImplementedError
17 |
18 | @property
19 | @abstractmethod
20 | def labels(self):
21 | raise NotImplementedError
22 |
23 | class MNIST(Dataset):
24 | '''
25 | Data points are 28x28 arrays with elements in [0, 1].
26 | '''
27 |
28 | @property
29 | def shape(self):
30 | return (28, 28)
31 |
32 | @property
33 | def labels(self):
34 | return 10
35 |
36 | class FMNIST(Dataset):
37 | '''
38 | Data points are 28x28 arrays with elements in [0, 1].
39 | '''
40 |
41 | @property
42 | def shape(self):
43 | return (28, 28)
44 |
45 | @property
46 | def labels(self):
47 | return 10
48 |
49 | class GTS(Dataset):
50 | '''
51 | Data points are 32x32x3 arrays with elements in [0, 1].
52 | '''
53 |
54 | @property
55 | def shape(self):
56 | return (32, 32, 3)
57 |
58 | @property
59 | def labels(self):
60 | return 43
61 |
62 | class CIFAR10(Dataset):
63 | '''
64 | Data points are 32x32x3 arrays with elements in [0, 1].
65 | '''
66 |
67 | @property
68 | def shape(self):
69 | return (32, 32, 3)
70 |
71 | @property
72 | def labels(self):
73 | return 10
74 |
75 | class ImageNet(Dataset):
76 | '''
77 | Data points are ?x?x3 arrays with elements in [0, 1].
78 |
79 | Dimensions are specified in the constructor.
80 | '''
81 |
82 | def __init__(self, shape=None):
83 | '''
84 | Shape is a 3-tuple (height, width, channels) describing the shape of
85 | the input image to the model.
86 | '''
87 | if not isinstance(shape, tuple) or len(shape) != 3 \
88 | or not all(isinstance(i, int) for i in shape) \
89 | or not shape[-1] == 3:
90 | raise ValueError('bad shape: %s' % str(shape))
91 | self._shape = shape
92 |
93 | @property
94 | def shape(self):
95 | return self._shape
96 |
97 | @property
98 | def labels(self):
99 | return 1000
100 |
--------------------------------------------------------------------------------
/robustml/evaluate.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import sys
3 |
4 | def evaluate(model, attack, provider, start=None, end=None, deterministic=False, debug=False):
5 | '''
6 | Evaluate an attack on a particular model and return attack success rate.
7 |
8 | An attack is allowed to be adaptive, so it's fine to design the attack
9 | based on the specific model it's supposed to break.
10 |
11 | `start` (inclusive) and `end` (exclusive) are indices to evaluate on. If
12 | unspecified, evaluates on the entire dataset.
13 |
14 | `deterministic` specifies whether to seed the RNG with a constant value for
15 | a more deterministic test (so randomly selected target classes are chosen
16 | in a pseudorandom way).
17 | '''
18 |
19 | if not provider.provides(model.dataset):
20 | raise ValueError('provider does not provide correct dataset')
21 | if start is not None and not (0 <= start < len(provider)):
22 | raise ValueError('start value out of range')
23 | if end is not None and not (0 <= end <= len(provider)):
24 | raise ValueError('end value out of range')
25 |
26 | threat_model = model.threat_model
27 | targeted = threat_model.targeted
28 |
29 | success = 0
30 | total = 0
31 | for i in range(start, end):
32 | print('evaluating %d of [%d, %d)' % (i, start, end), file=sys.stderr)
33 | total += 1
34 | x, y = provider[i]
35 | target = None
36 | if targeted:
37 | target = choose_target(i, y, model.dataset.labels, deterministic)
38 | x_adv = attack.run(np.copy(x), y, target)
39 | if not threat_model.check(np.copy(x), np.copy(x_adv)):
40 | if debug:
41 | print('check failed', file=sys.stderr)
42 | continue
43 | y_adv = model.classify(np.copy(x_adv))
44 | if debug:
45 | print('true = %d, adv = %d' % (y, y_adv), file=sys.stderr)
46 | if targeted:
47 | if y_adv == target:
48 | success += 1
49 | else:
50 | if y_adv != y:
51 | success += 1
52 |
53 | success_rate = success / total
54 |
55 | return success_rate
56 |
57 | def choose_target(index, true_label, num_labels, deterministic=False):
58 | if deterministic:
59 | rng = np.random.RandomState(index)
60 | else:
61 | rng = np.random.RandomState()
62 |
63 | target = true_label
64 | while target == true_label:
65 | target = rng.randint(0, num_labels)
66 |
67 | return target
68 |
--------------------------------------------------------------------------------
/robustml/model.py:
--------------------------------------------------------------------------------
1 | from abc import ABCMeta, abstractmethod
2 |
3 | class Model(metaclass=ABCMeta):
4 | '''
5 | Interface for a model (classifier).
6 |
7 | Besides the required methods below, a model should do a reasonable job of
8 | providing easy access to internals to make white box attacks easier. For
9 | example, a model using TensorFlow might want to provide access to the input
10 | tensor placeholder and the tensor representing the logits output of the
11 | classifier.
12 | '''
13 |
14 | @property
15 | @abstractmethod
16 | def dataset(self):
17 | '''
18 | A concrete instance of a subclass of `robustml.dataset.Dataset`.
19 | '''
20 | raise NotImplementedError
21 |
22 | @property
23 | @abstractmethod
24 | def threat_model(self):
25 | '''
26 | An instance of `robustml.threat_model.ThreatModel`, ideally
27 | one of the pre-defined concrete threat models.
28 | '''
29 | raise NotImplementedError
30 |
31 | @abstractmethod
32 | def classify(self, x):
33 | '''
34 | Returns the label for the input x (as a Python integer).
35 | '''
36 | raise NotImplementedError
37 |
--------------------------------------------------------------------------------
/robustml/provider.py:
--------------------------------------------------------------------------------
1 | from abc import ABCMeta, abstractmethod
2 | import os
3 | import numpy as np
4 | import pickle
5 | import gzip
6 | import PIL.Image
7 | import csv
8 |
9 | from . import dataset as d
10 |
11 | class Provider(metaclass=ABCMeta):
12 | @abstractmethod
13 | def provides(self, dataset):
14 | '''
15 | Returns whether or not this provider can provide data for the given
16 | dataset.
17 | '''
18 | raise NotImplementedError
19 |
20 | @abstractmethod
21 | def __len__(self):
22 | raise NotImplementedError
23 |
24 | @abstractmethod
25 | def __getitem__(self, index):
26 | '''
27 | Returns a tuple (x, y) containing the data and label of the given
28 | index.
29 | '''
30 | raise NotImplementedError
31 |
32 | class MNIST(Provider):
33 | def __init__(self, image_path, label_path):
34 | '''
35 | image_path is a path to the 't10k-images-idx3-ubyte.gz' from
36 | 'http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz'
37 |
38 | label_path is a path to the 't10k-labels-idx1-ubyte.gz' from
39 | 'http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz'
40 | '''
41 | with gzip.open(image_path, 'rb') as f:
42 | images = f.read()
43 | assert images[:4] == b'\x00\x00\x08\x03'
44 | images = np.frombuffer(images[16:], dtype=np.uint8)
45 | assert len(images) == 7840000
46 | images = images.reshape((10000, 28, 28)).astype(np.float32) / 255.0
47 | with gzip.open(label_path, 'rb') as f:
48 | labels = f.read()
49 | assert labels[:4] == b'\x00\x00\x08\x01'
50 | labels = np.frombuffer(labels[8:], dtype=np.uint8)
51 | assert len(labels) == 10000
52 |
53 | self.xs = images
54 | self.ys = labels
55 |
56 | def provides(self, dataset):
57 | return isinstance(dataset, d.MNIST)
58 |
59 | def __len__(self):
60 | return 10000
61 |
62 | def __getitem__(self, index):
63 | x = self.xs[index]
64 | y = self.ys[index]
65 | return x, y
66 |
67 | class FMNIST(Provider):
68 | def __init__(self, image_path, label_path):
69 | '''
70 | image_path is a path to the 't10k-images-idx3-ubyte.gz' from
71 | 'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz'
72 |
73 | label_path is a path to the 't10k-labels-idx1-ubyte.gz' from
74 | 'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz'
75 | '''
76 | with gzip.open(image_path, 'rb') as f:
77 | images = f.read()
78 | assert images[:4] == b'\x00\x00\x08\x03'
79 | images = np.frombuffer(images[16:], dtype=np.uint8)
80 | assert len(images) == 7840000
81 | images = images.reshape((10000, 28, 28)).astype(np.float32) / 255.0
82 | with gzip.open(label_path, 'rb') as f:
83 | labels = f.read()
84 | assert labels[:4] == b'\x00\x00\x08\x01'
85 | labels = np.frombuffer(labels[8:], dtype=np.uint8)
86 | assert len(labels) == 10000
87 |
88 | self.xs = images
89 | self.ys = labels
90 |
91 | def provides(self, dataset):
92 | return isinstance(dataset, d.FMNIST)
93 |
94 | def __len__(self):
95 | return 10000
96 |
97 | def __getitem__(self, index):
98 | x = self.xs[index]
99 | y = self.ys[index]
100 | return x, y
101 |
102 | class GTS(Provider):
103 | def __init__(self, test_data):
104 | '''
105 | test_data is a path to the test set of German Traffic Sign dataset (GTSRB/) from
106 | `http://benchmark.ini.rub.de/Dataset/GTSRB_Final_Test_Images.zip`
107 |
108 | Note that the labels should be obtained separately from
109 | `http://benchmark.ini.rub.de/Dataset/GTSRB_Final_Test_GT.zip`
110 | and placed to `test_data`
111 | '''
112 | height, width = 32, 32 # to resize the GTS images that originally have various sizes (from 15x15 to 250x250)
113 | images, labels = [], []
114 |
115 | f_annotation = test_data + 'GT-final_test.csv'
116 | gtFile = open(f_annotation) # annotations file
117 | gtReader = csv.reader(gtFile, delimiter=';') # csv parser for annotations
118 | for row in list(gtReader)[1:]: # loop over all images in the annotation file (ignoring the header)
119 | image_file_name = row[0]
120 | img = PIL.Image.open(test_data + 'Final_Test/Images/' + image_file_name)
121 | img = img.resize((height, width), PIL.Image.ANTIALIAS)
122 | images.append(np.array(img))
123 | labels.append(int(row[7]))
124 | gtFile.close()
125 | self.xs = np.array(images, dtype=np.float32) / 255.0
126 | self.ys = np.array(labels)
127 |
128 | def provides(self, dataset):
129 | return isinstance(dataset, d.GTS)
130 |
131 | def __len__(self):
132 | return 12630
133 |
134 | def __getitem__(self, index):
135 | return self.xs[index], self.ys[index]
136 |
137 | class CIFAR10(Provider):
138 | def __init__(self, test_data):
139 | '''
140 | test_data is a path to the 'test_batch' file from
141 | 'http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz'
142 | '''
143 | with open(test_data, 'rb') as f:
144 | data = pickle.load(f, encoding='bytes')
145 | self.xs = data[b'data'].reshape((10000,3,32,32)).astype(np.float32) / 255.0
146 | self.xs = np.transpose(self.xs, (0,2,3,1)) # (N,3,32,32) -> (N,32,32,3)
147 | self.ys = data[b'labels']
148 |
149 | def provides(self, dataset):
150 | return isinstance(dataset, d.CIFAR10)
151 |
152 | def __len__(self):
153 | return 10000
154 |
155 | def __getitem__(self, index):
156 | return self.xs[index], self.ys[index]
157 |
158 | class ImageNet(Provider):
159 | def __init__(self, path, shape):
160 | self._path = path
161 | self._shape = shape
162 |
163 | def provides(self, dataset):
164 | return isinstance(dataset, d.ImageNet) and self._shape == dataset.shape
165 |
166 | def __len__(self):
167 | return 50000
168 |
169 | def __getitem__(self, index):
170 | data_path = os.path.join(self._path, 'val')
171 | image_paths = sorted([os.path.join(data_path, i) for i in os.listdir(data_path)])
172 | assert len(image_paths) == 50000
173 | labels_path = os.path.join(self._path, 'val.txt')
174 | with open(labels_path) as labels_file:
175 | labels = [i.split(' ') for i in labels_file.read().strip().split('\n')]
176 | labels = {os.path.basename(i[0]): int(i[1]) for i in labels}
177 | path = image_paths[index]
178 | x = self._load_image(path)
179 | y = labels[os.path.basename(path)]
180 | return x, y
181 |
182 | # get centered crop of self._size
183 | def _load_image(self, path):
184 | h, w, c = self._shape
185 | aspect = w / h
186 | image = PIL.Image.open(path)
187 | image_aspect = image.width / image.height
188 |
189 | if image_aspect > aspect:
190 | # image is wider than our aspect ratio
191 | new_height = image.height
192 | height_off = 0
193 | new_width = int(aspect * new_height)
194 | width_off = (image.width - new_width) // 2
195 | else:
196 | # image is taller than our aspect ratio
197 | new_width = image.width
198 | width_off = 0
199 | new_height = int(new_width / aspect)
200 | height_off = (image.height - new_height) // 2
201 |
202 | # box is (left, upper, right, lower)
203 | image = image.crop((
204 | width_off,
205 | height_off,
206 | width_off+new_width,
207 | height_off+new_height
208 | ))
209 |
210 | image = image.resize((w, h))
211 |
212 | arr = np.asarray(image).astype(np.float32) / 255.0
213 | if arr.ndim == 2:
214 | # stack greyscale image
215 | arr = np.repeat(arr[:,:,np.newaxis], repeats=3, axis=2)
216 | if arr.shape[2] == 4:
217 | # remove alpha channel
218 | arr = arr[:,:,:3]
219 | assert arr.shape == self._shape
220 | return arr
221 |
222 |
--------------------------------------------------------------------------------
/robustml/threat_model.py:
--------------------------------------------------------------------------------
1 | from abc import ABCMeta, abstractmethod
2 | import numpy as np
3 |
4 | from . import dataset
5 |
6 | class ThreatModel(metaclass=ABCMeta):
7 | @abstractmethod
8 | def check(self, original, perturbed):
9 | '''
10 | Returns whether the perturbed image is a valid perturbation of the
11 | original under the threat model.
12 |
13 | `original` and `perturbed` are numpy arrays of the same dtype and
14 | shape.
15 | '''
16 | raise NotImplementedError
17 |
18 | @property
19 | @abstractmethod
20 | def targeted(self):
21 | '''
22 | Returns whether the threat model only includes targeted attacks
23 | (requiring the attack to be capable of synthesizing targeted
24 | adversarial examples).
25 | '''
26 | raise NotImplementedError
27 |
28 | class Or(ThreatModel):
29 | '''
30 | A union of threat models.
31 | '''
32 |
33 | def __init__(self, *threat_models):
34 | self._threat_models = threat_models
35 |
36 | def check(self, original, perturbed):
37 | return any(i.check(original, perturbed) for i in self._threat_models)
38 |
39 | @property
40 | def targeted(self):
41 | return all(i.targeted for i in self._threat_models)
42 |
43 | class And(ThreatModel):
44 | '''
45 | An intersection of threat models.
46 | '''
47 |
48 | def __init__(self, *threat_models):
49 | self._threat_models = threat_models
50 |
51 | def check(self, original, perturbed):
52 | return all(i.check(original, perturbed) for i in self._threat_models)
53 |
54 | @property
55 | def targeted(self):
56 | return any(i.targeted for i in self._threat_models)
57 |
58 | class Lp(ThreatModel):
59 | '''
60 | Bounded L_p perturbation. Given a `p` and `epsilon`, x' is a valid
61 | perturbation of x if the following holds:
62 |
63 | || x - x' ||_p <= \epsilon
64 | '''
65 |
66 | _SLOP = 0.0001 # to account for rounding errors
67 |
68 | def __init__(self, p, epsilon, targeted=False):
69 | self._p = p
70 | self._epsilon = epsilon
71 | self._targeted = targeted
72 |
73 | def check(self, original, perturbed):
74 | # we want to treat the inputs as big vectors
75 | original = np.ndarray.flatten(original)
76 | perturbed = np.ndarray.flatten(perturbed)
77 | # ensure it's a valid image
78 | if np.min(perturbed) < -self._SLOP or np.max(perturbed) > 1+self._SLOP:
79 | return False
80 | norm = np.linalg.norm(original - perturbed, ord=self._p)
81 | return norm <= self._epsilon + self._SLOP
82 |
83 | @property
84 | def targeted(self):
85 | return self._targeted
86 |
87 | @property
88 | def p(self):
89 | return self._p
90 |
91 | @property
92 | def epsilon(self):
93 | return self._epsilon
94 |
95 | class L0(Lp):
96 | def __init__(self, epsilon, targeted=False):
97 | super().__init__(p=0, epsilon=epsilon, targeted=targeted)
98 |
99 | class L1(Lp):
100 | def __init__(self, epsilon, targeted=False):
101 | super().__init__(p=1, epsilon=epsilon, targeted=targeted)
102 |
103 | class L2(Lp):
104 | def __init__(self, epsilon, targeted=False):
105 | super().__init__(p=2, epsilon=epsilon, targeted=targeted)
106 |
107 | class Linf(Lp):
108 | '''
109 | Bounded L_inf perturbation. Given a `p` and `epsilon`, x' is a valid
110 | perturbation of x if the following holds:
111 |
112 | || x - x' ||_\infty <= \epsilon
113 |
114 | >>> model = Linf(0.1)
115 | >>> x = np.array([0.1, 0.2, 0.3])
116 | >>> model.check(x, x)
117 | True
118 | >>> model.targeted
119 | False
120 | >>> model = Linf(0.1, targeted=True)
121 | >>> model.targeted
122 | True
123 | >>> y = np.array([0.1, 0.25, 0.32])
124 | >>> model.check(x, y)
125 | True
126 | >>> z = np.array([0.3, 0.2, 0.3])
127 | >>> model.check(x, z)
128 | False
129 | '''
130 |
131 | def __init__(self, epsilon, targeted=False):
132 | super().__init__(p=np.inf, epsilon=epsilon, targeted=targeted)
133 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup
2 | from codecs import open # For a consistent encoding
3 | from os import path
4 | import re
5 |
6 |
7 | here = path.dirname(__file__)
8 |
9 |
10 | with open(path.join(here, 'README.rst'), encoding='utf-8') as f:
11 | long_description = f.read()
12 |
13 |
14 | def read(*names, **kwargs):
15 | with open(
16 | path.join(here, *names),
17 | encoding=kwargs.get("encoding", "utf8")
18 | ) as fp:
19 | return fp.read()
20 |
21 |
22 | def find_version(*file_paths):
23 | version_file = read(*file_paths)
24 | version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]",
25 | version_file, re.M)
26 | if version_match:
27 | return version_match.group(1)
28 | raise RuntimeError("Unable to find version string.")
29 |
30 |
31 | setup(
32 | name='robustml',
33 |
34 | version=find_version('robustml', '__init__.py'),
35 |
36 | description='Robust ML API',
37 | long_description=long_description,
38 |
39 | url='https://github.com/robust-ml/robust-ml',
40 |
41 | author='Anish Athalye',
42 | author_email='me@anishathalye.com',
43 |
44 | license='MIT',
45 |
46 | packages=['robustml'],
47 |
48 | install_requires=[
49 | 'numpy>=1,<2',
50 | 'Pillow>=5,<6',
51 | ],
52 | )
53 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import doctest
2 | import robustml
3 |
4 | # We need to explicitly list all modules here. This is not super pretty, but
5 | # the testing here shouldn't be too involved. If there's ever a need for
6 | # fancier testing, we can switch to a more complete testing framework.
7 |
8 | TEST_MODULES = [
9 | robustml.threat_model,
10 | ]
11 |
12 | if __name__ == '__main__':
13 | for module in TEST_MODULES:
14 | doctest.testmod(module)
15 |
--------------------------------------------------------------------------------