├── .gitignore
├── CITATION.cff
├── LICENSE
├── README.md
├── adv_lib
├── __init__.py
├── attacks
│ ├── __init__.py
│ ├── augmented_lagrangian.py
│ ├── auto_pgd.py
│ ├── carlini_wagner
│ │ ├── __init__.py
│ │ ├── l2.py
│ │ └── linf.py
│ ├── decoupled_direction_norm.py
│ ├── deepfool.py
│ ├── fast_adaptive_boundary
│ │ ├── __init__.py
│ │ ├── fast_adaptive_boundary.py
│ │ └── projections.py
│ ├── fast_minimum_norm.py
│ ├── perceptual_color_attacks
│ │ ├── __init__.py
│ │ ├── differential_color_functions.py
│ │ └── perceptual_color_distance_al.py
│ ├── primal_dual_gradient_descent.py
│ ├── projected_gradient_descent.py
│ ├── segmentation
│ │ ├── __init__.py
│ │ ├── alma_prox.py
│ │ ├── asma.py
│ │ ├── dense_adversary.py
│ │ └── primal_dual_gradient_descent.py
│ ├── sigma_zero.py
│ ├── stochastic_sparse_attacks.py
│ ├── structured_adversarial_attack.py
│ ├── superdeepfool.py
│ └── trust_region.py
├── distances
│ ├── __init__.py
│ ├── color_difference.py
│ ├── lp_norms.py
│ ├── lpips.py
│ └── structural_similarity.py
└── utils
│ ├── __init__.py
│ ├── attack_utils.py
│ ├── color_conversions.py
│ ├── image_selection.py
│ ├── lagrangian_penalties
│ ├── __init__.py
│ ├── all_penalties.py
│ ├── penalty_functions.py
│ ├── scripts
│ │ ├── plot_penalties.py
│ │ └── plot_univariates.py
│ └── univariate_functions.py
│ ├── losses.py
│ ├── projections.py
│ ├── utils.py
│ └── visdom_logger.py
├── pyproject.toml
└── tests
├── distances
├── test_color_difference.py
└── test_structural_similarity.py
└── utils
└── lagrangian_penalties
├── test_penalty_functions.py
└── test_univariate_functions.py
/.gitignore:
--------------------------------------------------------------------------------
1 |
2 | # Created by https://www.toptal.com/developers/gitignore/api/python,linux,pycharm
3 | # Edit at https://www.toptal.com/developers/gitignore?templates=python,linux,pycharm
4 |
5 | ### Linux ###
6 | *~
7 |
8 | # temporary files which can be created if a process still has a handle open of a deleted file
9 | .fuse_hidden*
10 |
11 | # KDE directory preferences
12 | .directory
13 |
14 | # Linux trash folder which might appear on any partition or disk
15 | .Trash-*
16 |
17 | # .nfs files are created when an open file is removed but is still being accessed
18 | .nfs*
19 |
20 | ### PyCharm ###
21 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider
22 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
23 |
24 | # User-specific stuff
25 | .idea/**/workspace.xml
26 | .idea/**/tasks.xml
27 | .idea/**/usage.statistics.xml
28 | .idea/**/dictionaries
29 | .idea/**/shelf
30 |
31 | # Generated files
32 | .idea/**/contentModel.xml
33 |
34 | # Sensitive or high-churn files
35 | .idea/**/dataSources/
36 | .idea/**/dataSources.ids
37 | .idea/**/dataSources.local.xml
38 | .idea/**/sqlDataSources.xml
39 | .idea/**/dynamic.xml
40 | .idea/**/uiDesigner.xml
41 | .idea/**/dbnavigator.xml
42 |
43 | # Gradle
44 | .idea/**/gradle.xml
45 | .idea/**/libraries
46 |
47 | # Gradle and Maven with auto-import
48 | # When using Gradle or Maven with auto-import, you should exclude module files,
49 | # since they will be recreated, and may cause churn. Uncomment if using
50 | # auto-import.
51 | # .idea/artifacts
52 | # .idea/compiler.xml
53 | # .idea/jarRepositories.xml
54 | # .idea/modules.xml
55 | # .idea/*.iml
56 | # .idea/modules
57 | # *.iml
58 | # *.ipr
59 |
60 | # CMake
61 | cmake-build-*/
62 |
63 | # Mongo Explorer plugin
64 | .idea/**/mongoSettings.xml
65 |
66 | # File-based project format
67 | *.iws
68 |
69 | # IntelliJ
70 | out/
71 |
72 | # mpeltonen/sbt-idea plugin
73 | .idea_modules/
74 |
75 | # JIRA plugin
76 | atlassian-ide-plugin.xml
77 |
78 | # Cursive Clojure plugin
79 | .idea/replstate.xml
80 |
81 | # Crashlytics plugin (for Android Studio and IntelliJ)
82 | com_crashlytics_export_strings.xml
83 | crashlytics.properties
84 | crashlytics-build.properties
85 | fabric.properties
86 |
87 | # Editor-based Rest Client
88 | .idea/httpRequests
89 |
90 | # Android studio 3.1+ serialized cache file
91 | .idea/caches/build_file_checksums.ser
92 |
93 | ### PyCharm Patch ###
94 | # Comment Reason: https://github.com/joeblau/gitignore.io/issues/186#issuecomment-215987721
95 |
96 | # *.iml
97 | # modules.xml
98 | # .idea/misc.xml
99 | # *.ipr
100 |
101 | # Sonarlint plugin
102 | # https://plugins.jetbrains.com/plugin/7973-sonarlint
103 | .idea/**/sonarlint/
104 |
105 | # SonarQube Plugin
106 | # https://plugins.jetbrains.com/plugin/7238-sonarqube-community-plugin
107 | .idea/**/sonarIssues.xml
108 |
109 | # Markdown Navigator plugin
110 | # https://plugins.jetbrains.com/plugin/7896-markdown-navigator-enhanced
111 | .idea/**/markdown-navigator.xml
112 | .idea/**/markdown-navigator-enh.xml
113 | .idea/**/markdown-navigator/
114 |
115 | # Cache file creation bug
116 | # See https://youtrack.jetbrains.com/issue/JBR-2257
117 | .idea/$CACHE_FILE$
118 |
119 | # CodeStream plugin
120 | # https://plugins.jetbrains.com/plugin/12206-codestream
121 | .idea/codestream.xml
122 |
123 | ### Python ###
124 | # Byte-compiled / optimized / DLL files
125 | __pycache__/
126 | *.py[cod]
127 | *$py.class
128 |
129 | # C extensions
130 | *.so
131 |
132 | # Distribution / packaging
133 | .Python
134 | build/
135 | develop-eggs/
136 | dist/
137 | downloads/
138 | eggs/
139 | .eggs/
140 | lib/
141 | lib64/
142 | parts/
143 | sdist/
144 | var/
145 | wheels/
146 | pip-wheel-metadata/
147 | share/python-wheels/
148 | *.egg-info/
149 | .installed.cfg
150 | *.egg
151 | MANIFEST
152 |
153 | # PyInstaller
154 | # Usually these files are written by a python script from a template
155 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
156 | *.manifest
157 | *.spec
158 |
159 | # Installer logs
160 | pip-log.txt
161 | pip-delete-this-directory.txt
162 |
163 | # Unit tests / coverage reports
164 | htmlcov/
165 | .tox/
166 | .nox/
167 | .coverage
168 | .coverage.*
169 | .cache
170 | nosetests.xml
171 | coverage.xml
172 | *.cover
173 | *.py,cover
174 | .hypothesis/
175 | .pytest_cache/
176 | pytestdebug.log
177 |
178 | # Translations
179 | *.mo
180 | *.pot
181 |
182 | # Django stuff:
183 | *.log
184 | local_settings.py
185 | db.sqlite3
186 | db.sqlite3-journal
187 |
188 | # Flask stuff:
189 | instance/
190 | .webassets-cache
191 |
192 | # Scrapy stuff:
193 | .scrapy
194 |
195 | # Sphinx documentation
196 | docs/_build/
197 | doc/_build/
198 |
199 | # PyBuilder
200 | target/
201 |
202 | # Jupyter Notebook
203 | .ipynb_checkpoints
204 |
205 | # IPython
206 | profile_default/
207 | ipython_config.py
208 |
209 | # pyenv
210 | .python-version
211 |
212 | # pipenv
213 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
214 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
215 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
216 | # install all needed dependencies.
217 | #Pipfile.lock
218 |
219 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
220 | __pypackages__/
221 |
222 | # Celery stuff
223 | celerybeat-schedule
224 | celerybeat.pid
225 |
226 | # SageMath parsed files
227 | *.sage.py
228 |
229 | # Environments
230 | .env
231 | .venv
232 | env/
233 | venv/
234 | ENV/
235 | env.bak/
236 | venv.bak/
237 |
238 | # Spyder project settings
239 | .spyderproject
240 | .spyproject
241 |
242 | # Rope project settings
243 | .ropeproject
244 |
245 | # mkdocs documentation
246 | /site
247 |
248 | # mypy
249 | .mypy_cache/
250 | .dmypy.json
251 | dmypy.json
252 |
253 | # Pyre type checker
254 | .pyre/
255 |
256 | # pytype static type analyzer
257 | .pytype/
258 |
259 | # End of https://www.toptal.com/developers/gitignore/api/python,linux,pycharm
--------------------------------------------------------------------------------
/CITATION.cff:
--------------------------------------------------------------------------------
1 | # This CITATION.cff file was generated with cffinit.
2 | # Visit https://bit.ly/cffinit to generate yours today!
3 |
4 | cff-version: 1.2.0
5 | title: Adversarial Library
6 | message: >-
7 | If you use this software, please cite it using the
8 | metadata from this file.
9 | type: software
10 | authors:
11 | - given-names: Jérôme
12 | family-names: Rony
13 | affiliation: ÉTS Montréal
14 | orcid: 'https://orcid.org/0000-0002-6359-6142'
15 | - given-names: Ismail
16 | family-names: Ben Ayed
17 | affiliation: ÉTS Montréal
18 | orcid: 'https://orcid.org/0000-0002-9668-8027'
19 | doi: 10.5281/zenodo.5815063
20 | repository-code: 'https://github.com/jeromerony/adversarial-library'
21 | keywords:
22 | - machine learning
23 | - adversarial examples
24 | - pytorch
25 | license: BSD-3-Clause
26 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | BSD 3-Clause License
2 |
3 | Copyright (c) 2020, Jérôme Rony
4 | All rights reserved.
5 |
6 | Redistribution and use in source and binary forms, with or without
7 | modification, are permitted provided that the following conditions are met:
8 |
9 | 1. Redistributions of source code must retain the above copyright notice, this
10 | list of conditions and the following disclaimer.
11 |
12 | 2. Redistributions in binary form must reproduce the above copyright notice,
13 | this list of conditions and the following disclaimer in the documentation
14 | and/or other materials provided with the distribution.
15 |
16 | 3. Neither the name of the copyright holder nor the names of its
17 | contributors may be used to endorse or promote products derived from
18 | this software without specific prior written permission.
19 |
20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 | [](https://zenodo.org/badge/latestdoi/315504148)
3 |
4 | # Adversarial Library
5 |
6 | This library contains various resources related to adversarial attacks implemented in PyTorch. It is aimed towards researchers looking for implementations of state-of-the-art attacks.
7 |
8 | The code was written to maximize efficiency (_e.g._ by preferring low level functions from PyTorch) while retaining simplicity (_e.g._ by avoiding abstractions). As a consequence, most of the library, and especially the attacks, is implemented using **pure functions** (whenever possible).
9 |
10 | While focused on attacks, this library also provides several utilities related to adversarial attacks: distances (SSIM, CIEDE2000, LPIPS), visdom callback, projections, losses and helper functions. Most notably the function `run_attack` from `utils/attack_utils.py` performs an attack on a model given the inputs and labels, with fixed batch size, and reports complexity related metrics (run-time and forward/backward propagations).
11 |
12 | ### Dependencies
13 |
14 | The goal of this library is to be up-to-date with newer versions of PyTorch so the dependencies are expected to be updated regularly (possibly resulting in breaking changes).
15 |
16 | - pytorch>=1.8.0
17 | - torchvision>=0.9.0
18 | - tqdm>=4.48.0
19 | - visdom>=0.1.8
20 |
21 | ### Installation
22 |
23 | You can either install using:
24 |
25 | ```pip install git+https://github.com/jeromerony/adversarial-library```
26 |
27 | Or you can clone the repo and run:
28 |
29 | ```python setup.py install```
30 |
31 | Alternatively, you can install (after cloning) the library in editable mode:
32 |
33 | ```pip install -e .```
34 |
35 | ### Usage
36 | Attacks are implemented as functions, so they can be called directly by providing the model, samples and labels (possibly with optional arguments):
37 | ```python
38 | from adv_lib.attacks import ddn
39 | adv_samples = ddn(model=model, inputs=inputs, labels=labels, steps=300)
40 | ```
41 |
42 | Classification attacks all expect the following arguments:
43 | - `model`: the model that produces logits (pre-softmax activations) with inputs in $[0, 1]$
44 | - `inputs`: the samples to attack in $[0, 1]$
45 | - `labels`: either the ground-truth labels for the samples or the targets
46 | - `targeted`: flag indicated if the attack should be targeted or not -- defaults to `False`
47 |
48 | Additionally, many attacks have an optional `callback` argument which accepts an `adv_lib.utils.visdom_logger.VisdomLogger` to plot data to a visdom server for monitoring purposes.
49 |
50 | For a more detailed example on how to use this library, you can look at this repo: https://github.com/jeromerony/augmented_lagrangian_adversarial_attacks
51 |
52 | ## Contents
53 |
54 | ### Attacks
55 |
56 | #### Classification
57 |
58 | Currently the following classification attacks are implemented in the `adv_lib.attacks` module:
59 |
60 | | Name | Knowledge | Type | Distance(s) | ArXiv Link |
61 | |-----------------------------------------------------------------------------------------|-----------|---------|-----------------------------------------------------------|------------------------------------------------------------------------------------------------------|
62 | | DeepFool (DF) | White-box | Minimal | $\ell_2$, $\ell_\infty$ | [1511.04599](https://arxiv.org/abs/1511.04599) |
63 | | Carlini and Wagner (C&W) | White-box | Minimal | $\ell_2$, $\ell_\infty$ | [1608.04644](https://arxiv.org/abs/1608.04644) |
64 | | Projected Gradient Descent (PGD) | White-box | Budget | $\ell_\infty$ | [1706.06083](https://arxiv.org/abs/1706.06083) |
65 | | Structured Adversarial Attack (StrAttack) | White-box | Minimal | $\ell_2$ + group-sparsity | [1808.01664](https://arxiv.org/abs/1808.01664) |
66 | | **Decoupled Direction and Norm (DDN)** | White-box | Minimal | $\ell_2$ | [1811.09600](https://arxiv.org/abs/1811.09600) |
67 | | Trust Region (TR) | White-box | Minimal | $\ell_2$, $\ell_\infty$ | [1812.06371](https://arxiv.org/abs/1812.06371) |
68 | | Fast Adaptive Boundary (FAB) | White-box | Minimal | $\ell_1$, $\ell_2$, $\ell_\infty$ | [1907.02044](https://arxiv.org/abs/1907.02044) |
69 | | Perceptual Color distance Alternating Loss (PerC-AL) | White-box | Minimal | CIEDE2000 | [1911.02466](https://arxiv.org/abs/1911.02466) |
70 | | Auto-PGD (APGD) | White-box | Budget | $\ell_1$, $\ell_2$, $\ell_\infty$ | [2003.01690](https://arxiv.org/abs/2003.01690)
[2103.01208](https://arxiv.org/abs/2103.01208) |
71 | | **Augmented Lagrangian Method for Adversarial (ALMA)** | White-box | Minimal | $\ell_1$, $\ell_2$, SSIM, CIEDE2000, LPIPS, ... | [2011.11857](https://arxiv.org/abs/2011.11857) |
72 | | Folded Gaussian Attack (FGA)
Voting Folded Gaussian Attack (VFGA) | White-box | Minimal | $\ell_0$ | [2011.12423](https://arxiv.org/abs/2011.12423) |
73 | | Fast Minimum-Norm (FMN) | White-box | Minimal | $\ell_0$, $\ell_1$, $\ell_2$, $\ell_\infty$ | [2102.12827](https://arxiv.org/abs/2102.12827) |
74 | | Primal-Dual Gradient Descent (PDGD)
Primal-Dual Proximal Gradient Descent (PDPGD) | White-box | Minimal | $\ell_2$
$\ell_0$, $\ell_1$, $\ell_2$, $\ell_\infty$ | [2106.01538](https://arxiv.org/abs/2106.01538) |
75 | | SuperDeepFool (SDF) | White-box | Minimal | $\ell_2$ | [2303.12481](https://arxiv.org/abs/2303.12481) |
76 | | σ-zero | White-box | Minimal | $\ell_0$ | [2402.01879](https://arxiv.org/abs/2402.01879) |
77 |
78 | **Bold** means that this repository contains the official implementation.
79 |
80 | _Type_ refers to the goal of the attack:
81 | - _Minimal_ attacks aim to find the smallest adversarial perturbation w.r.t. a given distance;
82 | - _Budget_ attacks aim to find an adversarial perturbation within a distance budget (and often to maximize a loss as well).
83 |
84 | #### Segmentation
85 |
86 | The library now includes segmentation attacks in the `adv_lib.attacks.segmentation` module. These require the following arguments:
87 | - `model`: the model that produces logits (pre-softmax activations) with inputs in $[0, 1]$
88 | - `inputs`: the images to attack in $[0, 1]$. Shape: $b\times c\times h\times w$ with $b$ the batch size, $c$ the number of color channels and $h$ and $w$ the height and width of the images.
89 | - `labels`: either the ground-truth labels for the samples or the targets. Shape: $b\times h\times w$.
90 | - `masks`: binary mask indicating which pixels to attack, to account for unlabeled pixels (e.g. void in Pascal VOC). Shape: $b\times h\times w$
91 | - `targeted`: flag indicated if the attack should be targeted or not -- defaults to `False`
92 | - `adv_threshold`: fraction of the pixels to consider an attack successful -- defaults to `0.99`
93 |
94 | The following segmentation attacks are implemented:
95 |
96 | | Name | Knowledge | Type | Distance(s) | ArXiv Link |
97 | |-------------------------------------------------------------------------------------------|-----------|---------|-----------------------------------------------------------|------------------------------------------------|
98 | | Dense Adversary Generation (DAG) | White-box | Minimal | $\ell_2$, $\ell_\infty$ | [1703.08603](https://arxiv.org/abs/1703.08603) |
99 | | Adaptive Segmentation Mask Attack (ASMA) | White-box | Minimal | $\ell_2$ | [1907.13124](https://arxiv.org/abs/1907.13124) |
100 | | _Primal-Dual Gradient Descent (PDGD)
Primal-Dual Proximal Gradient Descent (PDPGD)_ | White-box | Minimal | $\ell_2$
$\ell_0$, $\ell_1$, $\ell_2$, $\ell_\infty$ | [2106.01538](https://arxiv.org/abs/2106.01538) |
101 | | **ALMA prox** | White-box | Minimal | $\ell_\infty$ | [2206.07179](https://arxiv.org/abs/2206.07179) |
102 |
103 | _Italic_ indicates that the attack is unofficially adapted from the classification variant.
104 |
105 | ### Distances
106 |
107 | The following distances are available in the utils `adv_lib.distances` module:
108 | - Lp-norms
109 | - SSIM https://ece.uwaterloo.ca/~z70wang/research/ssim/
110 | - MS-SSIM https://ece.uwaterloo.ca/~z70wang/publications/msssim.html
111 | - CIEDE2000 color difference http://www2.ece.rochester.edu/~gsharma/ciede2000/ciede2000noteCRNA.pdf
112 | - LPIPS https://arxiv.org/abs/1801.03924
113 |
114 | ## Contributions
115 |
116 | Suggestions and contributions are welcome :)
117 |
118 | ## Citation
119 |
120 | If this library has been useful for your research, you can cite it using the "Cite this repository" button in the "About" section.
121 |
--------------------------------------------------------------------------------
/adv_lib/__init__.py:
--------------------------------------------------------------------------------
1 | __version__ = "0.2.3"
2 |
--------------------------------------------------------------------------------
/adv_lib/attacks/__init__.py:
--------------------------------------------------------------------------------
1 | from .augmented_lagrangian import alma
2 | from .auto_pgd import apgd, apgd_targeted
3 | from .carlini_wagner import carlini_wagner_l2, carlini_wagner_linf
4 | from .decoupled_direction_norm import ddn
5 | from .deepfool import df
6 | from .fast_adaptive_boundary import fab
7 | from .fast_minimum_norm import fmn
8 | from .perceptual_color_attacks import perc_al
9 | from .primal_dual_gradient_descent import pdgd, pdpgd
10 | from .projected_gradient_descent import pgd_linf
11 | from .sigma_zero import sigma_zero
12 | from .stochastic_sparse_attacks import fga, vfga
13 | from .structured_adversarial_attack import str_attack
14 | from .superdeepfool import sdf
15 | from .trust_region import tr
16 |
--------------------------------------------------------------------------------
/adv_lib/attacks/augmented_lagrangian.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 | from typing import Callable, Optional
3 |
4 | import torch
5 | from torch import Tensor, nn
6 | from torch.autograd import grad
7 |
8 | from adv_lib.distances.color_difference import ciede2000_loss
9 | from adv_lib.distances.lp_norms import l1_distances, l2_distances
10 | from adv_lib.distances.lpips import LPIPS
11 | from adv_lib.distances.structural_similarity import ms_ssim_loss, ssim_loss
12 | from adv_lib.utils.lagrangian_penalties import all_penalties
13 | from adv_lib.utils.losses import difference_of_logits_ratio
14 | from adv_lib.utils.visdom_logger import VisdomLogger
15 |
16 |
17 | def init_lr_finder(inputs: Tensor, grad: Tensor, distance_function: Callable, target_distance: float) -> Tensor:
18 | """
19 | Performs a line search and a binary search to find the learning rate η for each sample such that:
20 | distance_function(inputs - η * grad) = target_distance.
21 |
22 | Parameters
23 | ----------
24 | inputs : Tensor
25 | Reference to compute the distance from.
26 | grad : Tensor
27 | Direction to step in.
28 | distance_function : Callable
29 | target_distance : float
30 | Target distance that inputs - η * grad should reach.
31 |
32 | Returns
33 | -------
34 | η : Tensor
35 | Learning rate for each sample.
36 |
37 | """
38 | batch_size = len(inputs)
39 | batch_view = lambda tensor: tensor.view(batch_size, *[1] * (inputs.ndim - 1))
40 | lr = torch.ones(batch_size, device=inputs.device)
41 | lower = torch.zeros_like(lr)
42 |
43 | found_upper = distance_function((inputs - grad).clamp_(min=0, max=1)) > target_distance
44 | while (~found_upper).any():
45 | lower = torch.where(found_upper, lower, lr)
46 | lr = torch.where(found_upper, lr, lr * 2)
47 | found_upper = distance_function((inputs - batch_view(lr) * grad).clamp_(min=0, max=1)) > target_distance
48 |
49 | for i in range(20):
50 | new_lr = (lower + lr) / 2
51 | larger = distance_function((inputs - batch_view(new_lr) * grad).clamp_(min=0, max=1)) > target_distance
52 | lower, lr = torch.where(larger, lower, new_lr), torch.where(larger, new_lr, lr)
53 |
54 | return (lr + lower) / 2
55 |
56 |
57 | _distances = {
58 | 'ssim': ssim_loss,
59 | 'msssim': ms_ssim_loss,
60 | 'ciede2000': partial(ciede2000_loss, ε=1e-12),
61 | 'lpips': LPIPS,
62 | 'l2': l2_distances,
63 | 'l1': l1_distances,
64 | }
65 |
66 |
67 | def alma(model: nn.Module,
68 | inputs: Tensor,
69 | labels: Tensor,
70 | penalty: Callable = all_penalties['P2'],
71 | targeted: bool = False,
72 | num_steps: int = 1000,
73 | lr_init: float = 0.1,
74 | lr_reduction: float = 0.01,
75 | distance: str = 'l2',
76 | init_lr_distance: Optional[float] = None,
77 | μ_init: float = 1,
78 | ρ_init: float = 1,
79 | check_steps: int = 10,
80 | τ: float = 0.95,
81 | γ: float = 1.2,
82 | α: float = 0.9,
83 | α_rms: Optional[float] = None,
84 | momentum: Optional[float] = None,
85 | logit_tolerance: float = 1e-4,
86 | levels: Optional[int] = None,
87 | callback: Optional[VisdomLogger] = None) -> Tensor:
88 | """
89 | Augmented Lagrangian Method for Adversarial (ALMA) attack from https://arxiv.org/abs/2011.11857.
90 |
91 | Parameters
92 | ----------
93 | model : nn.Module
94 | Model to attack.
95 | inputs : Tensor
96 | Inputs to attack. Should be in [0, 1].
97 | labels : Tensor
98 | Labels corresponding to the inputs if untargeted, else target labels.
99 | penalty : Callable
100 | Penalty-Lagrangian function to use. A good default choice is P2 (see the original article).
101 | targeted : bool
102 | Whether to perform a targeted attack or not.
103 | num_steps : int
104 | Number of optimization steps. Corresponds to the number of forward and backward propagations.
105 | lr_init : float
106 | Initial learning rate.
107 | lr_reduction : float
108 | Reduction factor for the learning rate. The final learning rate is lr_init * lr_reduction
109 | distance : str
110 | Distance to use.
111 | init_lr_distance : float
112 | If a float is given, the initial learning rate will be calculated such that the first step results in an
113 | increase of init_lr_distance of the distance to minimize. This corresponds to ε in the original article.
114 | μ_init : float
115 | Initial value of the penalty multiplier.
116 | ρ_init : float
117 | Initial value of the penalty parameter.
118 | check_steps : int
119 | Number of steps between checks for the improvement of the constraint. This corresponds to M in the original
120 | article.
121 | τ : float
122 | Constraint improvement rate.
123 | γ : float
124 | Penalty parameter increase rate.
125 | α : float
126 | Weight for the exponential moving average.
127 | α_rms : float
128 | Smoothing constant for RMSProp. If none is provided, defaults to α.
129 | momentum : float
130 | Momentum for the RMSProp. If none is provided, defaults to α.
131 | logit_tolerance : float
132 | Small quantity added to the difference of logits to avoid solutions where the difference of logits is 0, which
133 | can results in inconsistent class prediction (using argmax) on GPU. This can also be used as a confidence
134 | parameter κ as in https://arxiv.org/abs/1608.04644, however, a confidence parameter on logits is not robust to
135 | scaling of the logits.
136 | levels : int
137 | Number of levels for quantization. The attack will perform quantization only if the number of levels is
138 | provided.
139 | callback : VisdomLogger
140 | Callback to visualize the progress of the algorithm.
141 |
142 | Returns
143 | -------
144 | best_adv : Tensor
145 | Perturbed inputs (inputs + perturbation) that are adversarial and have smallest distance with the original
146 | inputs.
147 |
148 | """
149 | device = inputs.device
150 | batch_size = len(inputs)
151 | batch_view = lambda tensor: tensor.view(batch_size, *[1] * (inputs.ndim - 1))
152 | multiplier = -1 if targeted else 1
153 |
154 | # Setup variables
155 | δ = torch.zeros_like(inputs, requires_grad=True)
156 | square_avg = torch.ones_like(inputs)
157 | momentum_buffer = torch.zeros_like(inputs)
158 | lr = torch.full((batch_size,), lr_init, device=device, dtype=torch.float)
159 | α_rms, momentum = α if α_rms is None else α_rms, α if momentum is None else momentum
160 |
161 | # Init rho and mu
162 | μ = torch.full((batch_size,), μ_init, device=device, dtype=torch.float)
163 | ρ = torch.full((batch_size,), ρ_init, device=device, dtype=torch.float)
164 |
165 | # Init similarity metric
166 | if distance in ['lpips']:
167 | dist_func = _distances[distance](target=inputs)
168 | else:
169 | dist_func = partial(_distances[distance], inputs)
170 |
171 | # Init trackers
172 | best_dist = torch.full((batch_size,), float('inf'), device=device)
173 | best_adv = inputs.clone()
174 | adv_found = torch.zeros_like(best_dist, dtype=torch.bool)
175 | step_found = torch.full_like(best_dist, num_steps + 1)
176 |
177 | for i in range(num_steps):
178 |
179 | adv_inputs = inputs + δ
180 | logits = model(adv_inputs)
181 | dist = dist_func(adv_inputs)
182 |
183 | if i == 0:
184 | labels_infhot = torch.zeros_like(logits).scatter_(1, labels.unsqueeze(1), float('inf'))
185 | dlr_func = partial(difference_of_logits_ratio, labels=labels, labels_infhot=labels_infhot,
186 | targeted=targeted, ε=logit_tolerance)
187 |
188 | dlr = multiplier * dlr_func(logits)
189 |
190 | if i == 0:
191 | prev_dlr = dlr.detach()
192 | elif (i + 1) % check_steps == 0:
193 | improved_dlr = (dlr.detach() < τ * prev_dlr)
194 | ρ = torch.where(~(adv_found | improved_dlr), γ * ρ, ρ)
195 | prev_dlr = dlr.detach()
196 |
197 | if i:
198 | new_μ = grad(penalty(dlr, ρ, μ).sum(), dlr, only_inputs=True)[0]
199 | μ.lerp_(new_μ, weight=1 - α).clamp_(min=1e-6, max=1e12)
200 |
201 | is_adv = dlr < 0
202 | is_smaller = dist < best_dist
203 | is_both = is_adv & is_smaller
204 | step_found.masked_fill_((~adv_found) & is_adv, i)
205 | adv_found.logical_or_(is_adv)
206 | best_dist = torch.where(is_both, dist.detach(), best_dist)
207 | best_adv = torch.where(batch_view(is_both), adv_inputs.detach(), best_adv)
208 |
209 | if i == 0:
210 | loss = penalty(dlr, ρ, μ)
211 | else:
212 | loss = dist + penalty(dlr, ρ, μ)
213 | δ_grad = grad(loss.sum(), δ, only_inputs=True)[0]
214 |
215 | grad_norm = δ_grad.flatten(1).norm(p=2, dim=1)
216 | if init_lr_distance is not None and i == 0:
217 | randn_grad = torch.randn_like(δ_grad).renorm(dim=0, p=2, maxnorm=1)
218 | δ_grad = torch.where(batch_view(grad_norm <= 1e-6), randn_grad, δ_grad)
219 | lr = init_lr_finder(inputs, δ_grad, dist_func, target_distance=init_lr_distance)
220 |
221 | exp_decay = lr_reduction ** ((i - step_found).clamp_(min=0) / (num_steps - step_found))
222 | step_lr = lr * exp_decay
223 | square_avg.mul_(α_rms).addcmul_(δ_grad, δ_grad, value=1 - α_rms)
224 | momentum_buffer.mul_(momentum).addcdiv_(δ_grad, square_avg.sqrt().add_(1e-8))
225 | δ.data.addcmul_(momentum_buffer, batch_view(step_lr), value=-1)
226 |
227 | δ.data.add_(inputs).clamp_(min=0, max=1)
228 | if levels is not None:
229 | δ.data.mul_(levels - 1).round_().div_(levels - 1)
230 | δ.data.sub_(inputs)
231 |
232 | if callback:
233 | cb_best_dist = best_dist.masked_select(adv_found).mean()
234 | callback.accumulate_line([distance, 'dlr'], i, [dist.mean(), dlr.mean()])
235 | callback.accumulate_line(['μ_c', 'ρ_c'], i, [μ.mean(), ρ.mean()])
236 | callback.accumulate_line('grad_norm', i, grad_norm.mean())
237 | callback.accumulate_line(['best_{}'.format(distance), 'success', 'lr'], i,
238 | [cb_best_dist, adv_found.float().mean(), step_lr.mean()])
239 |
240 | if (i + 1) % (num_steps // 20) == 0 or (i + 1) == num_steps:
241 | callback.update_lines()
242 |
243 | return best_adv
244 |
--------------------------------------------------------------------------------
/adv_lib/attacks/carlini_wagner/__init__.py:
--------------------------------------------------------------------------------
1 | from .l2 import carlini_wagner_l2
2 | from .linf import carlini_wagner_linf
--------------------------------------------------------------------------------
/adv_lib/attacks/carlini_wagner/l2.py:
--------------------------------------------------------------------------------
1 | # Adapted from https://github.com/carlini/nn_robust_attacks
2 |
3 | from typing import Tuple, Optional
4 |
5 | import torch
6 | from torch import nn, optim, Tensor
7 | from torch.autograd import grad
8 |
9 | from adv_lib.utils.losses import difference_of_logits
10 | from adv_lib.utils.visdom_logger import VisdomLogger
11 |
12 |
13 | def carlini_wagner_l2(model: nn.Module,
14 | inputs: Tensor,
15 | labels: Tensor,
16 | targeted: bool = False,
17 | confidence: float = 0,
18 | learning_rate: float = 0.01,
19 | initial_const: float = 0.001,
20 | binary_search_steps: int = 9,
21 | max_iterations: int = 10000,
22 | abort_early: bool = True,
23 | callback: Optional[VisdomLogger] = None) -> Tensor:
24 | """
25 | Carlini and Wagner L2 attack from https://arxiv.org/abs/1608.04644.
26 |
27 | Parameters
28 | ----------
29 | model : nn.Module
30 | Model to attack.
31 | inputs : Tensor
32 | Inputs to attack. Should be in [0, 1].
33 | labels : Tensor
34 | Labels corresponding to the inputs if untargeted, else target labels.
35 | targeted : bool
36 | Whether to perform a targeted attack or not.
37 | confidence : float
38 | Confidence of adversarial examples: higher produces examples that are farther away, but more strongly classified
39 | as adversarial.
40 | learning_rate: float
41 | The learning rate for the attack algorithm. Smaller values produce better results but are slower to converge.
42 | initial_const : float
43 | The initial tradeoff-constant to use to tune the relative importance of distance and confidence. If
44 | binary_search_steps is large, the initial constant is not important.
45 | binary_search_steps : int
46 | The number of times we perform binary search to find the optimal tradeoff-constant between distance and
47 | confidence.
48 | max_iterations : int
49 | The maximum number of iterations. Larger values are more accurate; setting too small will require a large
50 | learning rate and will produce poor results.
51 | abort_early : bool
52 | If true, allows early aborts if gradient descent gets stuck.
53 | callback : Optional
54 |
55 | Returns
56 | -------
57 | adv_inputs : Tensor
58 | Modified inputs to be adversarial to the model.
59 |
60 | """
61 | device = inputs.device
62 | batch_size = len(inputs)
63 | batch_view = lambda tensor: tensor.view(batch_size, *[1] * (inputs.ndim - 1))
64 | t_inputs = (inputs * 2).sub_(1).mul_(1 - 1e-6).atanh_()
65 | multiplier = -1 if targeted else 1
66 |
67 | # set the lower and upper bounds accordingly
68 | c = torch.full((batch_size,), initial_const, device=device)
69 | lower_bound = torch.zeros_like(c)
70 | upper_bound = torch.full_like(c, 1e10)
71 |
72 | o_best_l2 = torch.full_like(c, float('inf'))
73 | o_best_adv = inputs.clone()
74 | o_adv_found = torch.zeros(batch_size, device=device, dtype=torch.bool)
75 |
76 | i_total = 0
77 | for outer_step in range(binary_search_steps):
78 |
79 | # setup the modifier variable and the optimizer
80 | modifier = torch.zeros_like(inputs, requires_grad=True)
81 | optimizer = optim.Adam([modifier], lr=learning_rate)
82 | best_l2 = torch.full_like(c, float('inf'))
83 | adv_found = torch.zeros(batch_size, device=device, dtype=torch.bool)
84 |
85 | # The last iteration (if we run many steps) repeat the search once.
86 | if (binary_search_steps >= 10) and outer_step == (binary_search_steps - 1):
87 | c = upper_bound
88 |
89 | prev = float('inf')
90 | for i in range(max_iterations):
91 |
92 | adv_inputs = (torch.tanh(t_inputs + modifier) + 1) / 2
93 | l2_squared = (adv_inputs - inputs).flatten(1).square().sum(1)
94 | l2 = l2_squared.detach().sqrt()
95 | logits = model(adv_inputs)
96 |
97 | if outer_step == 0 and i == 0:
98 | # setup the target variable, we need it to be in one-hot form for the loss function
99 | labels_onehot = torch.zeros_like(logits).scatter_(1, labels.unsqueeze(1), 1)
100 | labels_infhot = torch.zeros_like(logits).scatter_(1, labels.unsqueeze(1), float('inf'))
101 |
102 | # adjust the best result found so far
103 | predicted_classes = (logits - labels_onehot * confidence).argmax(1) if targeted else \
104 | (logits + labels_onehot * confidence).argmax(1)
105 |
106 | is_adv = (predicted_classes == labels) if targeted else (predicted_classes != labels)
107 | is_smaller = l2 < best_l2
108 | o_is_smaller = l2 < o_best_l2
109 | is_both = is_adv & is_smaller
110 | o_is_both = is_adv & o_is_smaller
111 |
112 | best_l2 = torch.where(is_both, l2, best_l2)
113 | adv_found.logical_or_(is_both)
114 | o_best_l2 = torch.where(o_is_both, l2, o_best_l2)
115 | o_adv_found.logical_or_(is_both)
116 | o_best_adv = torch.where(batch_view(o_is_both), adv_inputs.detach(), o_best_adv)
117 |
118 | logit_dists = multiplier * difference_of_logits(logits, labels, labels_infhot=labels_infhot)
119 | loss = l2_squared + c * (logit_dists + confidence).clamp_(min=0)
120 |
121 | # check if we should abort search if we're getting nowhere.
122 | if abort_early and i % (max_iterations // 10) == 0:
123 | if (loss > prev * 0.9999).all():
124 | break
125 | prev = loss.detach()
126 |
127 | optimizer.zero_grad(set_to_none=True)
128 | modifier.grad = grad(loss.sum(), modifier, only_inputs=True)[0]
129 | optimizer.step()
130 |
131 | if callback:
132 | i_total += 1
133 | callback.accumulate_line('logit_dist', i_total, logit_dists.mean())
134 | callback.accumulate_line('l2_norm', i_total, l2.mean())
135 | if i_total % (max_iterations // 20) == 0:
136 | callback.update_lines()
137 |
138 | if callback:
139 | best_l2 = o_best_l2.masked_select(o_adv_found).mean()
140 | callback.line(['success', 'best_l2', 'c'], outer_step, [o_adv_found.float().mean(), best_l2, c.mean()])
141 |
142 | # adjust the constant as needed
143 | upper_bound[adv_found] = torch.min(upper_bound[adv_found], c[adv_found])
144 | adv_not_found = ~adv_found
145 | lower_bound[adv_not_found] = torch.max(lower_bound[adv_not_found], c[adv_not_found])
146 | is_smaller = upper_bound < 1e9
147 | c[is_smaller] = (lower_bound[is_smaller] + upper_bound[is_smaller]) / 2
148 | c[(~is_smaller) & adv_not_found] *= 10
149 |
150 | # return the best solution found
151 | return o_best_adv
152 |
--------------------------------------------------------------------------------
/adv_lib/attacks/carlini_wagner/linf.py:
--------------------------------------------------------------------------------
1 | # Adapted from https://github.com/carlini/nn_robust_attacks
2 |
3 | from typing import Tuple, Optional
4 |
5 | import torch
6 | from torch import nn, optim, Tensor
7 |
8 | from adv_lib.utils.losses import difference_of_logits
9 | from adv_lib.utils.visdom_logger import VisdomLogger
10 |
11 |
12 | def carlini_wagner_linf(model: nn.Module,
13 | inputs: Tensor,
14 | labels: Tensor,
15 | targeted: bool = False,
16 | learning_rate: float = 0.01,
17 | max_iterations: int = 1000,
18 | initial_const: float = 1e-5,
19 | largest_const: float = 2e+1,
20 | const_factor: float = 2,
21 | reduce_const: bool = False,
22 | decrease_factor: float = 0.9,
23 | abort_early: bool = True,
24 | callback: Optional[VisdomLogger] = None) -> Tensor:
25 | """
26 | Carlini and Wagner Linf attack from https://arxiv.org/abs/1608.04644.
27 |
28 | Parameters
29 | ----------
30 | model : nn.Module
31 | Model to attack.
32 | inputs : Tensor
33 | Inputs to attack. Should be in [0, 1].
34 | labels : Tensor
35 | Labels corresponding to the inputs if untargeted, else target labels.
36 | targeted : bool
37 | Whether to perform a targeted attack or not.
38 | learning_rate: float
39 | The learning rate for the attack algorithm. Smaller values produce better results but are slower to converge.
40 | max_iterations : int
41 | The maximum number of iterations. Larger values are more accurate; setting too small will require a large
42 | learning rate and will produce poor results.
43 | initial_const : float
44 | The initial tradeoff-constant to use to tune the relative importance of distance and classification objective.
45 | largest_const : float
46 | The maximum tradeoff-constant to use to tune the relative importance of distance and classification objective.
47 | const_factor : float
48 | The multiplicative factor by which the constant is increased if the search failed.
49 | reduce_const : float
50 | If true, after each successful attack, make the constant smaller.
51 | decrease_factor : float
52 | Rate at which τ is decreased. Larger produces better quality results.
53 | abort_early : bool
54 | If true, allows early aborts if gradient descent gets stuck.
55 | image_constraints : Tuple[float, float]
56 | Minimum and maximum pixel values.
57 | callback : Optional
58 |
59 | Returns
60 | -------
61 | adv_inputs : Tensor
62 | Modified inputs to be adversarial to the model.
63 |
64 | """
65 | device = inputs.device
66 | batch_size = len(inputs)
67 | t_inputs = (inputs * 2).sub_(1).mul_(1 - 1e-6).atanh_()
68 | multiplier = -1 if targeted else 1
69 |
70 | # set modifier and the parameters used in the optimization
71 | modifier = torch.zeros_like(inputs)
72 | c = torch.full((batch_size,), initial_const, device=device, dtype=torch.float)
73 | τ = torch.ones(batch_size, device=device)
74 |
75 | o_adv_found = torch.zeros_like(c, dtype=torch.bool)
76 | o_best_linf = torch.ones_like(c)
77 | o_best_adv = inputs.clone()
78 |
79 | outer_loops = 0
80 | total_iters = 0
81 | while (to_optimize := (τ > 1 / 255) & (c < largest_const)).any():
82 |
83 | inputs_, t_inputs_, labels_ = inputs[to_optimize], t_inputs[to_optimize], labels[to_optimize]
84 | batch_view = lambda tensor: tensor.view(len(inputs_), *[1] * (inputs_.ndim - 1))
85 |
86 | if callback:
87 | callback.line(['const', 'τ'], outer_loops, [c[to_optimize].mean(), τ[to_optimize].mean()])
88 | callback.line(['success', 'best_linf'], outer_loops, [o_adv_found.float().mean(), best_linf.mean()])
89 |
90 | # setup the optimizer
91 | modifier_ = modifier[to_optimize].requires_grad_(True)
92 | optimizer = optim.Adam([modifier_], lr=learning_rate)
93 | c_, τ_ = c[to_optimize], τ[to_optimize]
94 |
95 | adv_found = torch.zeros(len(modifier_), device=device, dtype=torch.bool)
96 | best_linf = o_best_linf[to_optimize]
97 | best_adv = inputs_.clone()
98 |
99 | if callback:
100 | callback.line(['const', 'τ'], outer_loops, [c_.mean(), τ_.mean()])
101 | callback.line(['success', 'best_linf'], outer_loops, [o_adv_found.float().mean(), o_best_linf.mean()])
102 |
103 | for i in range(max_iterations):
104 |
105 | adv_inputs = (torch.tanh(t_inputs_ + modifier_) + 1) / 2
106 | linf = (adv_inputs.detach() - inputs_).flatten(1).norm(p=float('inf'), dim=1)
107 | logits = model(adv_inputs)
108 |
109 | if i == 0:
110 | labels_infhot = torch.zeros_like(logits).scatter_(1, labels[to_optimize].unsqueeze(1), float('inf'))
111 |
112 | # adjust the best result found so far
113 | predicted_classes = logits.argmax(1)
114 |
115 | is_adv = (predicted_classes == labels_) if targeted else (predicted_classes != labels_)
116 | is_smaller = linf < best_linf
117 | is_both = is_adv & is_smaller
118 | adv_found.logical_or_(is_both)
119 | best_linf = torch.where(is_both, linf, best_linf)
120 | best_adv = torch.where(batch_view(is_both), adv_inputs.detach(), best_adv)
121 |
122 | logit_dists = multiplier * difference_of_logits(logits, labels_, labels_infhot=labels_infhot)
123 | linf_loss = (adv_inputs - inputs_).abs_().sub_(batch_view(τ_)).clamp_(min=0).flatten(1).sum(1)
124 | loss = linf_loss + c_ * logit_dists.clamp_(min=0)
125 |
126 | # check if we should abort search
127 | if abort_early and (loss < 0.0001 * c_).all():
128 | break
129 |
130 | optimizer.zero_grad()
131 | loss.sum().backward()
132 | optimizer.step()
133 |
134 | if callback:
135 | callback.accumulate_line('logit_dist', total_iters, logit_dists.mean())
136 | callback.accumulate_line('linf_norm', total_iters, linf.mean())
137 |
138 | if (i + 1) % (max_iterations // 10) == 0 or (i + 1) == max_iterations:
139 | callback.update_lines()
140 |
141 | total_iters += 1
142 |
143 | o_adv_found[to_optimize] = adv_found | o_adv_found[to_optimize]
144 | o_best_linf[to_optimize] = torch.where(adv_found, best_linf, o_best_linf[to_optimize])
145 | o_best_adv[to_optimize] = torch.where(batch_view(adv_found), best_adv, o_best_adv[to_optimize])
146 | modifier[to_optimize] = modifier_.detach()
147 |
148 | smaller_τ_ = adv_found & (best_linf < τ_)
149 | τ_ = torch.where(smaller_τ_, best_linf, τ_)
150 | τ[to_optimize] = torch.where(adv_found, decrease_factor * τ_, τ_)
151 | c[to_optimize] = torch.where(~adv_found, const_factor * c_, c_)
152 | if reduce_const:
153 | c[to_optimize] = torch.where(adv_found, c[to_optimize] / 2, c[to_optimize])
154 |
155 | outer_loops += 1
156 |
157 | # return the best solution found
158 | return o_best_adv
159 |
--------------------------------------------------------------------------------
/adv_lib/attacks/decoupled_direction_norm.py:
--------------------------------------------------------------------------------
1 | import math
2 | from typing import Optional
3 |
4 | import torch
5 | from torch import Tensor, nn
6 | from torch.autograd import grad
7 | from torch.nn import functional as F
8 |
9 | from adv_lib.utils.visdom_logger import VisdomLogger
10 |
11 |
12 | def ddn(model: nn.Module,
13 | inputs: Tensor,
14 | labels: Tensor,
15 | targeted: bool = False,
16 | steps: int = 100,
17 | γ: float = 0.05,
18 | init_norm: float = 1.,
19 | levels: Optional[int] = 256,
20 | callback: Optional[VisdomLogger] = None) -> Tensor:
21 | """
22 | Decoupled Direction and Norm attack from https://arxiv.org/abs/1811.09600.
23 |
24 | Parameters
25 | ----------
26 | model : nn.Module
27 | Model to attack.
28 | inputs : Tensor
29 | Inputs to attack. Should be in [0, 1].
30 | labels : Tensor
31 | Labels corresponding to the inputs if untargeted, else target labels.
32 | targeted : bool
33 | Whether to perform a targeted attack or not.
34 | steps : int
35 | Number of optimization steps.
36 | γ : float
37 | Factor by which the norm will be modified. new_norm = norm * (1 + or - γ).
38 | init_norm : float
39 | Initial value for the norm of the attack.
40 | levels : int
41 | If not None, the returned adversarials will have quantized values to the specified number of levels.
42 | callback : Optional
43 |
44 | Returns
45 | -------
46 | adv_inputs : Tensor
47 | Modified inputs to be adversarial to the model.
48 |
49 | """
50 | if inputs.min() < 0 or inputs.max() > 1: raise ValueError('Input values should be in the [0, 1] range.')
51 | device = inputs.device
52 | batch_size = len(inputs)
53 | batch_view = lambda tensor: tensor.view(batch_size, *[1] * (inputs.ndim - 1))
54 |
55 | # Init variables
56 | multiplier = -1 if targeted else 1
57 | δ = torch.zeros_like(inputs, requires_grad=True)
58 | ε = torch.full((batch_size,), init_norm, device=device, dtype=torch.float)
59 | worst_norm = torch.max(inputs, 1 - inputs).flatten(1).norm(p=2, dim=1)
60 |
61 | # Init trackers
62 | best_l2 = worst_norm.clone()
63 | best_adv = inputs.clone()
64 | adv_found = torch.zeros(batch_size, dtype=torch.bool, device=device)
65 |
66 | for i in range(steps):
67 | l2 = δ.data.flatten(1).norm(p=2, dim=1)
68 | adv_inputs = inputs + δ
69 | logits = model(adv_inputs)
70 | pred_labels = logits.argmax(1)
71 | ce_loss = F.cross_entropy(logits, labels, reduction='none')
72 | loss = multiplier * ce_loss
73 |
74 | is_adv = (pred_labels == labels) if targeted else (pred_labels != labels)
75 | is_smaller = l2 < best_l2
76 | is_both = is_adv & is_smaller
77 | adv_found.logical_or_(is_adv)
78 | best_l2 = torch.where(is_both, l2, best_l2)
79 | best_adv = torch.where(batch_view(is_both), adv_inputs.detach(), best_adv)
80 |
81 | δ_grad = grad(loss.sum(), δ, only_inputs=True)[0]
82 | # renorming gradient
83 | grad_norms = δ_grad.flatten(1).norm(p=2, dim=1)
84 | δ_grad.div_(batch_view(grad_norms))
85 | # avoid nan or inf if gradient is 0
86 | if (zero_grad := (grad_norms < 1e-12)).any():
87 | δ_grad[zero_grad] = torch.randn_like(δ_grad[zero_grad])
88 |
89 | α = 0.01 + (1 - 0.01) * (1 + math.cos(math.pi * i / steps)) / 2
90 |
91 | if callback is not None:
92 | cosine = F.cosine_similarity(δ_grad.flatten(1), δ.data.flatten(1), dim=1).mean()
93 | callback.accumulate_line('ce', i, ce_loss.mean())
94 | callback_best = best_l2.masked_select(adv_found).mean()
95 | callback.accumulate_line(['ε', 'l2', 'best_l2'], i, [ε.mean(), l2.mean(), callback_best])
96 | callback.accumulate_line(['cosine', 'α', 'success'], i,
97 | [cosine, torch.tensor(α, device=device), adv_found.float().mean()])
98 |
99 | if (i + 1) % (steps // 20) == 0 or (i + 1) == steps:
100 | callback.update_lines()
101 |
102 | δ.data.add_(δ_grad, alpha=α)
103 |
104 | ε = torch.where(is_adv, (1 - γ) * ε, (1 + γ) * ε)
105 | ε = torch.minimum(ε, worst_norm)
106 |
107 | δ.data.mul_(batch_view(ε / δ.data.flatten(1).norm(p=2, dim=1)))
108 | δ.data.add_(inputs).clamp_(min=0, max=1)
109 | if levels is not None:
110 | δ.data.mul_(levels - 1).round_().div_(levels - 1)
111 | δ.data.sub_(inputs)
112 |
113 | return best_adv
114 |
--------------------------------------------------------------------------------
/adv_lib/attacks/deepfool.py:
--------------------------------------------------------------------------------
1 | import warnings
2 |
3 | import torch
4 | from torch import Tensor, nn
5 | from torch.autograd import grad
6 |
7 | from adv_lib.utils.attack_utils import get_all_targets
8 |
9 |
10 | def df(model: nn.Module,
11 | inputs: Tensor,
12 | labels: Tensor,
13 | targeted: bool = False,
14 | steps: int = 100,
15 | overshoot: float = 0.02,
16 | norm: float = 2,
17 | return_unsuccessful: bool = False,
18 | return_targets: bool = False) -> Tensor:
19 | """
20 | DeepFool attack from https://arxiv.org/abs/1511.04599. Properly implement parallel sample-wise early-stopping.
21 |
22 | Parameters
23 | ----------
24 | model : nn.Module
25 | Model to attack.
26 | inputs : Tensor
27 | Inputs to attack. Should be in [0, 1].
28 | labels : Tensor
29 | Labels corresponding to the inputs if untargeted, else target labels.
30 | targeted : bool
31 | Whether to perform a targeted attack or not.
32 | steps : int
33 | Maixmum number of attack steps.
34 | overshoot : float
35 | Ratio by which to overshoot the boundary estimated from linear model.
36 | norm : float
37 | Norm to minimize in {2, float('inf')}.
38 | return_unsuccessful : bool
39 | Whether to return unsuccessful adversarial inputs ; used by SuperDeepFool.
40 | return_unsuccessful : bool
41 | Whether to return last target labels ; used by SuperDeepFool.
42 |
43 | Returns
44 | -------
45 | adv_inputs : Tensor
46 | Modified inputs to be adversarial to the model.
47 |
48 | """
49 | if targeted:
50 | warnings.warn('DeepFool attack is untargeted only. Returning inputs.')
51 | return inputs
52 |
53 | if inputs.min() < 0 or inputs.max() > 1: raise ValueError('Input values should be in the [0, 1] range.')
54 | device = inputs.device
55 | batch_size = len(inputs)
56 | batch_view = lambda tensor: tensor.view(-1, *[1] * (inputs.ndim - 1))
57 |
58 | # Setup variables
59 | adv_inputs = inputs.clone()
60 | adv_inputs.requires_grad_(True)
61 |
62 | adv_out = inputs.clone()
63 | adv_found = torch.zeros(batch_size, dtype=torch.bool, device=device)
64 | if return_targets:
65 | targets = labels.clone()
66 |
67 | arange = torch.arange(batch_size, device=device)
68 | for i in range(steps):
69 |
70 | logits = model(adv_inputs)
71 |
72 | if i == 0:
73 | other_labels = get_all_targets(labels=labels, num_classes=logits.shape[1])
74 |
75 | pred_labels = logits.argmax(dim=1)
76 | is_adv = (pred_labels == labels) if targeted else (pred_labels != labels)
77 |
78 | if is_adv.any():
79 | adv_not_found = ~adv_found
80 | adv_out[adv_not_found] = torch.where(batch_view(is_adv), adv_inputs.detach(), adv_out[adv_not_found])
81 | adv_found.masked_scatter_(adv_not_found, is_adv)
82 | if is_adv.all():
83 | break
84 |
85 | not_adv = ~is_adv
86 | logits, labels, other_labels = logits[not_adv], labels[not_adv], other_labels[not_adv]
87 | arange = torch.arange(not_adv.sum(), device=device)
88 |
89 | f_prime = logits.gather(dim=1, index=other_labels) - logits.gather(dim=1, index=labels.unsqueeze(1))
90 | w_prime = []
91 | for j, f_prime_k in enumerate(f_prime.unbind(dim=1)):
92 | w_prime_k = grad(f_prime_k.sum(), inputs=adv_inputs, retain_graph=(j + 1) < f_prime.shape[1],
93 | only_inputs=True)[0]
94 | w_prime.append(w_prime_k)
95 | w_prime = torch.stack(w_prime, dim=1) # batch_size × num_classes × ...
96 |
97 | if is_adv.any():
98 | not_adv = ~is_adv
99 | adv_inputs, w_prime = adv_inputs[not_adv], w_prime[not_adv]
100 |
101 | if norm == 2:
102 | w_prime_norms = w_prime.flatten(2).norm(p=2, dim=2).clamp_(min=1e-6)
103 | elif norm == float('inf'):
104 | w_prime_norms = w_prime.flatten(2).norm(p=1, dim=2).clamp_(min=1e-6)
105 |
106 | distance = f_prime.detach().abs_().div_(w_prime_norms).add_(1e-4)
107 | l_hat = distance.argmin(dim=1)
108 |
109 | if return_targets:
110 | targets[~adv_found] = torch.where(l_hat >= labels, l_hat + 1, l_hat)
111 |
112 | if norm == 2:
113 | # 1e-4 added in original implementation
114 | scale = distance[arange, l_hat] / w_prime_norms[arange, l_hat]
115 | adv_inputs.data.addcmul_(batch_view(scale), w_prime[arange, l_hat], value=1 + overshoot)
116 | elif norm == float('inf'):
117 | adv_inputs.data.addcmul_(batch_view(distance[arange, l_hat]), w_prime[arange, l_hat].sign(),
118 | value=1 + overshoot)
119 | adv_inputs.data.clamp_(min=0, max=1)
120 |
121 | if return_unsuccessful and not adv_found.all():
122 | adv_out[~adv_found] = adv_inputs.detach()
123 |
124 | if return_targets:
125 | return adv_out, targets
126 |
127 | return adv_out
128 |
--------------------------------------------------------------------------------
/adv_lib/attacks/fast_adaptive_boundary/__init__.py:
--------------------------------------------------------------------------------
1 | from .fast_adaptive_boundary import fab
2 |
--------------------------------------------------------------------------------
/adv_lib/attacks/fast_adaptive_boundary/fast_adaptive_boundary.py:
--------------------------------------------------------------------------------
1 | # Adapted from https://github.com/fra31/auto-attack
2 |
3 | import warnings
4 | from functools import partial
5 | from typing import Optional, Tuple
6 |
7 | import torch
8 | from torch import Tensor, nn
9 | from torch.autograd import grad
10 |
11 | from adv_lib.utils.attack_utils import get_all_targets
12 | from .projections import projection_l1, projection_l2, projection_linf
13 |
14 |
15 | def fab(model: nn.Module,
16 | inputs: Tensor,
17 | labels: Tensor,
18 | norm: float,
19 | n_iter: int = 100,
20 | ε: Optional[float] = None,
21 | α_max: float = 0.1,
22 | β: float = 0.9,
23 | η: float = 1.05,
24 | restarts: Optional[int] = None,
25 | targeted_restarts: bool = False,
26 | targeted: bool = False) -> Tensor:
27 | """
28 | Fast Adaptive Boundary (FAB) attack from https://arxiv.org/abs/1907.02044
29 |
30 | Parameters
31 | ----------
32 | model : nn.Module
33 | Model to attack.
34 | inputs : Tensor
35 | Inputs to attack. Should be in [0, 1].
36 | labels : Tensor
37 | Labels corresponding to the inputs if untargeted, else target labels.
38 | norm : float
39 | Norm to minimize in {1, 2 ,float('inf')}.
40 | n_iter : int
41 | Number of optimization steps. This does not correspond to the number of forward / backward propagations for this
42 | attack. For a more comprehensive discussion on complexity, see section 4 of https://arxiv.org/abs/1907.02044 and
43 | for a comparison of complexities, see https://arxiv.org/abs/2011.11857.
44 | TL;DR: FAB performs 2 forwards and K - 1 (i.e. number of classes - 1) backwards per step in default mode. If
45 | `targeted_restarts` is `True`, performs `2 * restarts` forwards and `restarts` or (K - 1) backwards per step.
46 | ε : float
47 | Maximum norm of the random initialization for restarts.
48 | α_max : float
49 | Maximum weight for the biased gradient step. α = 0 corresponds to taking the projection of the `adv_inputs` on
50 | the decision hyperplane, while α = 1 corresponds to taking the projection of the `inputs` on the decision
51 | hyperplane.
52 | β : float
53 | Weight for the biased backward step, i.e. a linear interpolation between `inputs` and `adv_inputs` at step i.
54 | β = 0 corresponds to taking the original `inputs` and β = 1 corresponds to taking the `adv_inputs`.
55 | η : float
56 | Extrapolation for the optimization step. η = 1 corresponds to projecting the `adv_inputs` on the decision
57 | hyperplane. η > 1 corresponds to overshooting to increase the probability of crossing the decision hyperplane.
58 | restarts : int
59 | Number of random restarts in default mode; starts from the inputs in the first run and then add random noise for
60 | the consecutive restarts. Number of classes to attack if `targeted_restarts` is `True`.
61 | targeted_restarts : bool
62 | If `True`, performs targeted attack towards the most likely classes for the unperturbed `inputs`. If `restarts`
63 | is not given, this will attack each class (except the original class). If `restarts` is given, the `restarts`
64 | most likely classes will be attacked. If `restarts` is larger than K - 1, this will re-attack the most likely
65 | classes with random noise.
66 | targeted : bool
67 | Placeholder argument for library. FAB is only for untargeted attacks, so setting this to True will raise a
68 | warning and return the inputs.
69 |
70 | Returns
71 | -------
72 | adv_inputs : Tensor
73 | Modified inputs to be adversarial to the model.
74 |
75 | """
76 | if targeted:
77 | warnings.warn('FAB attack is untargeted only. Returning inputs.')
78 | return inputs
79 |
80 | best_adv = inputs.clone()
81 | best_norm = torch.full_like(labels, float('inf'), dtype=torch.float)
82 |
83 | fab_attack = partial(_fab, model=model, norm=norm, n_iter=n_iter, ε=ε, α_max=α_max, β=β, η=η)
84 |
85 | if targeted_restarts:
86 | logits = model(inputs)
87 | n_target_classes = logits.size(1) - 1
88 | labels_infhot = torch.zeros_like(logits).scatter_(1, labels.unsqueeze(1), float('inf'))
89 | k = min(restarts or n_target_classes, n_target_classes)
90 | topk_labels = (logits - labels_infhot).topk(k=k, dim=1).indices
91 |
92 | n_restarts = restarts or (n_target_classes if targeted_restarts else 1)
93 | for i in range(n_restarts):
94 |
95 | if targeted_restarts:
96 | target_labels = topk_labels[:, i % n_target_classes]
97 | adv_inputs_run, adv_found_run, norm_run = fab_attack(
98 | inputs=inputs, labels=labels, random_start=i >= n_target_classes, targets=target_labels, u=best_norm)
99 | else:
100 | adv_inputs_run, adv_found_run, norm_run = fab_attack(inputs=inputs, labels=labels, random_start=i != 0,
101 | u=best_norm)
102 |
103 | is_better_adv = adv_found_run & (norm_run < best_norm)
104 | best_norm[is_better_adv] = norm_run[is_better_adv]
105 | best_adv[is_better_adv] = adv_inputs_run[is_better_adv]
106 |
107 | return best_adv
108 |
109 |
110 | def get_best_diff_logits_grads(model: nn.Module,
111 | inputs: Tensor,
112 | labels: Tensor,
113 | other_labels: Tensor,
114 | q: float) -> Tuple[Tensor, Tensor]:
115 | batch_view = lambda tensor: tensor.view(-1, *[1] * (inputs.ndim - 1))
116 | min_ratio = torch.full_like(labels, float('inf'), dtype=torch.float)
117 | best_logit_diff, best_grad_diff = torch.zeros_like(labels, dtype=torch.float), torch.zeros_like(inputs)
118 |
119 | inputs.requires_grad_(True)
120 | logits = model(inputs)
121 | class_logits = logits.gather(1, labels.unsqueeze(1)).squeeze(1)
122 |
123 | n_other_labels = other_labels.size(1)
124 | for i, o_labels in enumerate(other_labels.transpose(0, 1)):
125 | other_logits = logits.gather(1, o_labels.unsqueeze(1)).squeeze(1)
126 | logits_diff = other_logits - class_logits
127 | grad_diff = grad(logits_diff.sum(), inputs, only_inputs=True, retain_graph=i + 1 != n_other_labels)[0]
128 | ratio = logits_diff.abs().div_(grad_diff.flatten(1).norm(p=q, dim=1).clamp_(min=1e-12))
129 |
130 | smaller_ratio = ratio < min_ratio
131 | min_ratio = torch.min(ratio, min_ratio)
132 | best_logit_diff = torch.where(smaller_ratio, logits_diff.detach(), best_logit_diff)
133 | best_grad_diff = torch.where(batch_view(smaller_ratio), grad_diff.detach(), best_grad_diff)
134 |
135 | inputs.detach_()
136 | return best_logit_diff, best_grad_diff
137 |
138 |
139 | def _fab(model: nn.Module,
140 | inputs: Tensor,
141 | labels: Tensor,
142 | norm: float,
143 | n_iter: int = 100,
144 | ε: Optional[float] = None,
145 | α_max: float = 0.1,
146 | β: float = 0.9,
147 | η: float = 1.05,
148 | random_start: bool = False,
149 | u: Optional[Tensor] = None,
150 | targets: Optional[Tensor] = None) -> Tuple[Tensor, Tensor, Tensor]:
151 | _projection_dual_default_ε = {
152 | 1: (projection_l1, float('inf'), 5),
153 | 2: (projection_l2, 2, 1),
154 | float('inf'): (projection_linf, 1, 0.3)
155 | }
156 |
157 | device = inputs.device
158 | batch_size = len(inputs)
159 | batch_view = lambda tensor: tensor.view(-1, *[1] * (inputs.ndim - 1))
160 | projection, dual_norm, default_ε = _projection_dual_default_ε[norm]
161 | ε = default_ε if ε is None else ε
162 |
163 | logits = model(inputs)
164 | if targets is not None:
165 | other_labels = targets.unsqueeze(1)
166 | else:
167 | other_labels = get_all_targets(labels=labels, num_classes=logits.shape[1])
168 |
169 | get_df_dg = partial(get_best_diff_logits_grads, model=model, labels=labels, other_labels=other_labels, q=dual_norm)
170 |
171 | adv_inputs = inputs.clone()
172 | adv_found = logits.argmax(dim=1) != labels
173 | best_norm = torch.full((batch_size,), float('inf'), device=device, dtype=torch.float) if u is None else u
174 | best_norm[adv_found] = 0
175 | best_adv = inputs.clone()
176 |
177 | if random_start:
178 | if norm == float('inf'):
179 | t = torch.rand_like(inputs).mul_(2).sub_(1)
180 | elif norm in [1, 2]:
181 | t = torch.randn_like(inputs)
182 |
183 | adv_inputs.add_(t.mul_(batch_view(best_norm.clamp(max=ε) / t.flatten(1).norm(p=norm, dim=1).mul_(2))))
184 | adv_inputs.clamp_(min=0.0, max=1.0)
185 |
186 | for i in range(n_iter):
187 | df, dg = get_df_dg(inputs=adv_inputs)
188 | b = (dg * adv_inputs).flatten(1).sum(dim=1).sub_(df)
189 | w = dg.flatten(1)
190 |
191 | d3 = projection(torch.cat((adv_inputs.flatten(1), inputs.flatten(1)), 0), w.repeat(2, 1), b.repeat(2))
192 | d1, d2 = map(lambda t: t.view_as(adv_inputs), torch.chunk(d3, 2, dim=0))
193 |
194 | a0 = batch_view(d3.flatten(1).norm(p=norm, dim=1).clamp_(min=1e-8))
195 | a1, a2 = torch.chunk(a0, 2, dim=0)
196 |
197 | α = a1.div_(a2.add_(a1)).clamp_(min=0, max=α_max)
198 | adv_inputs.add_(d1, alpha=η).mul_(1 - α).add_(inputs.add(d2, alpha=η).mul_(α)).clamp_(min=0, max=1)
199 |
200 | is_adv = model(adv_inputs).argmax(1) != labels
201 | adv_found.logical_or_(is_adv)
202 | adv_norm = (adv_inputs - inputs).flatten(1).norm(p=norm, dim=1)
203 | is_smaller = adv_norm < best_norm
204 | is_both = is_adv & is_smaller
205 | best_norm = torch.where(is_both, adv_norm, best_norm)
206 | best_adv = torch.where(batch_view(is_both), adv_inputs, best_adv)
207 |
208 | adv_inputs = torch.where(batch_view(is_adv), inputs + (adv_inputs - inputs) * β, adv_inputs)
209 |
210 | return best_adv, adv_found, best_norm
211 |
--------------------------------------------------------------------------------
/adv_lib/attacks/fast_adaptive_boundary/projections.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 | from torch import Tensor
5 | from torch.nn import functional as F
6 |
7 |
8 | def projection_l1(points_to_project: Tensor, w_hyperplane: Tensor, b_hyperplane: Tensor) -> Tensor:
9 | device = points_to_project.device
10 | t, w, b = points_to_project, w_hyperplane.clone(), b_hyperplane
11 |
12 | c = (w * t).sum(1).sub_(b)
13 | ind2 = (c >= 0).float().mul_(2).sub_(1)
14 | w.mul_(ind2.unsqueeze(1))
15 | c.mul_(ind2)
16 |
17 | w_abs = w.abs()
18 | r = (1 / w_abs).clamp_(max=1e12)
19 | indr = torch.argsort(r, dim=1)
20 | indr_rev = torch.argsort(indr)
21 |
22 | d = (w < 0).float().sub_(t).mul_(w != 0)
23 | ds = torch.min(-w * t, (1 - t).mul_(w)).gather(1, indr)
24 | ds2 = torch.cat((c.unsqueeze(-1), ds), 1)
25 | s = torch.cumsum(ds2, dim=1)
26 |
27 | c2 = s[:, -1] < 0
28 |
29 | lb = torch.zeros(c2.sum(), device=device)
30 | ub = torch.full_like(lb, s.shape[1])
31 | nitermax = math.ceil(math.log2(w.shape[1]))
32 |
33 | s_ = s[c2]
34 | for counter in range(nitermax):
35 | counter4 = (lb + ub).mul_(0.5).floor_()
36 | counter2 = counter4.long().unsqueeze(1)
37 | c3 = s_.gather(1, counter2).squeeze(1) > 0
38 | lb = torch.where(c3, counter4, lb)
39 | ub = torch.where(c3, ub, counter4)
40 |
41 | lb2 = lb.long()
42 |
43 | if c2.any():
44 | indr = indr[c2].gather(1, lb2.unsqueeze(1)).squeeze(1)
45 | u = torch.arange(0, w.shape[0], device=device).unsqueeze(1)
46 | u2 = torch.arange(0, w.shape[1], device=device, dtype=torch.float).unsqueeze(0)
47 | alpha = s[c2, lb2].neg().div_(w[c2, indr])
48 | c5 = u2 < lb.unsqueeze(-1)
49 | u3 = c5[u[:c5.shape[0]], indr_rev[c2]]
50 | d[c2] *= u3
51 | d[c2, indr] = alpha
52 |
53 | return d.mul_(w_abs > 1e-8)
54 |
55 |
56 | def projection_l2(points_to_project: Tensor, w_hyperplane: Tensor, b_hyperplane: Tensor) -> Tensor:
57 | device = points_to_project.device
58 | t, w, b = points_to_project, w_hyperplane.clone(), b_hyperplane
59 |
60 | c = (w * t).sum(1).sub_(b)
61 | ind2 = (c >= 0).float().mul_(2).sub_(1)
62 | w.mul_(ind2.unsqueeze(1))
63 | w_nonzero = w.abs() > 1e-8
64 | c.mul_(ind2)
65 |
66 | r = torch.maximum(t / w, (t - 1).div_(w)).clamp_(min=-1e12, max=1e12)
67 | r.masked_fill_(~w_nonzero, 1e12)
68 | r[r == -1e12] *= -1
69 | rs, indr = torch.sort(r, dim=1)
70 | rs2 = F.pad(rs[:, 1:], (0, 1))
71 | rs.masked_fill_(rs == 1e12, 0)
72 | rs2.masked_fill_(rs2 == 1e12, 0)
73 |
74 | w3s = w.square().gather(1, indr)
75 | w5 = w3s.sum(dim=1, keepdim=True)
76 | ws = w5 - torch.cumsum(w3s, dim=1)
77 | d = (r * w).neg_()
78 | d.mul_(w_nonzero)
79 | s = torch.cat((-w5 * rs[:, 0:1], torch.cumsum((rs - rs2).mul_(ws), dim=1).sub_(w5 * rs[:, 0:1])), 1)
80 |
81 | c4 = s[:, 0] + c < 0
82 | c3 = (d * w).sum(dim=1).add_(c) > 0
83 | c2 = ~(c4 | c3)
84 |
85 | lb = torch.zeros(c2.sum(), device=device)
86 | ub = torch.full_like(lb, w.shape[1] - 1)
87 | nitermax = math.ceil(math.log2(w.shape[1]))
88 |
89 | s_, c_ = s[c2], c[c2]
90 | for counter in range(nitermax):
91 | counter4 = (lb + ub).mul_(0.5).floor_()
92 | counter2 = counter4.long().unsqueeze(1)
93 | c3 = s_.gather(1, counter2).squeeze(1).add_(c_) > 0
94 | lb = torch.where(c3, counter4, lb)
95 | ub = torch.where(c3, ub, counter4)
96 |
97 | lb = lb.long()
98 |
99 | if c4.any():
100 | alpha = c[c4] / w5[c4].squeeze(-1)
101 | d[c4] = -alpha.unsqueeze(-1) * w[c4]
102 |
103 | if c2.any():
104 | alpha = (s[c2, lb] + c[c2]).div_(ws[c2, lb]).add_(rs[c2, lb])
105 | alpha[ws[c2, lb] == 0] = 0
106 | c5 = alpha.unsqueeze(-1) > r[c2]
107 | d[c2] = (d[c2] * c5).sub_((~c5).float().mul_(alpha.unsqueeze(-1)).mul_(w[c2]))
108 |
109 | return d.mul_(w_nonzero)
110 |
111 |
112 | def projection_linf(points_to_project: Tensor, w_hyperplane: Tensor, b_hyperplane: Tensor) -> Tensor:
113 | device = points_to_project.device
114 | t, w, b = points_to_project, w_hyperplane.clone(), b_hyperplane.clone()
115 |
116 | sign = ((w * t).sum(1).sub_(b) >= 0).float().mul_(2).sub_(1)
117 | w.mul_(sign.unsqueeze(1))
118 | b.mul_(sign)
119 |
120 | a = (w < 0).float()
121 | d = (a - t).mul_(w != 0)
122 |
123 | p = (2 * a).sub_(1).mul_(t).neg_().add_(a)
124 | indp = torch.argsort(p, dim=1)
125 |
126 | b.sub_((w * t).sum(1))
127 | b0 = (w * d).sum(1)
128 |
129 | indp2 = indp.flip((1,))
130 | ws = w.gather(1, indp2)
131 | bs2 = -ws * d.gather(1, indp2)
132 |
133 | s = torch.cumsum(ws.abs_(), dim=1)
134 | sb = torch.cumsum(bs2, dim=1).add_(b0.unsqueeze(1))
135 |
136 | b2 = sb[:, -1] - s[:, -1] * p.gather(1, indp[:, 0:1]).squeeze(1)
137 | c_l = b - b2 > 0
138 | c2 = (b - b0 > 0) & (~c_l)
139 | lb = torch.zeros(c2.sum(), device=device)
140 | ub = torch.full_like(lb, w.shape[1] - 1)
141 | nitermax = math.ceil(math.log2(w.shape[1]))
142 |
143 | indp_, sb_, s_, p_, b_ = indp[c2], sb[c2], s[c2], p[c2], b[c2]
144 | for counter in range(nitermax):
145 | counter4 = (lb + ub).mul_(0.5).floor_()
146 |
147 | counter2 = counter4.long().unsqueeze(1)
148 | indcurr = indp_.gather(1, indp_.size(1) - 1 - counter2)
149 | b2 = sb_.gather(1, counter2).sub_(s_.gather(1, counter2).mul_(p_.gather(1, indcurr))).squeeze(1)
150 | c = b_ - b2 > 0
151 |
152 | lb = torch.where(c, counter4, lb)
153 | ub = torch.where(c, ub, counter4)
154 |
155 | lb = lb.long()
156 |
157 | if c_l.any():
158 | lmbd_opt = (b[c_l] - sb[c_l, -1]).div_(-s[c_l, -1]).clamp_(min=0).unsqueeze_(-1)
159 | d[c_l] = (2 * a[c_l]).sub_(1).mul_(lmbd_opt)
160 |
161 | lmbd_opt = (b[c2] - sb[c2, lb]).div_(-s[c2, lb]).clamp_(min=0).unsqueeze_(-1)
162 | d[c2] = torch.minimum(lmbd_opt, d[c2]).mul_(a[c2]).add_(torch.maximum(-lmbd_opt, d[c2]).mul_(1 - a[c2]))
163 |
164 | return d.mul_(w != 0)
165 |
--------------------------------------------------------------------------------
/adv_lib/attacks/fast_minimum_norm.py:
--------------------------------------------------------------------------------
1 | # Adapted from https://github.com/maurapintor/Fast-Minimum-Norm-FMN-Attack
2 |
3 | import math
4 | from functools import partial
5 | from typing import Optional
6 |
7 | import torch
8 | from torch import Tensor, nn
9 | from torch.autograd import grad
10 |
11 | from adv_lib.utils.losses import difference_of_logits
12 | from adv_lib.utils.projections import clamp_, l1_ball_euclidean_projection
13 |
14 |
15 | def l0_projection_(δ: Tensor, ε: Tensor) -> Tensor:
16 | """In-place l0 projection"""
17 | δ = δ.flatten(1)
18 | δ_abs = δ.abs()
19 | thresholds = δ_abs.topk(k=ε.long().max(), dim=1).values.gather(1, (ε.long().unsqueeze(1) - 1).clamp_(min=0))
20 | δ[δ_abs < thresholds] = 0
21 | return δ
22 |
23 |
24 | def l1_projection_(δ: Tensor, ε: Tensor) -> Tensor:
25 | """In-place l1 projection"""
26 | δ = l1_ball_euclidean_projection(x=δ.flatten(1), ε=ε, inplace=True)
27 | return δ
28 |
29 |
30 | def l2_projection_(δ: Tensor, ε: Tensor) -> Tensor:
31 | """In-place l2 projection"""
32 | δ = δ.flatten(1)
33 | l2_norms = δ.norm(p=2, dim=1, keepdim=True).clamp_(min=1e-12)
34 | δ.mul_(ε.unsqueeze(1) / l2_norms).clamp_(max=1)
35 | return δ
36 |
37 |
38 | def linf_projection_(δ: Tensor, ε: Tensor) -> Tensor:
39 | """In-place linf projection"""
40 | δ, ε = δ.flatten(1), ε.unsqueeze(1)
41 | δ = clamp_(δ, lower=-ε, upper=ε)
42 | return δ
43 |
44 |
45 | def l0_mid_points(x0: Tensor, x1: Tensor, ε: Tensor) -> Tensor:
46 | n_features = x0[0].numel()
47 | δ = l0_projection_(δ=x1 - x0, ε=n_features * ε)
48 | return δ.view_as(x0).add_(x0)
49 |
50 |
51 | def l1_mid_points(x0: Tensor, x1: Tensor, ε: Tensor) -> Tensor:
52 | threshold = (1 - ε).unsqueeze(1)
53 | δ = (x1 - x0).flatten(1)
54 | δ_abs = δ.abs()
55 | mask = δ_abs <= threshold
56 | mid_points = δ_abs.sub_(threshold).copysign_(δ)
57 | mid_points[mask] = 0
58 | return mid_points.view_as(x0).add_(x0)
59 |
60 |
61 | def l2_mid_points(x0: Tensor, x1: Tensor, ε: Tensor) -> Tensor:
62 | return torch.lerp(x0.flatten(1), x1.flatten(1), weight=ε.unsqueeze(1)).view_as(x0)
63 |
64 |
65 | def linf_mid_points(x0: Tensor, x1: Tensor, ε: Tensor) -> Tensor:
66 | ε = ε.unsqueeze(1)
67 | δ = clamp_((x1 - x0).flatten(1), lower=-ε, upper=ε)
68 | return δ.view_as(x0).add_(x0)
69 |
70 |
71 | def fmn(model: nn.Module,
72 | inputs: Tensor,
73 | labels: Tensor,
74 | norm: float,
75 | targeted: bool = False,
76 | steps: int = 10,
77 | α_init: float = 1.0,
78 | α_final: Optional[float] = None,
79 | γ_init: float = 0.05,
80 | γ_final: float = 0.001,
81 | starting_points: Optional[Tensor] = None,
82 | binary_search_steps: int = 10) -> Tensor:
83 | """
84 | Fast Minimum-Norm attack from https://arxiv.org/abs/2102.12827.
85 |
86 | Parameters
87 | ----------
88 | model : nn.Module
89 | Model to attack.
90 | inputs : Tensor
91 | Inputs to attack. Should be in [0, 1].
92 | labels : Tensor
93 | Labels corresponding to the inputs if untargeted, else target labels.
94 | norm : float
95 | Norm to minimize in {0, 1, 2 ,float('inf')}.
96 | targeted : bool
97 | Whether to perform a targeted attack or not.
98 | steps : int
99 | Number of optimization steps.
100 | α_init : float
101 | Initial step size.
102 | α_final : float
103 | Final step size after cosine annealing.
104 | γ_init : float
105 | Initial factor by which ε is modified: ε = ε * (1 + or - γ).
106 | γ_final : float
107 | Final factor, after cosine annealing, by which ε is modified.
108 | starting_points : Tensor
109 | Optional warm-start for the attack.
110 | binary_search_steps : int
111 | Number of binary search steps to find the decision boundary between inputs and starting_points.
112 |
113 | Returns
114 | -------
115 | adv_inputs : Tensor
116 | Modified inputs to be adversarial to the model.
117 |
118 | """
119 | _dual_projection_mid_points = {
120 | 0: (None, l0_projection_, l0_mid_points),
121 | 1: (float('inf'), l1_projection_, l1_mid_points),
122 | 2: (2, l2_projection_, l2_mid_points),
123 | float('inf'): (1, linf_projection_, linf_mid_points),
124 | }
125 | if inputs.min() < 0 or inputs.max() > 1: raise ValueError('Input values should be in the [0, 1] range.')
126 | device = inputs.device
127 | batch_size = len(inputs)
128 | batch_view = lambda tensor: tensor.view(batch_size, *[1] * (inputs.ndim - 1))
129 | dual, projection, mid_point = _dual_projection_mid_points[norm]
130 | α_final = α_init / 100 if α_final is None else α_final
131 | multiplier = 1 if targeted else -1
132 |
133 | # If starting_points is provided, search for the boundary
134 | if starting_points is not None:
135 | start_preds = model(starting_points).argmax(dim=1)
136 | is_adv = (start_preds == labels) if targeted else (start_preds != labels)
137 | if not is_adv.all():
138 | raise ValueError('Starting points are not all adversarial.')
139 | lower_bound = torch.zeros(batch_size, device=device)
140 | upper_bound = torch.ones(batch_size, device=device)
141 | for _ in range(binary_search_steps):
142 | ε = (lower_bound + upper_bound) / 2
143 | mid_points = mid_point(x0=inputs, x1=starting_points, ε=ε)
144 | pred_labels = model(mid_points).argmax(dim=1)
145 | is_adv = (pred_labels == labels) if targeted else (pred_labels != labels)
146 | lower_bound = torch.where(is_adv, lower_bound, ε)
147 | upper_bound = torch.where(is_adv, ε, upper_bound)
148 |
149 | δ = mid_point(x0=inputs, x1=starting_points, ε=upper_bound).sub_(inputs)
150 | else:
151 | δ = torch.zeros_like(inputs)
152 | δ.requires_grad_(True)
153 |
154 | if norm == 0:
155 | ε = torch.ones(batch_size, device=device) if starting_points is None else δ.flatten(1).norm(p=0, dim=1)
156 | else:
157 | ε = torch.full((batch_size,), float('inf'), device=device)
158 |
159 | # Init trackers
160 | worst_norm = torch.maximum(inputs, 1 - inputs).flatten(1).norm(p=norm, dim=1)
161 | best_norm = worst_norm.clone()
162 | best_adv = inputs.clone()
163 | adv_found = torch.zeros(batch_size, dtype=torch.bool, device=device)
164 |
165 | for i in range(steps):
166 | cosine = (1 + math.cos(math.pi * i / steps)) / 2
167 | α = α_final + (α_init - α_final) * cosine
168 | γ = γ_final + (γ_init - γ_final) * cosine
169 |
170 | δ_norm = δ.data.flatten(1).norm(p=norm, dim=1)
171 | adv_inputs = inputs + δ
172 | logits = model(adv_inputs)
173 | pred_labels = logits.argmax(dim=1)
174 |
175 | if i == 0:
176 | labels_infhot = torch.zeros_like(logits).scatter_(1, labels.unsqueeze(1), float('inf'))
177 | logit_diff_func = partial(difference_of_logits, labels=labels, labels_infhot=labels_infhot)
178 |
179 | logit_diffs = logit_diff_func(logits=logits)
180 | loss = multiplier * logit_diffs
181 | δ_grad = grad(loss.sum(), δ, only_inputs=True)[0]
182 |
183 | is_adv = (pred_labels == labels) if targeted else (pred_labels != labels)
184 | is_smaller = δ_norm < best_norm
185 | is_both = is_adv & is_smaller
186 | adv_found.logical_or_(is_adv)
187 | best_norm = torch.where(is_both, δ_norm, best_norm)
188 | best_adv = torch.where(batch_view(is_both), adv_inputs.detach(), best_adv)
189 |
190 | if norm == 0:
191 | ε = torch.where(is_adv,
192 | torch.minimum(torch.minimum(ε - 1, (ε * (1 - γ)).floor_()), best_norm),
193 | torch.maximum(ε + 1, (ε * (1 + γ)).floor_()))
194 | ε.clamp_(min=0)
195 | else:
196 | distance_to_boundary = loss.detach().abs_().div_(δ_grad.flatten(1).norm(p=dual, dim=1).clamp_(min=1e-12))
197 | ε = torch.where(is_adv,
198 | torch.minimum(ε * (1 - γ), best_norm),
199 | torch.where(adv_found, ε * (1 + γ), δ_norm + distance_to_boundary))
200 |
201 | # clip ε
202 | ε = torch.minimum(ε, worst_norm)
203 | # gradient ascent step with normalized gradient
204 | grad_l2_norms = δ_grad.flatten(1).norm(p=2, dim=1).clamp_(min=1e-12)
205 | δ.data.addcdiv_(δ_grad, batch_view(grad_l2_norms), value=α)
206 | # project in place
207 | projection(δ=δ.data, ε=ε)
208 | # clamp
209 | δ.data.add_(inputs).clamp_(min=0, max=1).sub_(inputs)
210 |
211 | return best_adv
212 |
--------------------------------------------------------------------------------
/adv_lib/attacks/perceptual_color_attacks/__init__.py:
--------------------------------------------------------------------------------
1 | from .perceptual_color_distance_al import perc_al
--------------------------------------------------------------------------------
/adv_lib/attacks/perceptual_color_attacks/differential_color_functions.py:
--------------------------------------------------------------------------------
1 | # Adapted from https://github.com/ZhengyuZhao/PerC-Adversarial
2 | import numpy as np
3 | import torch
4 | from torch import Tensor
5 |
6 |
7 | def rgb2xyz(rgb_image):
8 | device = rgb_image.device
9 | mt = torch.tensor([[0.4124, 0.3576, 0.1805],
10 | [0.2126, 0.7152, 0.0722],
11 | [0.0193, 0.1192, 0.9504]], device=device)
12 | mask1 = (rgb_image > 0.0405).float()
13 | mask1_no = 1 - mask1
14 | temp_img = mask1 * (((rgb_image + 0.055) / 1.055) ** 2.4)
15 | temp_img = temp_img + mask1_no * (rgb_image / 12.92)
16 | temp_img = 100 * temp_img
17 |
18 | res = torch.matmul(mt, temp_img.permute(1, 0, 2, 3).contiguous().view(3, -1)).view(
19 | 3, rgb_image.size(0), rgb_image.size(2), rgb_image.size(3)).permute(1, 0, 2, 3)
20 | return res
21 |
22 |
23 | def xyz_lab(xyz_image):
24 | mask_value_0 = (xyz_image == 0).float()
25 | mask_value_0_no = 1 - mask_value_0
26 | xyz_image = xyz_image + 0.0001 * mask_value_0
27 | mask1 = (xyz_image > 0.008856).float()
28 | mask1_no = 1 - mask1
29 | res = mask1 * (xyz_image) ** (1 / 3)
30 | res = res + mask1_no * ((7.787 * xyz_image) + (16 / 116))
31 | res = res * mask_value_0_no
32 | return res
33 |
34 |
35 | def rgb2lab_diff(rgb_image):
36 | """
37 | Function to convert a batch of image tensors from RGB space to CIELAB space.
38 | parameters: xn, yn, zn are the CIE XYZ tristimulus values of the reference white point.
39 | Here use the standard Illuminant D65 with normalization Y = 100.
40 | """
41 | rgb_image = rgb_image
42 | res = torch.zeros_like(rgb_image)
43 | xyz_image = rgb2xyz(rgb_image)
44 |
45 | xn = 95.0489
46 | yn = 100
47 | zn = 108.8840
48 |
49 | x = xyz_image[:, 0, :, :]
50 | y = xyz_image[:, 1, :, :]
51 | z = xyz_image[:, 2, :, :]
52 |
53 | L = 116 * xyz_lab(y / yn) - 16
54 | a = 500 * (xyz_lab(x / xn) - xyz_lab(y / yn))
55 | b = 200 * (xyz_lab(y / yn) - xyz_lab(z / zn))
56 | res[:, 0, :, :] = L
57 | res[:, 1, :, :] = a
58 | res[:, 2, :, :] = b
59 |
60 | return res
61 |
62 |
63 | def degrees(n): return n * (180. / np.pi)
64 |
65 |
66 | def radians(n): return n * (np.pi / 180.)
67 |
68 |
69 | def hpf_diff(x, y):
70 | mask1 = ((x == 0) * (y == 0)).float()
71 | mask1_no = 1 - mask1
72 |
73 | tmphp = degrees(torch.atan2(x * mask1_no, y * mask1_no))
74 | tmphp1 = tmphp * (tmphp >= 0).float()
75 | tmphp2 = (360 + tmphp) * (tmphp < 0).float()
76 |
77 | return tmphp1 + tmphp2
78 |
79 |
80 | def dhpf_diff(c1, c2, h1p, h2p):
81 | mask1 = ((c1 * c2) == 0).float()
82 | mask1_no = 1 - mask1
83 | res1 = (h2p - h1p) * mask1_no * (torch.abs(h2p - h1p) <= 180).float()
84 | res2 = ((h2p - h1p) - 360) * ((h2p - h1p) > 180).float() * mask1_no
85 | res3 = ((h2p - h1p) + 360) * ((h2p - h1p) < -180).float() * mask1_no
86 |
87 | return res1 + res2 + res3
88 |
89 |
90 | def ahpf_diff(c1, c2, h1p, h2p):
91 | mask1 = ((c1 * c2) == 0).float()
92 | mask1_no = 1 - mask1
93 | mask2 = (torch.abs(h2p - h1p) <= 180).float()
94 | mask2_no = 1 - mask2
95 | mask3 = (torch.abs(h2p + h1p) < 360).float()
96 | mask3_no = 1 - mask3
97 |
98 | res1 = (h1p + h2p) * mask1_no * mask2
99 | res2 = (h1p + h2p + 360.) * mask1_no * mask2_no * mask3
100 | res3 = (h1p + h2p - 360.) * mask1_no * mask2_no * mask3_no
101 | res = (res1 + res2 + res3) + (res1 + res2 + res3) * mask1
102 | return res * 0.5
103 |
104 |
105 | def ciede2000_diff(lab1, lab2):
106 | """
107 | CIEDE2000 metric to claculate the color distance map for a batch of image tensors defined in CIELAB space
108 |
109 | This version contains errors:
110 | - Typo in the formula for T: "- 39" should be "- 30"
111 | - 0.0001s added for numerical stability change the conditional values
112 |
113 | """
114 | L1 = lab1[:, 0, :, :]
115 | A1 = lab1[:, 1, :, :]
116 | B1 = lab1[:, 2, :, :]
117 | L2 = lab2[:, 0, :, :]
118 | A2 = lab2[:, 1, :, :]
119 | B2 = lab2[:, 2, :, :]
120 | kL = 1
121 | kC = 1
122 | kH = 1
123 |
124 | mask_value_0_input1 = ((A1 == 0) * (B1 == 0)).float()
125 | mask_value_0_input2 = ((A2 == 0) * (B2 == 0)).float()
126 | mask_value_0_input1_no = 1 - mask_value_0_input1
127 | mask_value_0_input2_no = 1 - mask_value_0_input2
128 | B1 = B1 + 0.0001 * mask_value_0_input1
129 | B2 = B2 + 0.0001 * mask_value_0_input2
130 |
131 | C1 = torch.sqrt((A1 ** 2.) + (B1 ** 2.))
132 | C2 = torch.sqrt((A2 ** 2.) + (B2 ** 2.))
133 |
134 | aC1C2 = (C1 + C2) / 2.
135 | G = 0.5 * (1. - torch.sqrt((aC1C2 ** 7.) / ((aC1C2 ** 7.) + (25 ** 7.))))
136 | a1P = (1. + G) * A1
137 | a2P = (1. + G) * A2
138 | c1P = torch.sqrt((a1P ** 2.) + (B1 ** 2.))
139 | c2P = torch.sqrt((a2P ** 2.) + (B2 ** 2.))
140 |
141 | h1P = hpf_diff(B1, a1P)
142 | h2P = hpf_diff(B2, a2P)
143 | h1P = h1P * mask_value_0_input1_no
144 | h2P = h2P * mask_value_0_input2_no
145 |
146 | dLP = L2 - L1
147 | dCP = c2P - c1P
148 | dhP = dhpf_diff(C1, C2, h1P, h2P)
149 | dHP = 2. * torch.sqrt(c1P * c2P) * torch.sin(radians(dhP) / 2.)
150 | mask_0_no = 1 - torch.max(mask_value_0_input1, mask_value_0_input2)
151 | dHP = dHP * mask_0_no
152 |
153 | aL = (L1 + L2) / 2.
154 | aCP = (c1P + c2P) / 2.
155 | aHP = ahpf_diff(C1, C2, h1P, h2P)
156 | T = 1. - 0.17 * torch.cos(radians(aHP - 39)) + 0.24 * torch.cos(radians(2. * aHP)) + 0.32 * torch.cos(
157 | radians(3. * aHP + 6.)) - 0.2 * torch.cos(radians(4. * aHP - 63.))
158 | dRO = 30. * torch.exp(-1. * (((aHP - 275.) / 25.) ** 2.))
159 | rC = torch.sqrt((aCP ** 7.) / ((aCP ** 7.) + (25. ** 7.)))
160 | sL = 1. + ((0.015 * ((aL - 50.) ** 2.)) / torch.sqrt(20. + ((aL - 50.) ** 2.)))
161 |
162 | sC = 1. + 0.045 * aCP
163 | sH = 1. + 0.015 * aCP * T
164 | rT = -2. * rC * torch.sin(radians(2. * dRO))
165 |
166 | # res_square=((dLP / (sL * kL)) ** 2.) + ((dCP / (sC * kC)) ** 2.) + ((dHP / (sH * kH)) ** 2.) + rT * (dCP / (sC * kC)) * (dHP / (sH * kH))
167 |
168 | res_square = ((dLP / (sL * kL)) ** 2.) + ((dCP / (sC * kC)) ** 2.) * mask_0_no + (
169 | (dHP / (sH * kH)) ** 2.) * mask_0_no + rT * (dCP / (sC * kC)) * (dHP / (sH * kH)) * mask_0_no
170 | mask_0 = (res_square <= 0).float()
171 | mask_0_no = 1 - mask_0
172 | res_square = res_square + 0.0001 * mask_0
173 | res = torch.sqrt(res_square)
174 | res = res * mask_0_no
175 |
176 | return res
177 |
178 |
179 | def ciede2000_loss(input: Tensor, target: Tensor) -> Tensor:
180 | Lab_input, Lab_target = map(rgb2lab_diff, (input, target))
181 | return ciede2000_diff(Lab_input, Lab_target).flatten(1).norm(p=2, dim=1)
182 |
--------------------------------------------------------------------------------
/adv_lib/attacks/perceptual_color_attacks/perceptual_color_distance_al.py:
--------------------------------------------------------------------------------
1 | # Adapted from https://github.com/ZhengyuZhao/PerC-Adversarial
2 |
3 | from math import pi, cos
4 |
5 | import torch
6 | from torch import nn, Tensor
7 | from torch.autograd import grad
8 |
9 | from .differential_color_functions import rgb2lab_diff, ciede2000_diff
10 |
11 |
12 | def quantization(x):
13 | """quantize the continus image tensors into 255 levels (8 bit encoding)"""
14 | x_quan = torch.round(x * 255) / 255
15 | return x_quan
16 |
17 |
18 | def perc_al(model: nn.Module,
19 | images: Tensor,
20 | labels: Tensor,
21 | num_classes: int,
22 | targeted: bool = False,
23 | max_iterations: int = 1000,
24 | alpha_l_init: float = 1.,
25 | alpha_c_init: float = 0.5,
26 | confidence: float = 0, **kwargs) -> Tensor:
27 | """
28 | PerC_AL: Alternating Loss of Classification and Color Differences to achieve imperceptibile perturbations with few
29 | iterations. Adapted from https://github.com/ZhengyuZhao/PerC-Adversarial.
30 |
31 | Parameters
32 | ----------
33 | model : nn.Module
34 | Model to fool.
35 | images : Tensor
36 | Batch of image examples in the range of [0,1].
37 | labels : Tensor
38 | Original labels if untargeted, else labels of targets.
39 | targeted : bool, optional
40 | Whether to perform a targeted adversary or not.
41 | max_iterations : int
42 | Number of iterations for the optimization.
43 | alpha_l_init: float
44 | step size for updating perturbations with respect to classification loss
45 | alpha_c_init: float
46 | step size for updating perturbations with respect to perceptual color differences. for relatively easy
47 | untargeted case, alpha_c_init is adjusted to a smaller value (e.g., 0.1 is used in the paper)
48 | confidence : float, optional
49 | Confidence of the adversary for Carlini's loss, in term of distance between logits.
50 | Note that this approach only supports confidence setting in an untargeted case
51 |
52 | Returns
53 | -------
54 | Tensor
55 | Batch of image samples modified to be adversarial
56 | """
57 |
58 | if images.min() < 0 or images.max() > 1: raise ValueError('Input values should be in the [0, 1] range.')
59 | device = images.device
60 |
61 | alpha_l_min = alpha_l_init / 100
62 | alpha_c_min = alpha_c_init / 10
63 | multiplier = -1 if targeted else 1
64 |
65 | X_adv_round_best = images.clone()
66 | inputs_LAB = rgb2lab_diff(images)
67 | batch_size = images.shape[0]
68 | delta = torch.zeros_like(images, requires_grad=True)
69 | mask_isadv = torch.zeros(batch_size, dtype=torch.bool, device=device)
70 | color_l2_delta_bound_best = torch.full((batch_size,), 100000, dtype=torch.float, device=device)
71 |
72 | if (targeted == False) and confidence != 0:
73 | labels_onehot = torch.zeros(labels.size(0), num_classes, device=device)
74 | labels_onehot.scatter_(1, labels.unsqueeze(1), 1)
75 | labels_infhot = torch.zeros_like(labels_onehot).scatter_(1, labels.unsqueeze(1), float('inf'))
76 | if (targeted == True) and confidence != 0:
77 | print('Only support setting confidence in untargeted case!')
78 | return
79 |
80 | # check if some images are already adversarial
81 | if (targeted == False) and confidence != 0:
82 | logits = model(images)
83 | real = logits.gather(1, labels.unsqueeze(1)).squeeze(1)
84 | other = (logits - labels_infhot).amax(dim=1)
85 | mask_isadv = (real - other) <= -40
86 | elif confidence == 0:
87 | if targeted:
88 | mask_isadv = model(images).argmax(1) == labels
89 | else:
90 | mask_isadv = model(images).argmax(1) != labels
91 | color_l2_delta_bound_best[mask_isadv] = 0
92 | X_adv_round_best[mask_isadv] = images[mask_isadv]
93 |
94 | for i in range(max_iterations):
95 | # cosine annealing for alpha_l_init and alpha_c_init
96 | alpha_c = alpha_c_min + 0.5 * (alpha_c_init - alpha_c_min) * (1 + cos(i / max_iterations * pi))
97 | alpha_l = alpha_l_min + 0.5 * (alpha_l_init - alpha_l_min) * (1 + cos(i / max_iterations * pi))
98 |
99 | loss = multiplier * nn.CrossEntropyLoss(reduction='sum')(model(images + delta), labels)
100 | grad_a = grad(loss, delta, only_inputs=True)[0]
101 | delta.data[~mask_isadv] = delta.data[~mask_isadv] + alpha_l * (grad_a.permute(1, 2, 3, 0) / torch.norm(
102 | grad_a.flatten(1), dim=1)).permute(3, 0, 1, 2)[~mask_isadv]
103 |
104 | d_map = ciede2000_diff(inputs_LAB, rgb2lab_diff(images + delta)).unsqueeze(1)
105 | color_dis = torch.norm(d_map.flatten(1), dim=1)
106 | grad_color = grad(color_dis.sum(), delta, only_inputs=True)[0]
107 | delta.data[mask_isadv] = delta.data[mask_isadv] - alpha_c * (grad_color.permute(1, 2, 3, 0) / torch.norm(
108 | grad_color.flatten(1), dim=1)).permute(3, 0, 1, 2)[mask_isadv]
109 |
110 | delta.data = (images + delta.data).clamp_(min=0, max=1) - images
111 | X_adv_round = quantization(images + delta.data)
112 |
113 | if (targeted == False) and confidence != 0:
114 | logits = model(X_adv_round)
115 | real = logits.gather(1, labels.unsqueeze(1)).squeeze(1)
116 | other = (logits - labels_infhot).amax(dim=1)
117 | mask_isadv = (real - other) <= -40
118 | elif confidence == 0:
119 | if targeted:
120 | mask_isadv = model(X_adv_round).argmax(1) == labels
121 | else:
122 | mask_isadv = model(X_adv_round).argmax(1) != labels
123 | mask_best = (color_dis.data < color_l2_delta_bound_best)
124 | mask = mask_best * mask_isadv
125 | color_l2_delta_bound_best[mask] = color_dis.data[mask]
126 | X_adv_round_best[mask] = X_adv_round[mask]
127 |
128 | return X_adv_round_best
129 |
--------------------------------------------------------------------------------
/adv_lib/attacks/projected_gradient_descent.py:
--------------------------------------------------------------------------------
1 | import numbers
2 | from functools import partial
3 | from typing import Optional, Tuple, Union
4 |
5 | import torch
6 | from torch import Tensor, nn
7 | from torch.autograd import grad
8 | from torch.nn import functional as F
9 |
10 | from adv_lib.utils.losses import difference_of_logits, difference_of_logits_ratio
11 | from adv_lib.utils.projections import clamp_
12 | from adv_lib.utils.visdom_logger import VisdomLogger
13 |
14 |
15 | def pgd_linf(model: nn.Module,
16 | inputs: Tensor,
17 | labels: Tensor,
18 | ε: Union[float, Tensor],
19 | targeted: bool = False,
20 | steps: int = 40,
21 | random_init: bool = True,
22 | restarts: int = 1,
23 | loss_function: str = 'ce',
24 | relative_step_size: float = 0.01 / 0.3,
25 | absolute_step_size: Optional[float] = None,
26 | callback: Optional[VisdomLogger] = None) -> Tensor:
27 | device = inputs.device
28 | batch_size = len(inputs)
29 |
30 | adv_inputs = inputs.clone()
31 | adv_found = torch.zeros(batch_size, dtype=torch.bool, device=device)
32 |
33 | if isinstance(ε, numbers.Real):
34 | ε = torch.full_like(adv_found, ε, dtype=inputs.dtype)
35 |
36 | pgd_attack = partial(_pgd_linf, model=model, targeted=targeted, steps=steps, random_init=random_init,
37 | loss_function=loss_function, relative_step_size=relative_step_size,
38 | absolute_step_size=absolute_step_size)
39 |
40 | for i in range(restarts):
41 |
42 | adv_found_run, adv_inputs_run = pgd_attack(inputs=inputs[~adv_found], labels=labels[~adv_found],
43 | ε=ε[~adv_found])
44 | adv_inputs[~adv_found] = adv_inputs_run
45 | adv_found[~adv_found] = adv_found_run
46 |
47 | if callback:
48 | callback.line('success', i + 1, adv_found.float().mean())
49 |
50 | if adv_found.all():
51 | break
52 |
53 | return adv_inputs
54 |
55 |
56 | def _pgd_linf(model: nn.Module,
57 | inputs: Tensor,
58 | labels: Tensor,
59 | ε: Tensor,
60 | targeted: bool = False,
61 | steps: int = 40,
62 | random_init: bool = True,
63 | loss_function: str = 'ce',
64 | relative_step_size: float = 0.01 / 0.3,
65 | absolute_step_size: Optional[float] = None) -> Tuple[Tensor, Tensor]:
66 | _loss_functions = {
67 | 'ce': (partial(F.cross_entropy, reduction='none'), 1),
68 | 'dl': (difference_of_logits, -1),
69 | 'dlr': (partial(difference_of_logits_ratio, targeted=targeted), -1),
70 | }
71 |
72 | device = inputs.device
73 | batch_size = len(inputs)
74 | batch_view = lambda tensor: tensor.view(batch_size, *[1] * (inputs.ndim - 1))
75 | lower, upper = torch.maximum(-inputs, -batch_view(ε)), torch.minimum(1 - inputs, batch_view(ε))
76 |
77 | loss_func, multiplier = _loss_functions[loss_function.lower()]
78 |
79 | step_size: Tensor = ε * relative_step_size if absolute_step_size is None else torch.full_like(ε, absolute_step_size)
80 | if targeted:
81 | step_size *= -1
82 |
83 | δ = torch.zeros_like(inputs, requires_grad=True)
84 | best_adv = inputs.clone()
85 | adv_found = torch.zeros(batch_size, dtype=torch.bool, device=device)
86 |
87 | if random_init:
88 | δ.data.uniform_(-1, 1).mul_(batch_view(ε))
89 | clamp_(δ, lower=lower, upper=upper)
90 |
91 | for i in range(steps):
92 | adv_inputs = inputs + δ
93 | logits = model(adv_inputs)
94 |
95 | if i == 0 and loss_function.lower() in ['dl', 'dlr']:
96 | labels_infhot = torch.zeros_like(logits).scatter_(1, labels.unsqueeze(1), float('inf'))
97 | loss_func = partial(loss_func, labels_infhot=labels_infhot)
98 |
99 | loss = multiplier * loss_func(logits, labels)
100 | δ_grad = grad(loss.sum(), δ, only_inputs=True)[0].sign_()
101 |
102 | is_adv = (logits.argmax(1) == labels) if targeted else (logits.argmax(1) != labels)
103 | best_adv = torch.where(batch_view(is_adv), adv_inputs.detach(), best_adv)
104 | adv_found.logical_or_(is_adv)
105 |
106 | δ.data.addcmul_(batch_view(step_size), δ_grad)
107 | clamp_(δ, lower=lower, upper=upper)
108 |
109 | return adv_found, best_adv
110 |
--------------------------------------------------------------------------------
/adv_lib/attacks/segmentation/__init__.py:
--------------------------------------------------------------------------------
1 | from .alma_prox import alma_prox
2 | from .asma import asma
3 | from .dense_adversary import dag
4 | from .primal_dual_gradient_descent import pdgd, pdpgd
5 |
--------------------------------------------------------------------------------
/adv_lib/attacks/segmentation/asma.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | import torch
4 | from adv_lib.utils.visdom_logger import VisdomLogger
5 | from torch import Tensor, nn
6 | from torch.autograd import grad
7 |
8 |
9 | def iou_masks(mask1: Tensor, mask2: Tensor, n: int):
10 | k = (mask1 >= 0) & (mask1 < n)
11 | inds = n * mask1[k].to(torch.int64) + mask2[k]
12 | mat = torch.bincount(inds, minlength=n ** 2).reshape(n, n)
13 | iu = torch.diag(mat) / (mat.sum(1) + mat.sum(0) - torch.diag(mat) + 1e-6)
14 | return iu.mean().item()
15 |
16 |
17 | def asma(model: nn.Module,
18 | inputs: Tensor,
19 | labels: Tensor,
20 | masks: Tensor = None,
21 | targeted: bool = False,
22 | adv_threshold: float = 0.99,
23 | num_steps: int = 1000,
24 | τ: float = 1e-7,
25 | β: float = 1e-6,
26 | callback: Optional[VisdomLogger] = None) -> Tensor:
27 | "ASMA attack from https://arxiv.org/abs/1907.13124"
28 | attack_name = 'ASMA'
29 | device = inputs.device
30 | batch_size = len(inputs)
31 | batch_view = lambda tensor: tensor.view(batch_size, *[1] * (inputs.ndim - 1))
32 |
33 | # Setup variables
34 | δ = torch.zeros_like(inputs, requires_grad=True)
35 | lower, upper = -inputs, 1 - inputs
36 | pert_mul = τ
37 |
38 | # Init trackers
39 | best_dist = torch.full((batch_size,), float('inf'), device=device)
40 | best_adv_percent = torch.zeros_like(best_dist)
41 | adv_found = torch.zeros_like(best_dist, dtype=torch.bool)
42 | best_adv = inputs.clone()
43 |
44 | for i in range(num_steps):
45 |
46 | adv_inputs = inputs + δ
47 | logits = model(adv_inputs)
48 | l2_squared = δ.flatten(1).square().sum(dim=1)
49 |
50 | if i == 0:
51 | # initialize variables based on model's output
52 | num_classes = logits.size(1)
53 | if masks is None:
54 | masks = labels < num_classes
55 | masks_sum = masks.flatten(1).sum(dim=1)
56 | labels_ = labels * masks
57 |
58 | # track progress
59 | pred = logits.argmax(dim=1)
60 | pixel_is_adv = (pred == labels) if targeted else (pred != labels)
61 | adv_percent = (pixel_is_adv & masks).flatten(1).sum(dim=1) / masks_sum
62 | is_adv = adv_percent >= adv_threshold
63 | is_smaller = l2_squared <= best_dist
64 | improves_constraints = adv_percent >= best_adv_percent.clamp_max(adv_threshold)
65 | is_better_adv = (is_smaller & is_adv) | (~adv_found & improves_constraints)
66 | adv_found.logical_or_(is_adv)
67 | best_dist = torch.where(is_better_adv, l2_squared.detach(), best_dist)
68 | best_adv_percent = torch.where(is_better_adv, adv_percent, best_adv_percent)
69 | best_adv = torch.where(batch_view(is_better_adv), adv_inputs.detach(), best_adv)
70 |
71 | iou = iou_masks(labels, pred, n=num_classes)
72 | if i:
73 | pert_mul = β * iou + τ
74 |
75 | logit_loss = logits.gather(1, labels_.unsqueeze(1)).squeeze(1).mul(masks & (pred != labels_)).sum()
76 | loss = logit_loss - l2_squared
77 | δ_grad = grad(loss.sum(), δ, only_inputs=True)[0]
78 |
79 | δ.data.add_(δ_grad, alpha=pert_mul).clamp_(min=lower, max=upper)
80 |
81 | if callback:
82 | callback.accumulate_line('logit_loss', i, logit_loss.mean(), title=attack_name + ' - Logit loss')
83 | callback.accumulate_line(['adv%', 'best_adv%'], i, [adv_percent.mean(), best_adv_percent.mean()],
84 | title=attack_name + ' - APSR')
85 | callback.accumulate_line(['ℓ2', 'best ℓ2'], i, [l2_squared.detach().sqrt().mean(), best_dist.sqrt().mean()],
86 | title=attack_name + ' - L2 Norms')
87 | callback.accumulate_line('lr', i, pert_mul, title=attack_name + ' - Step size')
88 | callback.accumulate_line('IoU', i, iou, title=attack_name + ' - IoU')
89 | if (i + 1) % (num_steps // 20) == 0 or (i + 1) == num_steps:
90 | callback.update_lines()
91 |
92 | return best_adv
93 |
--------------------------------------------------------------------------------
/adv_lib/attacks/segmentation/dense_adversary.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | import torch
4 | from adv_lib.utils.losses import difference_of_logits
5 | from adv_lib.utils.visdom_logger import VisdomLogger
6 | from torch import Tensor, nn
7 | from torch.autograd import grad
8 |
9 |
10 | def dag(model: nn.Module,
11 | inputs: Tensor,
12 | labels: Tensor,
13 | masks: Tensor = None,
14 | targeted: bool = False,
15 | adv_threshold: float = 0.99,
16 | max_iter: int = 200,
17 | γ: float = 0.5,
18 | p: float = float('inf'),
19 | callback: Optional[VisdomLogger] = None) -> Tensor:
20 | """DAG attack from https://arxiv.org/abs/1703.08603"""
21 | device = inputs.device
22 | batch_size = len(inputs)
23 | batch_view = lambda tensor: tensor.view(-1, *[1] * (inputs.ndim - 1))
24 | multiplier = -1 if targeted else 1
25 |
26 | # Setup variables
27 | r = torch.zeros_like(inputs)
28 |
29 | # Init trackers
30 | best_adv_percent = torch.zeros(batch_size, device=device)
31 | adv_found = torch.zeros_like(best_adv_percent, dtype=torch.bool)
32 | best_adv = inputs.clone()
33 |
34 | for i in range(max_iter):
35 |
36 | active_inputs = ~adv_found
37 | inputs_ = inputs[active_inputs]
38 | r_ = r[active_inputs]
39 | r_.requires_grad_(True)
40 |
41 | adv_inputs_ = (inputs_ + r_).clamp(0, 1)
42 | logits = model(adv_inputs_)
43 |
44 | if i == 0:
45 | num_classes = logits.size(1)
46 | if masks is None:
47 | masks = labels < num_classes
48 | masks_sum = masks.flatten(1).sum(dim=1)
49 | masked_labels = labels * masks
50 | labels_infhot = torch.zeros_like(logits.detach()).scatter(1, masked_labels.unsqueeze(1), float('inf'))
51 |
52 | dl = multiplier * difference_of_logits(logits, labels=masked_labels[active_inputs],
53 | labels_infhot=labels_infhot[active_inputs])
54 | pixel_is_adv = dl < 0
55 |
56 | active_masks = masks[active_inputs]
57 | adv_percent = (pixel_is_adv & active_masks).flatten(1).sum(dim=1) / masks_sum[active_inputs]
58 | is_adv = adv_percent >= adv_threshold
59 | adv_found[active_inputs] = is_adv
60 | best_adv[active_inputs] = torch.where(batch_view(is_adv), adv_inputs_.detach(), best_adv[active_inputs])
61 |
62 | if callback:
63 | callback.accumulate_line('dl', i, dl[active_masks].mean(), title=f'DAG (p={p}, γ={γ}) - DL')
64 | callback.accumulate_line(f'L{p}', i, r.flatten(1).norm(p=p, dim=1).mean(), title=f'DAG (p={p}, γ={γ}) - Norm')
65 | callback.accumulate_line('adv%', i, adv_percent.mean(), title=f'DAG (p={p}, γ={γ}) - Adv percent')
66 |
67 | if (i + 1) % (max_iter // 20) == 0 or (i + 1) == max_iter:
68 | callback.update_lines()
69 |
70 | if is_adv.all():
71 | break
72 |
73 | loss = (dl[~is_adv] * active_masks[~is_adv]).relu()
74 | r_grad = grad(loss.sum(), r_, only_inputs=True)[0]
75 | r_grad.div_(batch_view(r_grad.flatten(1).norm(p=p, dim=1).clamp_min_(1e-8)))
76 | r_.data.sub_(r_grad, alpha=γ)
77 |
78 | r[active_inputs] = r_
79 |
80 | if callback:
81 | callback.update_lines()
82 |
83 | return best_adv
84 |
--------------------------------------------------------------------------------
/adv_lib/attacks/sigma_zero.py:
--------------------------------------------------------------------------------
1 | # Adapted from https://github.com/Cinofix/sigma-zero-adversarial-attack
2 | import math
3 | import warnings
4 |
5 | import torch
6 | from torch import Tensor, nn
7 | from torch.autograd import grad
8 |
9 | from adv_lib.utils.losses import difference_of_logits
10 |
11 |
12 | def sigma_zero(model: nn.Module,
13 | inputs: Tensor,
14 | labels: Tensor,
15 | num_steps: int = 1000,
16 | η_0: float = 1.0,
17 | σ: float = 0.001,
18 | τ_0: float = 0.3,
19 | τ_factor: float = 0.01,
20 | grad_norm: float = float('inf'),
21 | targeted: bool = False) -> Tensor:
22 | """
23 | σ-zero attack from https://arxiv.org/abs/2402.01879.
24 |
25 | Parameters
26 | ----------
27 | model : nn.Module
28 | Model to attack.
29 | inputs : Tensor
30 | Inputs to attack. Should be in [0, 1].
31 | labels : Tensor
32 | Labels corresponding to the inputs if untargeted, else target labels.
33 | num_steps : int
34 | Number of optimization steps. Corresponds to the number of forward and backward propagations.
35 | η_0 : float
36 | Initial step size.
37 | σ : float
38 | \ell_0 approximation parameter: smaller values produce sharper approximations while larger values produce a
39 | smoother approximation.
40 | τ_0 : float
41 | Initial sparsity threshold.
42 | τ_factor : float
43 | Threshold adjustment factor w.r.t. step size η.
44 | grad_norm: float
45 | Norm to use for gradient normalization.
46 | targeted : bool
47 | Attack is untargeted only: will raise a warning and return inputs if targeted is True.
48 |
49 | Returns
50 | -------
51 | best_adv : Tensor
52 | Perturbed inputs (inputs + perturbation) that are adversarial and have smallest distance with the original
53 | inputs.
54 |
55 | """
56 | if targeted:
57 | warnings.warn('σ-zero attack is untargeted only. Returning inputs.')
58 | return inputs
59 |
60 | batch_size, numel = len(inputs), inputs[0].numel()
61 | batch_view = lambda tensor: tensor.view(batch_size, *[1] * (inputs.ndim - 1))
62 |
63 | δ = torch.zeros_like(inputs, requires_grad=True)
64 | # Adam variables
65 | exp_avg = torch.zeros_like(inputs)
66 | exp_avg_sq = torch.zeros_like(inputs)
67 | β_1, β_2 = 0.9, 0.999
68 |
69 | best_l0 = inputs.new_full((batch_size,), numel)
70 | best_adv = inputs.clone()
71 | τ = torch.full_like(best_l0, τ_0)
72 |
73 | η = η_0
74 | for i in range(num_steps):
75 | adv_inputs = inputs + δ
76 |
77 | # compute loss
78 | logits = model(adv_inputs)
79 | dl_loss = difference_of_logits(logits, labels).clamp_(min=0)
80 | δ_square = δ.square()
81 | l0_approx_normalized = (δ_square / (δ_square + σ)).flatten(1).mean(dim=1)
82 |
83 | # keep best solutions
84 | predicted_classes = logits.argmax(dim=1)
85 | l0_norm = δ.data.flatten(1).norm(p=0, dim=1)
86 | is_adv = (predicted_classes == labels) if targeted else (predicted_classes != labels)
87 | is_smaller = l0_norm < best_l0
88 | is_both = is_adv & is_smaller
89 | best_l0 = torch.where(is_both, l0_norm, best_l0)
90 | best_adv = torch.where(batch_view(is_both), adv_inputs.detach(), best_adv)
91 |
92 | # compute loss and gradient
93 | adv_loss = (dl_loss + l0_approx_normalized).sum()
94 | δ_grad = grad(adv_loss, inputs=δ, only_inputs=True)[0]
95 |
96 | # normalize gradient based on grad_norm type
97 | δ_inf_norm = δ_grad.flatten(1).norm(p=grad_norm, dim=1).clamp_(min=1e-12)
98 | δ_grad.div_(batch_view(δ_inf_norm))
99 |
100 | # adam computations
101 | exp_avg.lerp_(δ_grad, weight=1 - β_1)
102 | exp_avg_sq.mul_(β_2).addcmul_(δ_grad, δ_grad, value=1 - β_2)
103 | bias_correction1 = 1 - β_1 ** (i + 1)
104 | bias_correction2 = 1 - β_2 ** (i + 1)
105 | denom = exp_avg_sq.sqrt().div_(bias_correction2 ** 0.5).add_(1e-8)
106 |
107 | # step and clamp
108 | δ.data.addcdiv_(exp_avg, denom, value=-η / bias_correction1)
109 | δ.data.add_(inputs).clamp_(min=0, max=1).sub_(inputs)
110 |
111 | # update step size with cosine annealing
112 | η = 0.1 * η_0 + 0.9 * η_0 * (1 + math.cos(math.pi * i / num_steps)) / 2
113 | # dynamic thresholding
114 | τ.add_(torch.where(is_adv, τ_factor * η, -τ_factor * η)).clamp_(min=0, max=1)
115 |
116 | # filter components
117 | δ.data[δ.data.abs() < batch_view(τ)] = 0
118 |
119 | return best_adv
120 |
--------------------------------------------------------------------------------
/adv_lib/attacks/stochastic_sparse_attacks.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | import torch
4 | from torch import Tensor, nn
5 | from torch.autograd import grad
6 |
7 |
8 | def fga(model: nn.Module,
9 | inputs: Tensor,
10 | labels: Tensor,
11 | targeted: bool,
12 | increasing: bool = True,
13 | max_iter: Optional[int] = None,
14 | n_samples: int = 10,
15 | large_memory: bool = False) -> Tensor:
16 | """Folded Gaussian Attack (FGA) attack from https://arxiv.org/abs/2011.12423.
17 |
18 | Parameters
19 | ----------
20 | model : nn.Module
21 | Model to attack.
22 | inputs : Tensor
23 | Inputs to attack. Should be in [0, 1].
24 | labels : Tensor
25 | Labels corresponding to the inputs if untargeted, else target labels.
26 | targeted : bool
27 | Whether to perform a targeted attack or not.
28 | increasing : bool
29 | Whether to add positive or negative perturbations.
30 | max_iter : int
31 | Maximum number of iterations for the attack. If None is provided, the attack will run as long as adversarial
32 | examples are not found and non-modified pixels are left.
33 | n_samples : int
34 | Number of random samples to draw in each iteration.
35 | large_memory : bool
36 | If True, performs forward propagations on all randomly perturbed inputs in one batch. This is slightly faster
37 | for small models but also uses `n_samples` times more memory. This should only be used when working on small
38 | models. For larger models, the speed gain is negligible, so this option should be left to False.
39 |
40 | Returns
41 | -------
42 | adv_inputs : Tensor
43 | Modified inputs to be adversarial to the model.
44 |
45 | """
46 | batch_size, *input_shape = inputs.shape
47 | batch_view = lambda tensor: tensor.view(-1, *[1] * (inputs.ndim - 1))
48 | input_view = lambda tensor: tensor.view(-1, *input_shape)
49 | device = inputs.device
50 | model_ = lambda t: model(t).softmax(dim=1)
51 | multiplier = 1 if targeted else -1
52 |
53 | adv_inputs = inputs.clone()
54 | best_adv = inputs.clone()
55 | adv_found = torch.zeros(batch_size, device=device, dtype=torch.bool)
56 |
57 | max_iter = inputs[0].numel() if max_iter is None else max_iter
58 | for i in range(max_iter):
59 | Γ = adv_inputs == inputs
60 | Γ_empty = ~Γ.flatten(1).any(1)
61 | if (Γ_empty | adv_found).all():
62 | break
63 | to_attack = ~(Γ_empty | adv_found)
64 |
65 | inputs_, labels_, Γ_ = adv_inputs[to_attack], labels[to_attack], Γ[to_attack]
66 | batch_size_ = len(inputs_)
67 | inputs_.requires_grad_(True)
68 | Γ_inf = torch.zeros_like(Γ_, dtype=torch.float).masked_fill_(~Γ_, float('inf'))
69 |
70 | probs = model_(inputs_)
71 | label_probs = probs.gather(1, labels_.unsqueeze(1)).squeeze(1)
72 | grad_label_probs = grad(multiplier * label_probs.sum(), inputs=inputs_, only_inputs=True)[0]
73 | inputs_.detach_()
74 |
75 | # find index of most relevant feature
76 | if increasing:
77 | i_0 = grad_label_probs.mul_(1 - inputs_).sub_(Γ_inf).flatten(1).argmax(dim=1, keepdim=True)
78 | else:
79 | i_0 = grad_label_probs.mul_(inputs_).add_(Γ_inf).flatten(1).argmin(dim=1, keepdim=True)
80 |
81 | # compute variance of gaussian noise
82 | θ = inputs_.flatten(1).gather(1, i_0).neg_()
83 | if increasing:
84 | θ.add_(1)
85 | # generate random perturbation from folded Gaussian noise
86 | S = torch.randn(batch_size_, n_samples, device=device).abs_().mul_(θ)
87 |
88 | # add perturbation to inputs
89 | perturbed_inputs = inputs_.flatten(1).unsqueeze(1).repeat(1, n_samples, 1)
90 | perturbed_inputs.scatter_add_(2, i_0.repeat_interleave(n_samples, dim=1).unsqueeze(2), S.unsqueeze(2))
91 | perturbed_inputs.clamp_(min=0, max=1)
92 |
93 | # get probabilities for perturbed inputs
94 | if large_memory:
95 | new_probs = model_(input_view(perturbed_inputs))
96 | else:
97 | new_probs = []
98 | for chunk in torch.chunk(input_view(perturbed_inputs), chunks=n_samples):
99 | new_probs.append(model_(chunk))
100 | new_probs = torch.cat(new_probs, dim=0)
101 | new_probs = new_probs.view(batch_size_, n_samples, -1)
102 | new_preds = new_probs.argmax(dim=2)
103 |
104 | new_label_probs = new_probs.gather(2, labels_.view(-1, 1, 1).expand(-1, n_samples, 1)).squeeze(2)
105 | if targeted:
106 | # finding the index of max probability for target class. If a sample is adv, it will be prioritized. If
107 | # several are adversarial, taking the index of the adv sample with max probability.
108 | adv_found_ = new_preds == labels_.unsqueeze(1)
109 | best_sample_index = (new_label_probs + adv_found_.float()).argmax(dim=1)
110 | else:
111 | # finding the index of min probability for original class. If a sample is adv, it will be prioritized. If
112 | # several are adversarial, taking the index of the adv sample with min probability.
113 | adv_found_ = new_preds != labels_.unsqueeze(1)
114 | best_sample_index = (new_label_probs - adv_found_.float()).argmin(dim=1)
115 |
116 | # update trackers
117 | adv_inputs[to_attack] = input_view(perturbed_inputs[range(batch_size_), best_sample_index])
118 | preds = new_preds.gather(1, best_sample_index.unsqueeze(1)).squeeze(1)
119 | is_adv = (preds == labels_) if targeted else (preds != labels_)
120 | adv_found[to_attack] = is_adv
121 | best_adv[to_attack] = torch.where(batch_view(is_adv), adv_inputs[to_attack], best_adv[to_attack])
122 |
123 | return best_adv
124 |
125 |
126 | def vfga(model: nn.Module,
127 | inputs: Tensor,
128 | labels: Tensor,
129 | targeted: bool,
130 | max_iter: Optional[int] = None,
131 | n_samples: int = 10,
132 | large_memory: bool = False) -> Tensor:
133 | """Voting Folded Gaussian Attack (VFGA) attack from https://arxiv.org/abs/2011.12423.
134 |
135 | Parameters
136 | ----------
137 | model : nn.Module
138 | Model to attack.
139 | inputs : Tensor
140 | Inputs to attack. Should be in [0, 1].
141 | labels : Tensor
142 | Labels corresponding to the inputs if untargeted, else target labels.
143 | targeted : bool
144 | Whether to perform a targeted attack or not.
145 | max_iter : int
146 | Maximum number of iterations for the attack. If None is provided, the attack will run as long as adversarial
147 | examples are not found and non-modified pixels are left.
148 | n_samples : int
149 | Number of random samples to draw in each iteration.
150 | large_memory : bool
151 | If True, performs forward propagations on all randomly perturbed inputs in one batch. This is slightly faster
152 | for small models but also uses `n_samples` times more memory. This should only be used when working on small
153 | models. For larger models, the speed gain is negligible, so this option should be left to False.
154 |
155 | Returns
156 | -------
157 | adv_inputs : Tensor
158 | Modified inputs to be adversarial to the model.
159 |
160 | """
161 | batch_size, *input_shape = inputs.shape
162 | batch_view = lambda tensor: tensor.view(-1, *[1] * (inputs.ndim - 1))
163 | input_view = lambda tensor: tensor.view(-1, *input_shape)
164 | device = inputs.device
165 | model_ = lambda t: model(t).softmax(dim=1)
166 | multiplier = 1 if targeted else -1
167 |
168 | adv_inputs = inputs.clone()
169 | best_adv = inputs.clone()
170 | adv_found = torch.zeros(batch_size, device=device, dtype=torch.bool)
171 |
172 | max_iter = inputs[0].numel() if max_iter is None else max_iter
173 | for i in range(max_iter):
174 | Γ = adv_inputs == inputs
175 | Γ_empty = ~Γ.flatten(1).any(1)
176 | if (Γ_empty | adv_found).all():
177 | break
178 | to_attack = ~(Γ_empty | adv_found)
179 |
180 | inputs_, labels_, Γ_ = adv_inputs[to_attack], labels[to_attack], Γ[to_attack]
181 | batch_size_ = len(inputs_)
182 | inputs_.requires_grad_(True)
183 | Γ_inf = torch.zeros_like(Γ_, dtype=torch.float).masked_fill_(~Γ_, float('inf'))
184 |
185 | probs = model_(inputs_)
186 | label_probs = probs.gather(1, labels_.unsqueeze(1)).squeeze(1)
187 | grad_label_probs = grad(multiplier * label_probs.sum(), inputs=inputs_, only_inputs=True)[0]
188 | inputs_.detach_()
189 |
190 | # find index of most relevant feature
191 | i_plus = (1 - inputs_).mul_(grad_label_probs).sub_(Γ_inf).flatten(1).argmax(dim=1, keepdim=True)
192 | i_minus = grad_label_probs.mul_(inputs_).add_(Γ_inf).flatten(1).argmin(dim=1, keepdim=True)
193 | # compute variance of gaussian noise
194 | θ_plus = 1 - inputs_.flatten(1).gather(1, i_plus)
195 | θ_minus = inputs_.flatten(1).gather(1, i_minus)
196 | # generate random perturbation from folded Gaussian noise
197 | S_plus = torch.randn(batch_size_, n_samples, device=device).abs_().mul_(θ_plus)
198 | S_minus = torch.randn(batch_size_, n_samples, device=device).abs_().neg_().mul_(θ_minus)
199 |
200 | # add perturbation to inputs
201 | perturbed_inputs = inputs_.flatten(1).unsqueeze(1).repeat(1, 2 * n_samples, 1)
202 | i_plus_minus = torch.cat([i_plus, i_minus], dim=1).repeat_interleave(n_samples, dim=1)
203 | S_plus_minus = torch.cat([S_plus, S_minus], dim=1)
204 | perturbed_inputs.scatter_add_(2, i_plus_minus.unsqueeze(2), S_plus_minus.unsqueeze(2))
205 | perturbed_inputs.clamp_(min=0, max=1)
206 |
207 | # get probabilities for perturbed inputs
208 | if large_memory:
209 | new_probs = model_(input_view(perturbed_inputs))
210 | else:
211 | new_probs = []
212 | for chunk in torch.chunk(input_view(perturbed_inputs), chunks=n_samples):
213 | new_probs.append(model_(chunk))
214 | new_probs = torch.cat(new_probs, dim=0)
215 | new_probs = new_probs.view(batch_size_, 2 * n_samples, -1)
216 | new_preds = new_probs.argmax(dim=2)
217 |
218 | new_label_probs = new_probs.gather(2, labels_.view(-1, 1, 1).expand(-1, 2 * n_samples, 1)).squeeze(2)
219 | if targeted:
220 | # finding the index of max probability for target class. If a sample is adv, it will be prioritized. If
221 | # several are adversarial, taking the index of the adv sample with max probability.
222 | adv_found_ = new_preds == labels_.unsqueeze(1)
223 | best_sample_index = (new_label_probs + adv_found_.float()).argmax(dim=1)
224 | else:
225 | # finding the index of min probability for original class. If a sample is adv, it will be prioritized. If
226 | # several are adversarial, taking the index of the adv sample with min probability.
227 | adv_found_ = new_preds != labels_.unsqueeze(1)
228 | best_sample_index = (new_label_probs - adv_found_.float()).argmin(dim=1)
229 |
230 | # update trackers
231 | adv_inputs[to_attack] = input_view(perturbed_inputs[range(batch_size_), best_sample_index])
232 | preds = new_preds.gather(1, best_sample_index.unsqueeze(1)).squeeze(1)
233 | is_adv = (preds == labels_) if targeted else (preds != labels_)
234 | adv_found[to_attack] = is_adv
235 | best_adv[to_attack] = torch.where(batch_view(is_adv), adv_inputs[to_attack], best_adv[to_attack])
236 |
237 | return best_adv
238 |
--------------------------------------------------------------------------------
/adv_lib/attacks/structured_adversarial_attack.py:
--------------------------------------------------------------------------------
1 | # Adapted from https://github.com/KaidiXu/StrAttack
2 | import math
3 | from itertools import product
4 | from typing import Optional
5 |
6 | import numpy as np
7 | import torch
8 | from torch import Tensor, nn
9 | from torch.autograd import grad
10 | from torch.nn import functional as F
11 |
12 | from adv_lib.utils.losses import difference_of_logits
13 | from adv_lib.utils.visdom_logger import VisdomLogger
14 |
15 |
16 | def str_attack(model: nn.Module,
17 | inputs: Tensor,
18 | labels: Tensor,
19 | targeted: bool = False,
20 | confidence: float = 0,
21 | initial_const: float = 1,
22 | binary_search_steps: int = 6,
23 | max_iterations: int = 2000,
24 | ρ: float = 1,
25 | α: float = 5,
26 | τ: float = 2,
27 | γ: float = 1,
28 | group_size: int = 2,
29 | stride: int = 2,
30 | retrain: bool = True,
31 | σ: float = 3,
32 | fix_y_step: bool = False,
33 | callback: Optional[VisdomLogger] = None) -> Tensor:
34 | """
35 | StrAttack from https://arxiv.org/abs/1808.01664.
36 |
37 | Parameters
38 | ----------
39 | model : nn.Module
40 | Model to attack.
41 | inputs : Tensor
42 | Inputs to attack. Should be in [0, 1].
43 | labels : Tensor
44 | Labels corresponding to the inputs if untargeted, else target labels.
45 | targeted : bool
46 | Whether to perform a targeted attack or not.
47 | confidence : float
48 | Confidence of adversarial examples: higher produces examples that are farther away, but more strongly classified
49 | as adversarial.
50 | initial_const : float
51 | The initial tradeoff-constant to use to tune the relative importance of distance and confidence. If
52 | binary_search_steps is large, the initial constant is not important.
53 | binary_search_steps : int
54 | The number of times we perform binary search to find the optimal tradeoff-constant between distance and
55 | confidence.
56 | max_iterations : int
57 | The maximum number of iterations. Larger values are more accurate; setting too small will require a large
58 | learning rate and will produce poor results.
59 | ρ : float
60 | Penalty parameter adjusting the trade-off between convergence speed and value. Larger ρ leads to faster
61 | convergence but larger perturbations.
62 | α : float
63 | Initial learning rate (η_1 in section F of the paper).
64 | τ : float
65 | Weight of the group sparsity penalty.
66 | γ : float
67 | Weight of the l2-norm penalty.
68 | group_size : int
69 | Size of the groups for the sparsity penalty.
70 | stride : int
71 | Stride of the groups for the sparsity penalty. If stride < group_size, then the groups are overlapping.
72 | retrain : bool
73 | If True, refines the perturbation by constraining it to σ% of the pixels.
74 | σ : float
75 | Percentage of pixels allowed to be perturbed in the refinement procedure.
76 | fix_y_step : bool
77 | Fix the group folding of the original implementation by correctly summing over groups for pixels in the
78 | overlapping regions. Typically results in smaller perturbations and is faster.
79 | callback : Optional
80 |
81 | Returns
82 | -------
83 | adv_inputs : Tensor
84 | Modified inputs to be adversarial to the model.
85 |
86 | """
87 | device = inputs.device
88 | batch_size, C, H, W = inputs.shape
89 | batch_view = lambda tensor: tensor.view(batch_size, *[1] * (inputs.ndim - 1))
90 | multiplier = -1 if targeted else 1
91 | zeros = torch.zeros_like(inputs)
92 |
93 | # set the lower and upper bounds accordingly
94 | const = torch.full((batch_size,), initial_const, device=device, dtype=torch.float)
95 | lower_bound = torch.zeros_like(const)
96 | upper_bound = torch.full_like(const, 1e10)
97 |
98 | # bounds for the perturbations to get valid inputs
99 | sup_bound = 1 - inputs
100 | inf_bound = -inputs
101 |
102 | # number of groups per row and column
103 | P, Q = math.floor((W - group_size) / stride) + 1, math.floor((H - group_size) / stride) + 1
104 | overlap = group_size > stride
105 |
106 | z = torch.zeros_like(inputs)
107 | v = torch.zeros_like(inputs)
108 | u = torch.zeros_like(inputs)
109 | s = torch.zeros_like(inputs)
110 |
111 | o_best_l2 = torch.full_like(const, float('inf'))
112 | o_best_adv = inputs.clone()
113 | o_adv_found = torch.zeros(batch_size, device=device, dtype=torch.bool)
114 |
115 | i_total = 0
116 | for outer_step in range(binary_search_steps):
117 |
118 | best_l2 = torch.full_like(const, float('inf'))
119 | adv_found = torch.zeros(batch_size, device=device, dtype=torch.bool)
120 |
121 | # The last iteration (if we run many steps) repeat the search once.
122 | if (binary_search_steps >= 10) and outer_step == (binary_search_steps - 1):
123 | const = upper_bound
124 |
125 | for i in range(max_iterations): # max_iterations + outer_step * 1000 in the original implementation
126 |
127 | z.requires_grad_(True)
128 | adv_inputs = z + inputs
129 | logits = model(adv_inputs)
130 |
131 | if outer_step == 0 and i == 0:
132 | # setup the target variable, we need it to be in one-hot form for the loss function
133 | labels_onehot = torch.zeros_like(logits).scatter_(1, labels.unsqueeze(1), 1)
134 | labels_infhot = torch.zeros_like(logits).scatter_(1, labels.unsqueeze(1), float('inf'))
135 |
136 | logit_dists = multiplier * difference_of_logits(logits, labels, labels_infhot=labels_infhot)
137 | loss = const * (logit_dists + confidence).clamp(min=0)
138 | z_grad = grad(loss.sum(), inputs=z, only_inputs=True)[0]
139 | z.detach_()
140 |
141 | # δ step (equation 10)
142 | a = z - u
143 | δ = ρ / (ρ + 2 * γ) * a
144 |
145 | # w step (equation 11)
146 | b = z - s
147 | w = torch.minimum(torch.maximum(b, inf_bound), sup_bound)
148 |
149 | # y step (equation 17)
150 | c = z - v
151 | groups = F.unfold(c, kernel_size=group_size, stride=stride)
152 | group_norms = groups.norm(dim=1, p=2, keepdim=True)
153 | temp = torch.where(group_norms != 0, 1 - τ / (ρ * group_norms), torch.zeros_like(group_norms))
154 | temp_ = groups * temp.clamp(min=0)
155 |
156 | if overlap and not fix_y_step: # match original implementation when overlapping
157 | y = c
158 | for i, (p, q) in enumerate(product(range(P), range(Q))):
159 | p_start, p_end = p * stride, p * stride + group_size
160 | q_start, q_end = q * stride, q * stride + group_size
161 | y[:, :, p_start:p_end, q_start:q_end] = temp_[:, :, i].view(-1, C, group_size, group_size)
162 | else: # faster folding (matches original implementation when groups are not overlapping)
163 | y = F.fold(temp_, output_size=(H, W), kernel_size=group_size, stride=stride)
164 |
165 | # MODIFIED: add projection for valid perturbation
166 | y = torch.minimum(torch.maximum(y, inf_bound), sup_bound)
167 |
168 | # z step (equation 18)
169 | a_prime = δ + u
170 | b_prime = w + s
171 | c_prime = y + v
172 | η = α * (i + 1) ** 0.5
173 | z = (z * η + ρ * (2 * a_prime + b_prime + c_prime) - z_grad) / (η + 4 * ρ)
174 |
175 | # MODIFIED: add projection for valid perturbation
176 | z = torch.minimum(torch.maximum(z, inf_bound), sup_bound)
177 |
178 | # update steps
179 | u.add_(δ - z)
180 | v.add_(y - z)
181 | s.add_(w - z)
182 |
183 | # new predictions
184 | adv_inputs = y + inputs
185 | l2 = (adv_inputs - inputs).flatten(1).norm(p=2, dim=1)
186 | logits = model(adv_inputs)
187 | logit_dists = multiplier * difference_of_logits(logits, labels, labels_infhot=labels_infhot)
188 |
189 | # adjust the best result found so far
190 | predicted_classes = (logits - labels_onehot * confidence).argmax(1) if targeted else \
191 | (logits + labels_onehot * confidence).argmax(1)
192 |
193 | is_adv = (predicted_classes == labels) if targeted else (predicted_classes != labels)
194 | is_smaller = l2 < best_l2
195 | o_is_smaller = l2 < o_best_l2
196 | is_both = is_adv & is_smaller
197 | o_is_both = is_adv & o_is_smaller
198 |
199 | best_l2 = torch.where(is_both, l2, best_l2)
200 | adv_found.logical_or_(is_both)
201 | o_best_l2 = torch.where(o_is_both, l2, o_best_l2)
202 | o_adv_found.logical_or_(is_both)
203 | o_best_adv = torch.where(batch_view(o_is_both), adv_inputs.detach(), o_best_adv)
204 |
205 | if callback:
206 | i_total += 1
207 | callback.accumulate_line('logit_dist', i_total, logit_dists.mean())
208 | callback.accumulate_line('l2_norm', i_total, l2.mean())
209 | if i_total % (max_iterations // 20) == 0:
210 | callback.update_lines()
211 |
212 | if callback:
213 | best_l2 = o_best_l2[o_adv_found].mean() if o_adv_found.any() else torch.tensor(float('nan'), device=device)
214 | callback.line(['success', 'best_l2', 'c'], outer_step, [o_adv_found.float().mean(), best_l2, c.mean()])
215 |
216 | # adjust the constant as needed
217 | upper_bound[adv_found] = torch.min(upper_bound[adv_found], const[adv_found])
218 | adv_not_found = ~adv_found
219 | lower_bound[adv_not_found] = torch.max(lower_bound[adv_not_found], const[adv_not_found])
220 | is_smaller = upper_bound < 1e9
221 | const[is_smaller] = (lower_bound[is_smaller] + upper_bound[is_smaller]) / 2
222 | const[(~is_smaller) & adv_not_found] *= 5
223 |
224 | if retrain:
225 | lower_bound = torch.zeros_like(const)
226 | const = torch.full_like(const, initial_const)
227 | upper_bound = torch.full_like(const, 1e10)
228 |
229 | for i in range(8): # taken from the original implementation
230 |
231 | best_l2 = torch.full_like(const, float('inf'))
232 | adv_found = torch.zeros(batch_size, device=device, dtype=torch.bool)
233 |
234 | o_best_y = o_best_adv - inputs
235 | np_o_best_y = o_best_y.cpu().numpy()
236 | Nz = np.abs(np_o_best_y[np.nonzero(np_o_best_y)])
237 | threshold = np.percentile(Nz, σ)
238 |
239 | S_σ = o_best_y.abs() <= threshold
240 | z = o_best_y.clone()
241 | u = torch.zeros_like(inputs)
242 | tmpC = ρ / (ρ + γ / 100)
243 |
244 | for outer_step in range(400): # taken from the original implementation
245 | # δ step (equation 21)
246 | temp_a = (z - u) * tmpC
247 | δ = torch.where(S_σ, zeros, torch.minimum(torch.maximum(temp_a, inf_bound), sup_bound))
248 |
249 | # new predictions
250 | δ.requires_grad_(True)
251 | adv_inputs = δ + inputs
252 | logits = model(adv_inputs)
253 | l2 = (adv_inputs.detach() - inputs).flatten(1).norm(p=2, dim=1)
254 | logit_dists = multiplier * difference_of_logits(logits, labels, labels_infhot=labels_infhot)
255 | loss = const * (logit_dists + confidence).clamp(min=0)
256 | z_grad = grad(loss.sum(), inputs=δ, only_inputs=True)[0]
257 | δ.detach_()
258 |
259 | # z step (equation 22)
260 | a_prime = δ + u
261 | z = torch.where(S_σ, zeros, (α * z + ρ * a_prime - z_grad) / (α + 2 * ρ))
262 | u.add_(δ - z)
263 |
264 | # adjust the best result found so far
265 | predicted_classes = (logits - labels_onehot * confidence).argmax(1) if targeted else \
266 | (logits + labels_onehot * confidence).argmax(1)
267 |
268 | is_adv = (predicted_classes == labels) if targeted else (predicted_classes != labels)
269 | is_smaller = l2 < best_l2
270 | o_is_smaller = l2 < o_best_l2
271 | is_both = is_adv & is_smaller
272 | o_is_both = is_adv & o_is_smaller
273 |
274 | best_l2 = torch.where(is_both, l2, best_l2)
275 | adv_found.logical_or_(is_both)
276 | o_best_l2 = torch.where(o_is_both, l2, o_best_l2)
277 | o_adv_found.logical_or_(is_both)
278 | o_best_adv = torch.where(batch_view(o_is_both), adv_inputs.detach(), o_best_adv)
279 |
280 | # adjust the constant as needed
281 | upper_bound[adv_found] = torch.min(upper_bound[adv_found], const[adv_found])
282 | adv_not_found = ~adv_found
283 | lower_bound[adv_not_found] = torch.max(lower_bound[adv_not_found], const[adv_not_found])
284 | is_smaller = upper_bound < 1e9
285 | const[is_smaller] = (lower_bound[is_smaller] + upper_bound[is_smaller]) / 2
286 | const[(~is_smaller) & adv_not_found] *= 5
287 |
288 | # return the best solution found
289 | return o_best_adv
290 |
--------------------------------------------------------------------------------
/adv_lib/attacks/superdeepfool.py:
--------------------------------------------------------------------------------
1 | import warnings
2 |
3 | import torch
4 | from torch import Tensor, nn
5 | from torch.autograd import grad
6 |
7 | from .deepfool import df
8 |
9 |
10 | def sdf(model: nn.Module,
11 | inputs: Tensor,
12 | labels: Tensor,
13 | targeted: bool = False,
14 | steps: int = 100,
15 | df_steps: int = 100,
16 | overshoot: float = 0.02,
17 | search_iter: int = 10) -> Tensor:
18 | """
19 | SuperDeepFool attack from https://arxiv.org/abs/2303.12481.
20 |
21 | Parameters
22 | ----------
23 | model : nn.Module
24 | Model to attack.
25 | inputs : Tensor
26 | Inputs to attack. Should be in [0, 1].
27 | labels : Tensor
28 | Labels corresponding to the inputs if untargeted, else target labels.
29 | targeted : bool
30 | Whether to perform a targeted attack or not.
31 | steps : int
32 | Number of steps.
33 | df_steps : int
34 | Maximum number of steps for DeepFool attack at each iteration of SuperDeepFool.
35 | overshoot : float
36 | overshoot parameter in DeepFool.
37 | search_iter : int
38 | Number of binary search steps at the end of the attack.
39 |
40 | Returns
41 | -------
42 | adv_inputs : Tensor
43 | Modified inputs to be adversarial to the model.
44 |
45 | """
46 | if targeted:
47 | warnings.warn('DeepFool attack is untargeted only. Returning inputs.')
48 | return inputs
49 |
50 | if inputs.min() < 0 or inputs.max() > 1: raise ValueError('Input values should be in the [0, 1] range.')
51 | device = inputs.device
52 | batch_size = len(inputs)
53 | batch_view = lambda tensor: tensor.view(-1, *[1] * (inputs.ndim - 1))
54 |
55 | # Setup variables
56 | adv_inputs = inputs_ = inputs
57 | labels_ = labels
58 | adv_out = inputs.clone()
59 | adv_found = torch.zeros(batch_size, dtype=torch.bool, device=device)
60 |
61 | for i in range(steps):
62 | logits = model(adv_inputs)
63 | pred_labels = logits.argmax(dim=1)
64 |
65 | is_adv = pred_labels != labels_
66 | if is_adv.any():
67 | adv_not_found = ~adv_found
68 | adv_out[adv_not_found] = torch.where(batch_view(is_adv), adv_inputs, adv_out[adv_not_found])
69 | adv_found.masked_scatter_(adv_not_found, is_adv)
70 | if is_adv.all():
71 | break
72 |
73 | not_adv = ~is_adv
74 | inputs_, adv_inputs, labels_ = inputs_[not_adv], adv_inputs[not_adv], labels_[not_adv]
75 |
76 | # start by doing deepfool -> need to return adv_inputs even for unsuccessful attacks
77 | df_adv_inputs, df_targets = df(model=model, inputs=adv_inputs, labels=labels_, steps=df_steps, norm=2,
78 | overshoot=overshoot, return_unsuccessful=True, return_targets=True)
79 |
80 | r_df = df_adv_inputs - inputs_
81 | df_adv_inputs.requires_grad_(True)
82 | logits = model(df_adv_inputs)
83 | pred_labels = logits.argmax(dim=1)
84 | pred_labels = torch.where(pred_labels != labels_, pred_labels, df_targets)
85 |
86 | logit_diff = logits.gather(1, pred_labels.unsqueeze(1)) - logits.gather(1, labels_.unsqueeze(1))
87 | w = grad(logit_diff.sum(), inputs=df_adv_inputs, only_inputs=True)[0]
88 | w.div_(batch_view(w.flatten(1).norm(p=2, dim=1).clamp_(min=1e-6))) # w / ||w||_2
89 | scale = torch.linalg.vecdot(r_df.flatten(1), w.flatten(1), dim=1) # (\tilde{x} - x_0)^T w / ||w||_2
90 |
91 | adv_inputs = adv_inputs.addcmul(batch_view(scale), w)
92 | adv_inputs.clamp_(min=0, max=1) # added compared to original implementation to produce valid adv
93 |
94 | if search_iter: # binary search to bring perturbation as close to the decision boundary as possible
95 | low, high = torch.zeros(batch_size, device=device), torch.ones(batch_size, device=device)
96 | for i in range(search_iter):
97 | mid = (low + high) / 2
98 | logits = torch.lerp(inputs, adv_out, weight=batch_view(mid))
99 | pred_labels = model(logits).argmax(dim=1)
100 | is_adv = pred_labels != labels
101 | high = torch.where(is_adv, mid, high)
102 | low = torch.where(is_adv, low, mid)
103 | adv_out = torch.lerp(inputs, adv_out, weight=batch_view(high))
104 |
105 | return adv_out
106 |
--------------------------------------------------------------------------------
/adv_lib/attacks/trust_region.py:
--------------------------------------------------------------------------------
1 | # Adapted from https://github.com/amirgholami/trattack
2 |
3 | import warnings
4 | from functools import partial
5 | from typing import Tuple
6 |
7 | import torch
8 | from torch import Tensor, nn
9 | from torch.autograd import grad
10 |
11 |
12 | def select_index(model: nn.Module,
13 | inputs: Tensor,
14 | c: int = 9,
15 | p: float = 2,
16 | worst_case: bool = False) -> Tensor:
17 | """
18 | Select the attack target class.
19 | """
20 | _duals = {2: 2, float('inf'): 1}
21 | dual = _duals[p]
22 |
23 | inputs.requires_grad_(True)
24 | logits = model(inputs)
25 | logits, indices = torch.sort(logits, descending=True)
26 |
27 | top_logits = logits[:, 0]
28 | top_grad = grad(top_logits.sum(), inputs, only_inputs=True, retain_graph=True)[0]
29 | pers = []
30 |
31 | c = min(logits.size(1) - 1, c)
32 | for i in range(c):
33 | other_logits = logits[:, i + 1]
34 | other_grad = grad(other_logits.sum(), inputs, only_inputs=True, retain_graph=i + 1 != c)[0]
35 | grad_dual_norm = (top_grad - other_grad).flatten(1).norm(p=dual, dim=1)
36 | pers.append((top_logits.detach() - other_logits.detach()).div_(grad_dual_norm))
37 |
38 | pers = torch.stack(pers, dim=1)
39 | inputs.detach_()
40 |
41 | if worst_case:
42 | index = pers.argmax(dim=1, keepdim=True)
43 | else:
44 | index = pers.clamp_(min=0).argmin(dim=1, keepdim=True)
45 |
46 | return indices.gather(1, index + 1).squeeze(1)
47 |
48 |
49 | def _step(model: nn.Module,
50 | inputs: Tensor,
51 | labels: Tensor,
52 | target_labels: Tensor,
53 | eps: Tensor,
54 | p: float = 2) -> Tuple[Tensor, Tensor]:
55 | _duals = {2: 2, float('inf'): 1}
56 | dual = _duals[p]
57 |
58 | inputs.requires_grad_(True)
59 | logits = model(inputs)
60 |
61 | logit_diff = (logits.gather(1, target_labels.unsqueeze(1)) - logits.gather(1, labels.unsqueeze(1))).squeeze(1)
62 |
63 | grad_inputs = grad(logit_diff.sum(), inputs, only_inputs=True)[0].flatten(1)
64 | inputs.detach_()
65 | per = logit_diff.detach().neg_().div_(grad_inputs.norm(p=dual, dim=1).clamp_(min=1e-6))
66 |
67 | if p == float('inf'):
68 | grad_inputs.sign_()
69 | elif p == 2:
70 | grad_inputs.div_(grad_inputs.norm(p=2, dim=1, keepdim=True).clamp_(min=1e-6))
71 |
72 | per = torch.min(per, eps)
73 | adv_inputs = grad_inputs.mul_(per.add_(1e-4).mul_(1.02).unsqueeze(1)).view_as(inputs).add_(inputs)
74 | adv_inputs.clamp_(min=0, max=1)
75 | return adv_inputs, eps
76 |
77 |
78 | def _adaptive_step(model: nn.Module,
79 | inputs: Tensor,
80 | labels: Tensor,
81 | target_labels: Tensor,
82 | eps: Tensor,
83 | p: float = 2) -> Tuple[Tensor, Tensor]:
84 | _duals = {2: 2, float('inf'): 1}
85 | dual = _duals[p]
86 |
87 | inputs.requires_grad_(True)
88 | logits = model(inputs)
89 |
90 | class_logits = logits.gather(1, labels.unsqueeze(1)).squeeze(1)
91 | target_logits = logits.gather(1, target_labels.unsqueeze(1)).squeeze(1)
92 | logit_diff = target_logits - class_logits
93 |
94 | grad_inputs = grad(logit_diff.sum(), inputs, only_inputs=True)[0].flatten(1)
95 | inputs.detach_()
96 | per = logit_diff.detach().neg_().div_(grad_inputs.norm(p=dual, dim=1).clamp_(min=1e-6))
97 |
98 | if p == float('inf'):
99 | grad_inputs.sign_()
100 | elif p == 2:
101 | grad_inputs.div_(grad_inputs.norm(p=2, dim=1, keepdim=True).clamp_(min=1e-6))
102 |
103 | new_eps = torch.min(per, eps)
104 |
105 | adv_inputs = grad_inputs.mul_(new_eps.add(1e-4).mul_(1.02).unsqueeze_(1)).view_as(inputs).add_(inputs)
106 | adv_inputs.clamp_(min=0, max=1)
107 |
108 | adv_logits = model(adv_inputs)
109 | class_adv_logits = adv_logits.gather(1, labels.unsqueeze(1)).squeeze(1)
110 |
111 | obj_diff = (class_logits - class_adv_logits).div_(new_eps)
112 | increase = obj_diff > 0.9
113 | decrease = obj_diff < 0.5
114 | new_eps = torch.where(increase, new_eps * 1.2, torch.where(decrease, new_eps / 1.2, new_eps))
115 | if p == 2:
116 | new_eps.clamp_(min=0.0005, max=0.05)
117 | elif p == float('inf'):
118 | new_eps.clamp_(min=0.0001, max=0.01)
119 |
120 | return adv_inputs, new_eps
121 |
122 |
123 | def tr(model: nn.Module,
124 | inputs: Tensor,
125 | labels: Tensor,
126 | iter: int = 100,
127 | adaptive: bool = False,
128 | p: float = 2,
129 | eps: float = 0.001,
130 | c: int = 9,
131 | worst_case: bool = False,
132 | targeted: bool = False) -> Tensor:
133 | if targeted:
134 | warnings.warn('TR attack is untargeted only. Returning inputs.')
135 | return inputs
136 |
137 | adv_inputs = inputs.clone()
138 | target_labels = select_index(model, inputs, c=c, p=p, worst_case=worst_case)
139 | attack_step = partial(_adaptive_step if adaptive else _step, model=model, p=p)
140 |
141 | to_attack = torch.ones(len(inputs), dtype=torch.bool, device=inputs.device)
142 | eps = torch.full_like(to_attack, eps, dtype=torch.float, device=inputs.device)
143 |
144 | for _ in range(iter):
145 |
146 | logits = model(adv_inputs[to_attack])
147 | to_attack.masked_scatter_(to_attack, logits.argmax(dim=1) == labels[to_attack])
148 | if (~to_attack).all():
149 | break
150 | adv_inputs[to_attack], eps[to_attack] = attack_step(inputs=adv_inputs[to_attack], labels=labels[to_attack],
151 | target_labels=target_labels[to_attack], eps=eps[to_attack])
152 |
153 | return adv_inputs
154 |
--------------------------------------------------------------------------------
/adv_lib/distances/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jeromerony/adversarial-library/4ca9b77bb6c909d47e161ced0e257c6003fb4116/adv_lib/distances/__init__.py
--------------------------------------------------------------------------------
/adv_lib/distances/color_difference.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 | from torch import Tensor
5 |
6 | from adv_lib.utils.color_conversions import rgb_to_cielab
7 |
8 |
9 | def cie94_color_difference(Lab_1: Tensor, Lab_2: Tensor, k_L: float = 1, k_C: float = 1, k_H: float = 1,
10 | K_1: float = 0.045, K_2: float = 0.015, squared: bool = False, ε: float = 0) -> Tensor:
11 | """
12 | Inputs should be L*, a*, b*. Star from formulas are omitted for conciseness.
13 |
14 | Parameters
15 | ----------
16 | Lab_1 : Tensor
17 | First image in L*a*b* space. First image is intended to be the reference image.
18 | Lab_2 : Tensor
19 | Second image in L*a*b* space. Second image is intended to be the modified one.
20 | k_L : float
21 | Weighting factor for S_L.
22 | k_C : float
23 | Weighting factor for S_C.
24 | k_H : float
25 | Weighting factor for S_H.
26 | squared : bool
27 | Return the squared ΔE_94.
28 | ε : float
29 | Small value for numerical stability when computing gradients. Default to 0 for most accurate evaluation.
30 |
31 | Returns
32 | -------
33 | ΔE_94 : Tensor
34 | The CIEDE2000 color difference for each pixel.
35 |
36 | """
37 | ΔL = Lab_1.narrow(1, 0, 1) - Lab_2.narrow(1, 0, 1)
38 | C_1 = torch.norm(Lab_1.narrow(1, 1, 2), p=2, dim=1, keepdim=True)
39 | C_2 = torch.norm(Lab_2.narrow(1, 1, 2), p=2, dim=1, keepdim=True)
40 | ΔC = C_1 - C_2
41 | Δa = Lab_1.narrow(1, 1, 1) - Lab_2.narrow(1, 1, 1)
42 | Δb = Lab_1.narrow(1, 2, 1) - Lab_2.narrow(1, 2, 1)
43 | ΔH = Δa ** 2 + Δb ** 2 - ΔC ** 2
44 | S_L = 1
45 | S_C = 1 + K_1 * C_1
46 | S_H = 1 + K_2 * C_1
47 | ΔE_94_squared = (ΔL / (k_L * S_L)) ** 2 + (ΔC / (k_C * S_C)) ** 2 + ΔH / ((k_H * S_H) ** 2)
48 | if squared:
49 | return ΔE_94_squared
50 | return ΔE_94_squared.clamp(min=ε).sqrt()
51 |
52 |
53 | def rgb_cie94_color_difference(input: Tensor, target: Tensor, **kwargs) -> Tensor:
54 | """Computes the CIEDE2000 Color-Difference from RGB inputs."""
55 | return cie94_color_difference(*map(rgb_to_cielab, (input, target)), **kwargs)
56 |
57 |
58 | def cie94_loss(x1: Tensor, x2: Tensor, squared: bool = False, **kwargs) -> Tensor:
59 | """
60 | Computes the L2-norm over all pixels of the CIEDE2000 Color-Difference for two RGB inputs.
61 |
62 | Parameters
63 | ----------
64 | x1 : Tensor:
65 | First input.
66 | x2 : Tensor:
67 | Second input (of size matching x1).
68 | squared : bool
69 | Returns the squared L2-norm.
70 |
71 | Returns
72 | -------
73 | ΔE_00_l2 : Tensor
74 | The L2-norm over all pixels of the CIEDE2000 Color-Difference.
75 |
76 | """
77 | ΔE_94_squared = rgb_cie94_color_difference(x1, x2, squared=True, **kwargs).flatten(1)
78 | ε = kwargs.get('ε', 0)
79 | if squared:
80 | return ΔE_94_squared.sum(1)
81 | return ΔE_94_squared.sum(1).clamp(min=ε).sqrt()
82 |
83 |
84 | def ciede2000_color_difference(Lab_1: Tensor, Lab_2: Tensor, k_L: float = 1, k_C: float = 1, k_H: float = 1,
85 | squared: bool = False, ε: float = 0) -> Tensor:
86 | """
87 | Inputs should be L*, a*, b*. Primes from formulas in
88 | http://www2.ece.rochester.edu/~gsharma/ciede2000/ciede2000noteCRNA.pdf are omitted for conciseness.
89 | This version is based on the matlab implementation from Gaurav Sharma
90 | http://www2.ece.rochester.edu/~gsharma/ciede2000/dataNprograms/deltaE2000.m modified to have non NaN gradients.
91 |
92 | Parameters
93 | ----------
94 | Lab_1 : Tensor
95 | First image in L*a*b* space. First image is intended to be the reference image.
96 | Lab_2 : Tensor
97 | Second image in L*a*b* space. Second image is intended to be the modified one.
98 | k_L : float
99 | Weighting factor for S_L.
100 | k_C : float
101 | Weighting factor for S_C.
102 | k_H : float
103 | Weighting factor for S_H.
104 | squared : bool
105 | Return the squared ΔE_00.
106 | ε : float
107 | Small value for numerical stability when computing gradients. Default to 0 for most accurate evaluation.
108 |
109 | Returns
110 | -------
111 | ΔE_00 : Tensor
112 | The CIEDE2000 color difference for each pixel.
113 |
114 | """
115 | assert Lab_1.size(1) == 3 and Lab_2.size(1) == 3
116 | assert Lab_1.dtype == Lab_2.dtype
117 | dtype = Lab_1.dtype
118 | π = torch.tensor(math.pi, dtype=dtype, device=Lab_1.device)
119 | π_compare = π if dtype == torch.float64 else torch.tensor(math.pi, dtype=torch.float64, device=Lab_1.device)
120 |
121 | L_star_1, a_star_1, b_star_1 = Lab_1.unbind(dim=1)
122 | L_star_2, a_star_2, b_star_2 = Lab_2.unbind(dim=1)
123 |
124 | C_star_1 = torch.norm(torch.stack((a_star_1, b_star_1), dim=1), p=2, dim=1)
125 | C_star_2 = torch.norm(torch.stack((a_star_2, b_star_2), dim=1), p=2, dim=1)
126 | C_star_bar = (C_star_1 + C_star_2) / 2
127 | C7 = C_star_bar ** 7
128 | G = 0.5 * (1 - (C7 / (C7 + 25 ** 7)).clamp(min=ε).sqrt())
129 |
130 | scale = 1 + G
131 | a_1 = scale * a_star_1
132 | a_2 = scale * a_star_2
133 | C_1 = torch.norm(torch.stack((a_1, b_star_1), dim=1), p=2, dim=1)
134 | C_2 = torch.norm(torch.stack((a_2, b_star_2), dim=1), p=2, dim=1)
135 | C_1_C_2_zero = (C_1 == 0) | (C_2 == 0)
136 | h_1 = torch.atan2(b_star_1, a_1 + ε * (a_1 == 0))
137 | h_2 = torch.atan2(b_star_2, a_2 + ε * (a_2 == 0))
138 |
139 | # required to match the test data
140 | h_abs_diff_compare = (torch.atan2(b_star_1.to(dtype=torch.float64),
141 | a_1.to(dtype=torch.float64)).remainder(2 * π_compare) -
142 | torch.atan2(b_star_2.to(dtype=torch.float64),
143 | a_2.to(dtype=torch.float64)).remainder(2 * π_compare)).abs() <= π_compare
144 |
145 | h_1 = h_1.remainder(2 * π)
146 | h_2 = h_2.remainder(2 * π)
147 | h_diff = h_2 - h_1
148 | h_sum = h_1 + h_2
149 |
150 | ΔL = L_star_2 - L_star_1
151 | ΔC = C_2 - C_1
152 | Δh = torch.where(C_1_C_2_zero, torch.zeros_like(h_1),
153 | torch.where(h_abs_diff_compare, h_diff,
154 | torch.where(h_diff > π, h_diff - 2 * π, h_diff + 2 * π)))
155 |
156 | ΔH = 2 * (C_1 * C_2).clamp(min=ε).sqrt() * torch.sin(Δh / 2)
157 | ΔH_squared = 4 * C_1 * C_2 * torch.sin(Δh / 2) ** 2
158 |
159 | L_bar = (L_star_1 + L_star_2) / 2
160 | C_bar = (C_1 + C_2) / 2
161 |
162 | h_bar = torch.where(C_1_C_2_zero, h_sum,
163 | torch.where(h_abs_diff_compare, h_sum / 2,
164 | torch.where(h_sum < 2 * π, h_sum / 2 + π, h_sum / 2 - π)))
165 |
166 | T = 1 - 0.17 * (h_bar - π / 6).cos() + 0.24 * (2 * h_bar).cos() + \
167 | 0.32 * (3 * h_bar + π / 30).cos() - 0.20 * (4 * h_bar - 63 * π / 180).cos()
168 |
169 | Δθ = π / 6 * (torch.exp(-((180 / π * h_bar - 275) / 25) ** 2))
170 | C7 = C_bar ** 7
171 | R_C = 2 * (C7 / (C7 + 25 ** 7)).clamp(min=ε).sqrt()
172 | S_L = 1 + 0.015 * (L_bar - 50) ** 2 / torch.sqrt(20 + (L_bar - 50) ** 2)
173 | S_C = 1 + 0.045 * C_bar
174 | S_H = 1 + 0.015 * C_bar * T
175 | R_T = -torch.sin(2 * Δθ) * R_C
176 |
177 | ΔE_00 = (ΔL / (k_L * S_L)) ** 2 + (ΔC / (k_C * S_C)) ** 2 + ΔH_squared / (k_H * S_H) ** 2 + \
178 | R_T * (ΔC / (k_C * S_C)) * (ΔH / (k_H * S_H))
179 | if squared:
180 | return ΔE_00
181 | return ΔE_00.clamp(min=ε).sqrt()
182 |
183 |
184 | def rgb_ciede2000_color_difference(input: Tensor, target: Tensor, **kwargs) -> Tensor:
185 | """Computes the CIEDE2000 Color-Difference from RGB inputs."""
186 | return ciede2000_color_difference(*map(rgb_to_cielab, (input, target)), **kwargs)
187 |
188 |
189 | def ciede2000_loss(x1: Tensor, x2: Tensor, squared: bool = False, **kwargs) -> Tensor:
190 | """
191 | Computes the L2-norm over all pixels of the CIEDE2000 Color-Difference for two RGB inputs.
192 |
193 | Parameters
194 | ----------
195 | x1 : Tensor:
196 | First input.
197 | x2 : Tensor:
198 | Second input (of size matching x1).
199 | squared : bool
200 | Returns the squared L2-norm.
201 |
202 | Returns
203 | -------
204 | ΔE_00_l2 : Tensor
205 | The L2-norm over all pixels of the CIEDE2000 Color-Difference.
206 |
207 | """
208 | ΔE_00 = rgb_ciede2000_color_difference(x1, x2, squared=True, **kwargs).flatten(1)
209 | ε = kwargs.get('ε', 0)
210 | if squared:
211 | return ΔE_00.sum(1)
212 | return ΔE_00.sum(1).clamp(min=ε).sqrt()
213 |
--------------------------------------------------------------------------------
/adv_lib/distances/lp_norms.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 | from typing import Union
3 |
4 | from torch import Tensor
5 |
6 |
7 | def lp_distances(x1: Tensor, x2: Tensor, p: Union[float, int] = 2, dim: int = 1) -> Tensor:
8 | return (x1 - x2).flatten(dim).norm(p=p, dim=dim)
9 |
10 |
11 | l0_distances = partial(lp_distances, p=0)
12 | l1_distances = partial(lp_distances, p=1)
13 | l2_distances = partial(lp_distances, p=2)
14 | linf_distances = partial(lp_distances, p=float('inf'))
15 |
16 |
17 | def squared_l2_distances(x1: Tensor, x2: Tensor, dim: int = 1) -> Tensor:
18 | return (x1 - x2).square().flatten(dim).sum(dim)
19 |
--------------------------------------------------------------------------------
/adv_lib/distances/lpips.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple, Optional, Union
2 |
3 | import torch
4 | from torch import nn, Tensor
5 | from torchvision import models
6 |
7 | from adv_lib.utils import ImageNormalizer, requires_grad_
8 |
9 |
10 | class AlexNetFeatures(nn.Module):
11 | def __init__(self) -> None:
12 | super(AlexNetFeatures, self).__init__()
13 | self.normalize = ImageNormalizer(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
14 | self.model = models.alexnet(pretrained=True)
15 | self.model.eval()
16 |
17 | self.features_layers = nn.ModuleList([
18 | self.model.features[:2],
19 | self.model.features[2:5],
20 | self.model.features[5:8],
21 | self.model.features[8:10],
22 | self.model.features[10:12],
23 | ])
24 |
25 | requires_grad_(self, False)
26 |
27 | def forward(self, x: Tensor) -> Tuple[Tensor, ...]:
28 | return self.features(x)
29 |
30 | def features(self, x: Tensor) -> Tuple[Tensor, ...]:
31 | x = self.normalize(x)
32 |
33 | features = [x]
34 | for i, layer in enumerate(self.features_layers):
35 | features.append(layer(features[i]))
36 |
37 | return tuple(features[1:])
38 |
39 |
40 | def _normalize_features(x: Tensor, ε: float = 1e-12) -> Tensor:
41 | """Normalize by norm and sqrt of spatial size."""
42 | norm = torch.norm(x, dim=1, p=2, keepdim=True)
43 | return x / (norm[0].numel() ** 0.5 * norm.clamp(min=ε))
44 |
45 |
46 | def _feature_difference(features_1: Tensor, features_2: Tensor, linear_mapping: Optional[nn.Module] = None) -> Tensor:
47 | features = [map(_normalize_features, feature) for feature in [features_1, features_2]] # Normalize features
48 | if linear_mapping is not None: # Perform linear scaling
49 | features = [[module(f) for module, f in zip(linear_mapping, feature)] for feature in features]
50 | features = [torch.cat([f.flatten(1) for f in feature], dim=1) for feature in features] # Concatenate
51 | return features[0] - features[1]
52 |
53 |
54 | class LPIPS(nn.Module):
55 | _models = {'alexnet': AlexNetFeatures}
56 |
57 | def __init__(self,
58 | model: Union[str, nn.Module] = 'alexnet',
59 | linear_mapping: Optional[str] = None,
60 | target: Optional[Tensor] = None,
61 | squared: bool = False) -> None:
62 | super(LPIPS, self).__init__()
63 |
64 | if isinstance(model, str):
65 | self.features = self._models[model]()
66 | else:
67 | self.features = model
68 |
69 | self.linear_mapping = None
70 | if linear_mapping is not None:
71 | convs = []
72 | sd = torch.load(linear_mapping)
73 | for k, weight in sd.items():
74 | out_channels, in_channels = weight.shape[:2]
75 | conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=False)
76 | conv.weight.data.copy_(weight)
77 | convs.append(conv)
78 | self.linear_mapping = nn.ModuleList(convs)
79 |
80 | self.target_features = None
81 | if target is not None:
82 | self.to(target.device)
83 | self.target_features = self.features(target)
84 |
85 | self.squared = squared
86 |
87 | def forward(self, input: Tensor, target: Optional[Tensor] = None) -> Tensor:
88 | input_features = self.features(input)
89 | if target is None and self.target_features is not None:
90 | target_features = self.target_features
91 | elif target is not None:
92 | target_features = self.features(target)
93 | else:
94 | raise ValueError('Must provide targets (either in init or in forward).')
95 |
96 | if self.squared:
97 | return _feature_difference(input_features, target_features).square().sum(dim=1)
98 |
99 | return torch.norm(_feature_difference(input_features, target_features), p=2, dim=1)
100 |
--------------------------------------------------------------------------------
/adv_lib/distances/structural_similarity.py:
--------------------------------------------------------------------------------
1 | # Adapted from https://github.com/pytorch/pytorch/pull/22289
2 | from functools import lru_cache
3 | from typing import Tuple
4 |
5 | import torch
6 | from torch import Tensor
7 | from torch.nn import _reduction as _Reduction
8 | from torch.nn.functional import avg_pool2d, conv2d
9 |
10 |
11 | @lru_cache()
12 | def _fspecial_gaussian(size: int, channel: int, sigma: float, device: torch.device, dtype: torch.dtype,
13 | max_size: Tuple[int, int]) -> Tensor:
14 | coords = -(torch.arange(size, device=device, dtype=dtype) - (size - 1) / 2) ** 2 / (2. * sigma ** 2)
15 | if max(max_size) <= size:
16 | coords_x, coords_y = torch.zeros(max_size[0], device=device, dtype=dtype), torch.zeros(max_size[1],
17 | device=device,
18 | dtype=dtype)
19 | elif max_size[0] <= size:
20 | coords_x, coords_y = torch.zeros(max_size[0], device=device, dtype=dtype), coords
21 | elif max_size[1] <= size:
22 | coords_x, coords_y = coords, torch.zeros(max_size[1], device=device, dtype=dtype)
23 | else:
24 | coords_x = coords_y = coords
25 | final_size = (min(max_size[0], size), min(max_size[1], size))
26 |
27 | grid = coords_x.view(-1, 1) + coords_y.view(1, -1)
28 | kernel = grid.view(1, -1).softmax(-1).view(1, 1, *final_size).expand(channel, 1, -1, -1).contiguous()
29 | return kernel
30 |
31 |
32 | def _ssim(input: Tensor, target: Tensor, max_val: float, k1: float, k2: float, channel: int,
33 | kernel: Tensor) -> Tuple[Tensor, Tensor]:
34 | c1 = (k1 * max_val) ** 2
35 | c2 = (k2 * max_val) ** 2
36 |
37 | mu1 = conv2d(input, kernel, groups=channel)
38 | mu2 = conv2d(target, kernel, groups=channel)
39 |
40 | mu1_sq = mu1 ** 2
41 | mu2_sq = mu2 ** 2
42 | mu1_mu2 = mu1 * mu2
43 |
44 | sigma1_sq = conv2d(input * input, kernel, groups=channel) - mu1_sq
45 | sigma2_sq = conv2d(target * target, kernel, groups=channel) - mu2_sq
46 | sigma12 = conv2d(input * target, kernel, groups=channel) - mu1_mu2
47 |
48 | v1 = 2 * sigma12 + c2
49 | v2 = sigma1_sq + sigma2_sq + c2
50 |
51 | ssim = ((2 * mu1_mu2 + c1) * v1) / ((mu1_sq + mu2_sq + c1) * v2)
52 | return ssim, v1 / v2
53 |
54 |
55 | def ssim(input: Tensor, target: Tensor, max_val: float, filter_size: int = 11, k1: float = 0.01, k2: float = 0.03,
56 | sigma: float = 1.5, size_average=None, reduce=None, reduction: str = 'mean') -> Tensor:
57 | """Measures the structural similarity index (SSIM) error."""
58 | dim = input.dim()
59 | if dim != 4:
60 | raise ValueError('Expected 4 dimensions (got {})'.format(dim))
61 |
62 | if input.size() != target.size():
63 | raise ValueError('Expected input size ({}) to match target size ({}).'.format(input.size(0), target.size(0)))
64 |
65 | if size_average is not None or reduce is not None:
66 | reduction = _Reduction.legacy_get_string(size_average, reduce)
67 |
68 | channel = input.size(1)
69 | kernel = _fspecial_gaussian(filter_size, channel, sigma, device=input.device, dtype=input.dtype,
70 | max_size=input.shape[-2:])
71 | ret, _ = _ssim(input, target, max_val, k1, k2, channel, kernel)
72 |
73 | if reduction != 'none':
74 | ret = torch.mean(ret) if reduction == 'mean' else torch.sum(ret)
75 | return ret
76 |
77 |
78 | def compute_ssim(input: Tensor, target: Tensor, **kwargs) -> Tensor:
79 | c_ssim = ssim(input=input, target=target, max_val=1, reduction='none', **kwargs).mean([2, 3])
80 | return c_ssim.mean(1)
81 |
82 |
83 | def ssim_loss(*args, **kwargs) -> Tensor:
84 | return 1 - compute_ssim(*args, **kwargs)
85 |
86 |
87 | @lru_cache()
88 | def ms_weights(device: torch.device):
89 | return torch.tensor([0.0448, 0.2856, 0.3001, 0.2363, 0.1333], device=device)
90 |
91 |
92 | def ms_ssim(input: Tensor, target: Tensor, max_val: float, filter_size: int = 11, k1: float = 0.01, k2: float = 0.03,
93 | sigma: float = 1.5, size_average=None, reduce=None, reduction: str = 'mean') -> Tensor:
94 | """Measures the multi-scale structural similarity index (MS-SSIM) error."""
95 | dim = input.dim()
96 | if dim != 4:
97 | raise ValueError('Expected 4 dimensions (got {}) from input'.format(dim))
98 |
99 | if input.size() != target.size():
100 | raise ValueError('Expected input size ({}) to match target size ({}).'.format(input.size(0), target.size(0)))
101 |
102 | if size_average is not None or reduce is not None:
103 | reduction = _Reduction.legacy_get_string(size_average, reduce)
104 |
105 | channel = input.size(1)
106 | kernel = _fspecial_gaussian(filter_size, channel, sigma, device=input.device, dtype=input.dtype,
107 | max_size=input.shape[-2:])
108 |
109 | weights = ms_weights(input.device).unsqueeze(-1).unsqueeze(-1)
110 | levels = weights.size(0)
111 | mssim = []
112 | mcs = []
113 | for i in range(levels):
114 |
115 | if i:
116 | input = avg_pool2d(input, kernel_size=2, ceil_mode=True)
117 | target = avg_pool2d(target, kernel_size=2, ceil_mode=True)
118 |
119 | if min(size := input.shape[-2:]) <= filter_size:
120 | kernel = _fspecial_gaussian(filter_size, channel, sigma, device=input.device, dtype=input.dtype,
121 | max_size=size)
122 |
123 | ssim, cs = _ssim(input, target, max_val, k1, k2, channel, kernel)
124 | ssim = ssim.mean((2, 3))
125 | cs = cs.mean((2, 3))
126 | mssim.append(ssim)
127 | mcs.append(cs)
128 |
129 | mssim = torch.stack(mssim)
130 | mcs = torch.stack(mcs)
131 | p1 = mcs ** weights
132 | p2 = mssim ** weights
133 |
134 | ret = torch.prod(p1[:-1], 0) * p2[-1]
135 |
136 | if reduction != 'none':
137 | ret = torch.mean(ret) if reduction == 'mean' else torch.sum(ret)
138 | return ret
139 |
140 |
141 | def compute_ms_ssim(input: Tensor, target: Tensor, **kwargs) -> Tensor:
142 | channel_ssim = ms_ssim(input=input, target=target, max_val=1, reduction='none', **kwargs)
143 | return channel_ssim.mean(1)
144 |
145 |
146 | def ms_ssim_loss(*args, **kwargs) -> Tensor:
147 | return 1 - compute_ms_ssim(*args, **kwargs)
148 |
--------------------------------------------------------------------------------
/adv_lib/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .utils import ForwardCounter, BackwardCounter, ImageNormalizer, normalize_model, predict_inputs, requires_grad_
--------------------------------------------------------------------------------
/adv_lib/utils/attack_utils.py:
--------------------------------------------------------------------------------
1 | import warnings
2 | from collections import OrderedDict
3 | from distutils.version import LooseVersion
4 | from functools import partial
5 | from inspect import isclass
6 | from typing import Callable, Optional, Dict, Union
7 |
8 | import numpy as np
9 | import torch
10 | import tqdm
11 | from torch import Tensor, nn
12 | from torch.nn import functional as F
13 |
14 | from adv_lib.distances.lp_norms import l0_distances, l1_distances, l2_distances, linf_distances
15 | from adv_lib.utils import ForwardCounter, BackwardCounter, predict_inputs
16 |
17 |
18 | def generate_random_targets(labels: Tensor, num_classes: int) -> Tensor:
19 | """
20 | Generates one random target in (num_classes - 1) possibilities for each label that is different from the original
21 | label.
22 |
23 | Parameters
24 | ----------
25 | labels: Tensor
26 | Original labels. Generated targets will be different from labels.
27 | num_classes: int
28 | Number of classes to generate the random targets from.
29 |
30 | Returns
31 | -------
32 | targets: Tensor
33 | Random target for each label. Has the same shape as labels.
34 |
35 | """
36 | random = torch.rand(len(labels), num_classes, device=labels.device, dtype=torch.float)
37 | random.scatter_(1, labels.unsqueeze(-1), 0)
38 | return random.argmax(1)
39 |
40 |
41 | def get_all_targets(labels: Tensor, num_classes: int) -> Tensor:
42 | """
43 | Generates all possible targets that are different from the original labels.
44 |
45 | Parameters
46 | ----------
47 | labels: Tensor
48 | Original labels. Generated targets will be different from labels.
49 | num_classes: int
50 | Number of classes to generate the random targets from.
51 |
52 | Returns
53 | -------
54 | targets: Tensor
55 | Random targets for each label. shape: (len(labels), num_classes - 1).
56 |
57 | """
58 | assert labels.ndim == 1
59 | all_possible_targets = torch.arange(num_classes - 1, dtype=torch.long, device=labels.device)
60 | all_possible_targets = all_possible_targets + (all_possible_targets >= labels.unsqueeze(1))
61 | return all_possible_targets
62 |
63 |
64 | def run_attack(model: nn.Module,
65 | inputs: Tensor,
66 | labels: Tensor,
67 | attack: Callable,
68 | targets: Optional[Tensor] = None,
69 | batch_size: Optional[int] = None) -> dict:
70 | device = next(model.parameters()).device
71 | to_device = lambda tensor: tensor.to(device)
72 | targeted, adv_labels = False, labels
73 | if targets is not None:
74 | targeted, adv_labels = True, targets
75 | batch_size = batch_size or len(inputs)
76 |
77 | # run attack only on non already adversarial samples
78 | already_adv = []
79 | chunks = [tensor.split(batch_size) for tensor in [inputs, adv_labels]]
80 | for (inputs_chunk, label_chunk) in zip(*chunks):
81 | batch_chunk_d, label_chunk_d = [to_device(tensor) for tensor in [inputs_chunk, label_chunk]]
82 | preds = model(batch_chunk_d).argmax(1)
83 | is_adv = (preds == label_chunk_d) if targeted else (preds != label_chunk_d)
84 | already_adv.append(is_adv.cpu())
85 | not_adv = ~torch.cat(already_adv, 0)
86 |
87 | start, end = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
88 | forward_counter, backward_counter = ForwardCounter(), BackwardCounter()
89 | model.register_forward_pre_hook(forward_counter)
90 | if LooseVersion(torch.__version__) >= LooseVersion('1.8'):
91 | model.register_full_backward_hook(backward_counter)
92 | else:
93 | model.register_backward_hook(backward_counter)
94 | average_forwards, average_backwards = [], [] # number of forward and backward calls per sample
95 | advs_chunks = []
96 | chunks = [tensor.split(batch_size) for tensor in [inputs[not_adv], adv_labels[not_adv]]]
97 | total_time = 0
98 | for (inputs_chunk, label_chunk) in tqdm.tqdm(zip(*chunks), ncols=80, total=len(chunks[0])):
99 | batch_chunk_d, label_chunk_d = [to_device(tensor.clone()) for tensor in [inputs_chunk, label_chunk]]
100 |
101 | start.record()
102 | advs_chunk_d = attack(model, batch_chunk_d, label_chunk_d, targeted=targeted)
103 |
104 | # performance monitoring
105 | end.record()
106 | torch.cuda.synchronize()
107 | total_time += (start.elapsed_time(end)) / 1000 # times for cuda Events are in milliseconds
108 | average_forwards.append(forward_counter.num_samples_called / len(batch_chunk_d))
109 | average_backwards.append(backward_counter.num_samples_called / len(batch_chunk_d))
110 | forward_counter.reset(), backward_counter.reset()
111 |
112 | advs_chunks.append(advs_chunk_d.cpu())
113 | if isinstance(attack, partial) and (callback := attack.keywords.get('callback')) is not None:
114 | callback.reset_windows()
115 |
116 | adv_inputs = inputs.clone()
117 | adv_inputs[not_adv] = torch.cat(advs_chunks, 0)
118 |
119 | data = {
120 | 'inputs': inputs,
121 | 'labels': labels,
122 | 'targets': adv_labels if targeted else None,
123 | 'adv_inputs': adv_inputs,
124 | 'time': total_time,
125 | 'num_forwards': sum(average_forwards) / len(chunks[0]),
126 | 'num_backwards': sum(average_backwards) / len(chunks[0]),
127 | }
128 |
129 | return data
130 |
131 |
132 | _default_metrics = OrderedDict([
133 | ('linf', linf_distances),
134 | ('l0', l0_distances),
135 | ('l1', l1_distances),
136 | ('l2', l2_distances),
137 | ])
138 |
139 |
140 | def compute_attack_metrics(model: nn.Module,
141 | attack_data: Dict[str, Union[Tensor, float]],
142 | batch_size: Optional[int] = None,
143 | metrics: Dict[str, Callable] = _default_metrics) -> Dict[str, Union[Tensor, float]]:
144 | inputs, labels, targets, adv_inputs = map(attack_data.get, ['inputs', 'labels', 'targets', 'adv_inputs'])
145 | if adv_inputs.min() < 0 or adv_inputs.max() > 1:
146 | warnings.warn('Values of produced adversarials are not in the [0, 1] range -> Clipping to [0, 1].')
147 | adv_inputs.clamp_(min=0, max=1)
148 | device = next(model.parameters()).device
149 | to_device = lambda tensor: tensor.to(device)
150 |
151 | batch_size = batch_size or len(inputs)
152 | chunks = [tensor.split(batch_size) for tensor in [inputs, labels, adv_inputs]]
153 | all_predictions = [[] for _ in range(6)]
154 | distances = {k: [] for k in metrics.keys()}
155 | metrics = {k: v().to(device) if (isclass(v.func) if isinstance(v, partial) else False) else v for k, v in
156 | metrics.items()}
157 |
158 | append = lambda list, data: list.append(data.cpu())
159 | for inputs_chunk, labels_chunk, adv_chunk in zip(*chunks):
160 | inputs_chunk, adv_chunk = map(to_device, [inputs_chunk, adv_chunk])
161 | clean_preds, adv_preds = [predict_inputs(model, chunk.to(device)) for chunk in [inputs_chunk, adv_chunk]]
162 | list(map(append, all_predictions, [*clean_preds, *adv_preds]))
163 | for metric, metric_func in metrics.items():
164 | distances[metric].append(metric_func(adv_chunk, inputs_chunk).detach().cpu())
165 |
166 | logits, probs, preds, logits_adv, probs_adv, preds_adv = [torch.cat(l) for l in all_predictions]
167 | for metric in metrics.keys():
168 | distances[metric] = torch.cat(distances[metric], 0)
169 |
170 | accuracy_orig = (preds == labels).float().mean().item()
171 | if targets is not None:
172 | success = (preds_adv == targets)
173 | labels = targets
174 | else:
175 | success = (preds_adv != labels)
176 |
177 | prob_orig = probs.gather(1, labels.unsqueeze(1)).squeeze(1)
178 | prob_adv = probs_adv.gather(1, labels.unsqueeze(1)).squeeze(1)
179 | labels_infhot = torch.zeros_like(logits_adv).scatter_(1, labels.unsqueeze(1), float('inf'))
180 | real = logits_adv.gather(1, labels.unsqueeze(1)).squeeze(1)
181 | other = (logits_adv - labels_infhot).amax(dim=1)
182 | diff_vs_max_adv = (real - other)
183 | nll = F.cross_entropy(logits, labels, reduction='none')
184 | nll_adv = F.cross_entropy(logits_adv, labels, reduction='none')
185 |
186 | data = {
187 | 'time': attack_data['time'],
188 | 'num_forwards': attack_data['num_forwards'],
189 | 'num_backwards': attack_data['num_backwards'],
190 | 'targeted': targets is not None,
191 | 'preds': preds,
192 | 'adv_preds': preds_adv,
193 | 'accuracy_orig': accuracy_orig,
194 | 'success': success,
195 | 'probs_orig': prob_orig,
196 | 'probs_adv': prob_adv,
197 | 'logit_diff_adv': diff_vs_max_adv,
198 | 'nll': nll,
199 | 'nll_adv': nll_adv,
200 | 'distances': distances,
201 | }
202 |
203 | return data
204 |
205 |
206 | def print_metrics(metrics: dict) -> None:
207 | np.set_printoptions(formatter={'float': '{:0.3f}'.format}, threshold=16, edgeitems=3,
208 | linewidth=120) # To print arrays with less precision
209 | print('Original accuracy: {:.2%}'.format(metrics['accuracy_orig']))
210 | print('Attack done in: {:.2f}s with {:.4g} forwards and {:.4g} backwards.'.format(
211 | metrics['time'], metrics['num_forwards'], metrics['num_backwards']))
212 | success = metrics['success'].numpy()
213 | fail = bool(success.mean() != 1)
214 | print('Attack success: {:.2%}'.format(success.mean()) + fail * ' - {}'.format(success))
215 | for distance, values in metrics['distances'].items():
216 | data = values.numpy()
217 | print('{}: {} - Average: {:.3f} - Median: {:.3f}'.format(distance, data, data.mean(), np.median(data)) +
218 | fail * ' | Avg over success: {:.3f}'.format(data[success].mean()))
219 | attack_type = 'targets' if metrics['targeted'] else 'correct'
220 | print('Logit({} class) - max_Logit(other classes): {} - Average: {:.2f}'.format(
221 | attack_type, metrics['logit_diff_adv'].numpy(), metrics['logit_diff_adv'].numpy().mean()))
222 | print('NLL of target/pred class: {:.3f}'.format(metrics['nll_adv'].numpy().mean()))
223 |
--------------------------------------------------------------------------------
/adv_lib/utils/color_conversions.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import Tensor
3 |
4 | _ycbcr_conversions = {
5 | 'rec_601': (0.299, 0.587, 0.114),
6 | 'rec_709': (0.2126, 0.7152, 0.0722),
7 | 'rec_2020': (0.2627, 0.678, 0.0593),
8 | 'smpte_240m': (0.212, 0.701, 0.087),
9 | }
10 |
11 |
12 | def rgb_to_ycbcr(input: Tensor, standard: str = 'rec_2020'):
13 | kr, kg, kb = _ycbcr_conversions[standard]
14 | conversion_matrix = torch.tensor([[kr, kg, kb],
15 | [-0.5 * kr / (1 - kb), -0.5 * kg / (1 - kb), 0.5],
16 | [0.5, -0.5 * kg / (1 - kr), -0.5 * kb / (1 - kr)]], device=input.device)
17 | return torch.einsum('mc,nchw->nmhw', conversion_matrix, input)
18 |
19 |
20 | def ycbcr_to_rgb(input: Tensor, standard: str = 'rec_2020'):
21 | kr, kg, kb = _ycbcr_conversions[standard]
22 | conversion_matrix = torch.tensor([[1, 0, 2 - 2 * kr],
23 | [1, -kb / kg * (2 - 2 * kb), -kr / kg * (2 - 2 * kr)],
24 | [1, 2 - 2 * kb, 0]], device=input.device)
25 | return torch.einsum('mc,nchw->nmhw', conversion_matrix, input)
26 |
27 |
28 | _xyz_conversions = {
29 | 'CIE_RGB': ((0.4887180, 0.3106803, 0.2006017),
30 | (0.1762044, 0.8129847, 0.0108109),
31 | (0.0000000, 0.0102048, 0.9897952)),
32 | 'sRGB': ((0.4124564, 0.3575761, 0.1804375),
33 | (0.2126729, 0.7151522, 0.0721750),
34 | (0.0193339, 0.1191920, 0.9503041))
35 | }
36 |
37 |
38 | def rgb_to_xyz(input: Tensor, rgb_space: str = 'sRGB'):
39 | conversion_matrix = torch.tensor(_xyz_conversions[rgb_space], device=input.device)
40 | # Inverse sRGB companding
41 | v = torch.where(input <= 0.04045, input / 12.92, ((input + 0.055) / 1.055) ** 2.4)
42 | return torch.einsum('mc,nchw->nmhw', conversion_matrix, v)
43 |
44 |
45 | _delta = 6 / 29
46 |
47 |
48 | def cielab_func(input: Tensor) -> Tensor:
49 | # torch.where produces NaNs in backward if one of the choice produces NaNs or infs in backward (here .pow(1/3))
50 | return torch.where(input > _delta ** 3, input.clamp(min=_delta ** 3).pow(1 / 3), input / (3 * _delta ** 2) + 4 / 29)
51 |
52 |
53 | def cielab_inverse_func(input: Tensor) -> Tensor:
54 | return torch.where(input > _delta, input.pow(3), 3 * _delta ** 2 * (input - 4 / 29))
55 |
56 |
57 | _cielab_conversions = {
58 | 'illuminant_d50': (96.4212, 100, 82.5188),
59 | 'illuminant_d65': (95.0489, 100, 108.884),
60 | }
61 |
62 |
63 | def rgb_to_cielab(input: Tensor, standard: str = 'illuminant_d65') -> Tensor:
64 | # Convert to XYZ
65 | XYZ_input = rgb_to_xyz(input=input)
66 |
67 | Xn, Yn, Zn = _cielab_conversions[standard]
68 | L_star = 116 * cielab_func(XYZ_input.narrow(1, 1, 1) / Yn) - 16
69 | a_star = 500 * (cielab_func(XYZ_input.narrow(1, 0, 1) / Xn) - cielab_func(XYZ_input.narrow(1, 1, 1) / Yn))
70 | b_star = 200 * (cielab_func(XYZ_input.narrow(1, 1, 1) / Yn) - cielab_func(XYZ_input.narrow(1, 2, 1) / Zn))
71 | return torch.cat((L_star, a_star, b_star), 1)
72 |
--------------------------------------------------------------------------------
/adv_lib/utils/image_selection.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple
2 |
3 | import torch
4 | import torch.nn as nn
5 | from torch import Tensor
6 | from torch.utils.data import Dataset, DataLoader
7 |
8 |
9 | def select_images(model: nn.Module, dataset: Dataset, num_images: int, correct_only: bool = False,
10 | random: bool = False) -> Tuple[Tensor, Tensor]:
11 | device = next(model.parameters()).device
12 | loader = DataLoader(dataset=dataset, batch_size=1, shuffle=random)
13 | selected_images, selected_labels = [], []
14 | for (image, label) in loader:
15 | if correct_only:
16 | correct = model(image.to(device)).argmax(1) == label.to(device)
17 | if correct.all():
18 | selected_images.append(image), selected_labels.append(label)
19 | elif not correct_only:
20 | selected_images.append(image), selected_labels.append(label)
21 |
22 | if len(selected_images) == num_images:
23 | break
24 | else:
25 | print('Could only find {} correctly classified images.'.format(len(selected_images)))
26 |
27 | return torch.cat(selected_images, 0), torch.cat(selected_labels, 0)
28 |
--------------------------------------------------------------------------------
/adv_lib/utils/lagrangian_penalties/__init__.py:
--------------------------------------------------------------------------------
1 | from .all_penalties import all_penalties
--------------------------------------------------------------------------------
/adv_lib/utils/lagrangian_penalties/all_penalties.py:
--------------------------------------------------------------------------------
1 | from .penalty_functions import *
2 | from .univariate_functions import *
3 |
4 | univariates_P4 = {
5 | 'Quad': Quadratic,
6 | 'FourThirds': FourThirds,
7 | 'Cosh': Cosh,
8 | }
9 |
10 | univariates_P5_P6_P7 = {
11 | 'Exp': Exp,
12 | 'LogExp': LogExp,
13 | 'LogQuad_1': LogQuad_1,
14 | 'LogQuad_2': LogQuad_2,
15 | 'HyperExp': HyperExp,
16 | 'HyperQuad': HyperQuad,
17 | 'DualLogQuad': DualLogQuad,
18 | 'CubicQuad': CubicQuad,
19 | 'ExpQuad': ExpQuad,
20 | 'LogBarrierQuad': LogBarrierQuad,
21 | 'HyperBarrierQuad': HyperBarrierQuad,
22 | 'HyperLogQuad': HyperLogQuad,
23 | 'SmoothPlus': SmoothPlus,
24 | 'NNSmoothPlus': NNSmoothPlus,
25 | 'ExpSmoothPlus': ExpSmoothPlus,
26 | }
27 |
28 | univariates_P8 = {
29 | 'LogExp': LogExp,
30 | 'LogQuad_1': LogQuad_1,
31 | 'LogQuad_2': LogQuad_2,
32 | 'HyperExp': HyperExp,
33 | 'HyperQuad': HyperQuad,
34 | 'DualLogQuad': DualLogQuad,
35 | 'CubicQuad': CubicQuad,
36 | 'ExpQuad': ExpQuad,
37 | 'LogBarrierQuad': LogBarrierQuad,
38 | 'HyperBarrierQuad': HyperBarrierQuad,
39 | 'HyperLogQuad': HyperLogQuad,
40 | }
41 |
42 | univariates_P9 = {
43 | 'SmoothPlus': SmoothPlus,
44 | 'NNSmoothPlus': NNSmoothPlus,
45 | 'ExpSmoothPlus': ExpSmoothPlus,
46 | }
47 |
48 | combinations = {
49 | 'PHRQuad': PHRQuad,
50 | 'P1': P1,
51 | 'P2': P2,
52 | 'P3': P3,
53 | 'P4': (P4, univariates_P4),
54 | 'P5': (P5, univariates_P5_P6_P7),
55 | 'P6': (P6, univariates_P5_P6_P7),
56 | 'P7': (P7, univariates_P5_P6_P7),
57 | 'P8': (P8, univariates_P8),
58 | 'P9': (P9, univariates_P9),
59 | }
60 |
61 | all_penalties = {}
62 | for p_name, penalty in combinations.items():
63 | if isinstance(penalty, tuple):
64 | for θ_name, θ in penalty[1].items():
65 | all_penalties['_'.join([p_name, θ_name])] = penalty[0](θ())
66 | else:
67 | all_penalties[p_name] = penalty
68 |
--------------------------------------------------------------------------------
/adv_lib/utils/lagrangian_penalties/penalty_functions.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import Tensor
3 |
4 | __all__ = [
5 | 'PHRQuad',
6 | 'P1',
7 | 'P2',
8 | 'P3',
9 | 'P4',
10 | 'P5',
11 | 'P6',
12 | 'P7',
13 | 'P8',
14 | 'P9'
15 | ]
16 |
17 |
18 | def PHRQuad(y: Tensor, ρ: Tensor, μ: Tensor) -> Tensor:
19 | return ((μ + ρ * y).relu().square() - μ ** 2).div(2 * ρ)
20 |
21 |
22 | def P1(y: Tensor, ρ: Tensor, μ: Tensor) -> Tensor:
23 | y_sup = μ * y + 0.5 * ρ * y ** 2 + ρ ** 2 * y ** 3
24 | y_mid = μ * y + 0.5 * ρ * y ** 2
25 | y_inf = - μ ** 2 / (2 * ρ)
26 | return torch.where(y >= 0, y_sup, torch.where(y <= -μ / ρ, y_inf, y_mid))
27 |
28 |
29 | def P2(y: Tensor, ρ: Tensor, μ: Tensor) -> Tensor:
30 | y_sup = μ * y + μ * ρ * y ** 2 + 1 / 6 * ρ ** 2 * y ** 3
31 | y_inf = μ * y / (1 - ρ * y.clamp(max=0))
32 | return torch.where(y >= 0, y_sup, y_inf)
33 |
34 |
35 | def P3(y: Tensor, ρ: Tensor, μ: Tensor) -> Tensor:
36 | y_sup = μ * y + μ * ρ * y ** 2
37 | y_inf = μ * y / (1 - ρ * y.clamp(max=0))
38 | return torch.where(y >= 0, y_sup, y_inf)
39 |
40 |
41 | class GenericPenaltyLagrangian:
42 | def __init__(self, θ):
43 | self.θ = θ # univariate function
44 |
45 |
46 | class P4(GenericPenaltyLagrangian):
47 | def __call__(self, y: Tensor, ρ: Tensor, μ: Tensor) -> Tensor:
48 | y_sup = μ * y + self.θ(ρ * y) / ρ
49 | y_inf = self.θ.min(ρ, μ)
50 | return torch.where(self.θ.sup(y, ρ, μ), y_sup, y_inf)
51 |
52 |
53 | class P5(GenericPenaltyLagrangian):
54 | def __call__(self, y: Tensor, ρ: Tensor, μ: Tensor) -> Tensor:
55 | return self.θ.mul * self.θ(ρ * y) * μ / ρ
56 |
57 |
58 | class P6(GenericPenaltyLagrangian):
59 | def __call__(self, y: Tensor, ρ: Tensor, μ: Tensor) -> Tensor:
60 | return self.θ.mul * self.θ(ρ * μ * y) / ρ
61 |
62 |
63 | class P7(GenericPenaltyLagrangian):
64 | def __call__(self, y: Tensor, ρ: Tensor, μ: Tensor) -> Tensor:
65 | return self.θ.mul * self.θ(ρ * y / μ) * (μ ** 2) / ρ
66 |
67 |
68 | class P8(GenericPenaltyLagrangian):
69 | def __call__(self, y: Tensor, ρ: Tensor, μ: Tensor) -> Tensor:
70 | tilde_x = self.θ.tilde(μ)
71 | return (self.θ(ρ * y + tilde_x) - self.θ(tilde_x)) / ρ
72 |
73 |
74 | class P9(GenericPenaltyLagrangian):
75 | # Penalty-Lagrangian functions associated with P9 are not well defined for ρ <= μ, so we set ρ = max({ρ, 2μ})
76 | def __call__(self, y: Tensor, ρ: Tensor, μ: Tensor) -> Tensor:
77 | ρ_adjusted = torch.max(ρ, 2 * μ)
78 | tilde_x = self.θ.tilde(μ, ρ_adjusted)
79 | return self.θ(ρ_adjusted * y + tilde_x) - self.θ(tilde_x)
80 |
--------------------------------------------------------------------------------
/adv_lib/utils/lagrangian_penalties/scripts/plot_penalties.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 | import torch
3 | from cycler import cycler
4 | from torch.autograd import grad
5 |
6 | from adv_lib.utils.lagrangian_penalties.all_penalties import all_penalties
7 |
8 | fig, ax = plt.subplots(figsize=(12, 12))
9 | x = torch.linspace(-10, 5, 15001, requires_grad=True)
10 | x[10000] = 0
11 |
12 | styles = (cycler(linestyle=['-', '--', '-.']) * cycler(color=plt.rcParams['axes.prop_cycle']))
13 | ax.set_prop_cycle(styles)
14 |
15 | unique_penalties_idxs = list(range(len(all_penalties)))
16 |
17 | for i in unique_penalties_idxs:
18 | penalty = list(all_penalties.keys())[i]
19 | ρ = torch.tensor(1.)
20 | μ = torch.tensor(1.)
21 | y = all_penalties[penalty](x, ρ, μ)
22 | if y.isnan().any():
23 | print('nan in y for {}'.format(penalty))
24 | grads = grad(y.sum(), x, only_inputs=True)[0]
25 | if grads.isnan().any():
26 | print('nan in grad for {}'.format(penalty))
27 | if not torch.allclose(grads[10000], μ):
28 | print("P'(0, ρ, μ) = {:3g} = μ".format(grads[10000].item()))
29 | ax.plot(x.detach().numpy(), y.detach().numpy(),
30 | label=r'{}: $\nabla P(0)$:{:.3g}'.format(penalty, grads[10000].item()))
31 |
32 | ax.set_xlim(-10, 5)
33 | ax.set_ylim(-5, 10)
34 | ax.legend(loc=2, prop={'size': 6})
35 | ax.set_aspect('equal')
36 | ax.set_xlabel(r'$x$')
37 | ax.set_ylabel(r'$Penalty(x)$')
38 | ax.grid(True, linestyle='--')
39 |
40 | plt.tight_layout()
41 | plt.show()
42 |
--------------------------------------------------------------------------------
/adv_lib/utils/lagrangian_penalties/scripts/plot_univariates.py:
--------------------------------------------------------------------------------
1 | from inspect import isfunction
2 |
3 | import matplotlib.pyplot as plt
4 | import torch
5 | from cycler import cycler
6 |
7 | from adv_lib.utils.lagrangian_penalties import univariate_functions
8 |
9 | fig, ax = plt.subplots(figsize=(8, 8))
10 | x = torch.linspace(-10, 5, 15001)
11 |
12 | styles = (cycler(linestyle=['-', '--']) * cycler(color=plt.rcParams['axes.prop_cycle']))
13 | ax.set_prop_cycle(styles)
14 |
15 | for univariate in univariate_functions.__all__:
16 | if isfunction(univariate_functions.__dict__[univariate]):
17 | univariate_function = univariate_functions.__dict__[univariate]
18 | else:
19 | univariate_function = univariate_functions.__dict__[univariate]()
20 | y = univariate_function(x)
21 | ax.plot(x, y, label=univariate)
22 |
23 | ax.set_xlim(-10, 5)
24 | ax.set_ylim(-5, 10)
25 | ax.legend(loc=2)
26 | ax.set_aspect('equal')
27 | ax.set_xlabel(r'$x$')
28 | ax.set_ylabel(r'$Univariate(x)$')
29 | ax.grid(True, linestyle='--')
30 |
31 | plt.tight_layout()
32 | plt.show()
33 |
--------------------------------------------------------------------------------
/adv_lib/utils/lagrangian_penalties/univariate_functions.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 | from torch import Tensor
5 | from torch.nn import functional as F
6 |
7 | __all__ = [
8 | 'Quadratic',
9 | 'FourThirds',
10 | 'Cosh',
11 | 'Exp',
12 | 'LogExp',
13 | 'LogQuad_1',
14 | 'LogQuad_2',
15 | 'HyperExp',
16 | 'HyperQuad',
17 | 'DualLogQuad',
18 | 'CubicQuad',
19 | 'ExpQuad',
20 | 'LogBarrierQuad',
21 | 'HyperBarrierQuad',
22 | 'HyperLogQuad',
23 | 'SmoothPlus',
24 | 'NNSmoothPlus',
25 | 'ExpSmoothPlus',
26 | ]
27 |
28 |
29 | def safe_exp(x: Tensor) -> Tensor:
30 | return torch.exp(x.clamp(max=87.5))
31 |
32 |
33 | class Quadratic:
34 | @staticmethod
35 | def __call__(t: Tensor) -> Tensor:
36 | return 0.5 * t ** 2
37 |
38 | @staticmethod
39 | def min(ρ: Tensor, μ: Tensor) -> Tensor:
40 | return - μ ** 2 / (2 * ρ)
41 |
42 | @staticmethod
43 | def sup(t: Tensor, ρ: Tensor, μ: Tensor) -> Tensor:
44 | return μ + ρ * t >= 0
45 |
46 |
47 | class FourThirds:
48 | @staticmethod
49 | def __call__(t: Tensor) -> Tensor:
50 | return 0.75 * t.abs().pow(4 / 3)
51 |
52 | @staticmethod
53 | def min(ρ: Tensor, μ: Tensor) -> Tensor:
54 | return - μ ** 4 / (4 * ρ)
55 |
56 | @staticmethod
57 | def sup(t: Tensor, ρ: Tensor, μ: Tensor) -> Tensor:
58 | return ρ.pow(1 / 3) * t.sign() * t.abs().pow(1 / 3) + μ >= 0
59 |
60 |
61 | class Cosh:
62 | @staticmethod
63 | def __call__(t: Tensor) -> Tensor:
64 | return t.cosh() - 1
65 |
66 | @staticmethod
67 | def min(ρ: Tensor, μ: Tensor) -> Tensor:
68 | return μ / ρ * torch.asinh(-μ) + 1 / ρ * (torch.cosh(torch.asinh(-μ)) - 1)
69 |
70 | @staticmethod
71 | def sup(t: Tensor, ρ: Tensor, μ: Tensor) -> Tensor:
72 | return torch.sinh(μ + (ρ * t)) >= 0
73 |
74 |
75 | class Exp:
76 | mul = 1
77 |
78 | @staticmethod
79 | def __call__(t: Tensor) -> Tensor:
80 | return safe_exp(t) - 1
81 |
82 | @staticmethod
83 | def tilde(μ: Tensor) -> Tensor:
84 | return torch.log(μ)
85 |
86 |
87 | class LogExp:
88 | mul = 1
89 |
90 | @staticmethod
91 | def __call__(t: Tensor) -> Tensor:
92 | y_sup = safe_exp(2 * t - 1) + math.log(2) - 1
93 | y_inf = -torch.log(1 - t.clamp(max=0.5))
94 | return torch.where(t >= 0.5, y_sup, y_inf)
95 |
96 | @staticmethod
97 | def tilde(μ: Tensor) -> Tensor:
98 | y_sup = 0.5 * (torch.log(μ / 2) + 1)
99 | y_inf = (μ - 1) / μ
100 | return torch.where(μ >= 2, y_sup, y_inf)
101 |
102 |
103 | class LogQuad_1:
104 | mul = 1
105 |
106 | @staticmethod
107 | def __call__(t: Tensor) -> Tensor:
108 | y_sup = 2 * t ** 2 + math.log(2) - 0.5
109 | y_inf = -torch.log(1 - t.clamp(max=0.5))
110 | return torch.where(t >= 0.5, y_sup, y_inf)
111 |
112 | @staticmethod
113 | def tilde(μ: Tensor) -> Tensor:
114 | y_sup = μ / 4
115 | y_inf = (μ - 1) / μ
116 | return torch.where(μ >= 2, y_sup, y_inf)
117 |
118 |
119 | class LogQuad_2:
120 | mul = 1
121 |
122 | @staticmethod
123 | def __call__(t: Tensor) -> Tensor:
124 | y_sup = 0.5 * t ** 2 + t
125 | y_inf = -0.25 * torch.log(-2 * t.clamp(max=-0.5)) - 0.375
126 | return torch.where(t >= -0.5, y_sup, y_inf)
127 |
128 | @staticmethod
129 | def tilde(μ: Tensor) -> Tensor:
130 | y_sup = μ - 1
131 | y_inf = -1 / (4 * μ)
132 | return torch.where(μ >= 0.5, y_sup, y_inf)
133 |
134 |
135 | class HyperExp:
136 | mul = 1
137 |
138 | @staticmethod
139 | def __call__(t: Tensor) -> Tensor:
140 | y_sup = safe_exp(4 * t - 2)
141 | y_inf = t / (1 - t.clamp(max=0.5))
142 | return torch.where(t >= 0.5, y_sup, y_inf)
143 |
144 | @staticmethod
145 | def tilde(μ: Tensor) -> Tensor:
146 | y_sup = torch.log(μ / 4) / 4 + 0.5
147 | y_inf = 1 - 1 / torch.sqrt(μ)
148 | return torch.where(μ >= 4, y_sup, y_inf)
149 |
150 |
151 | class HyperQuad:
152 | mul = 1
153 |
154 | @staticmethod
155 | def __call__(t: Tensor) -> Tensor:
156 | y_sup = 8 * t ** 2 - 4 * t + 1
157 | y_inf = t / (1 - t.clamp(max=0.5))
158 | return torch.where(t >= 0.5, y_sup, y_inf)
159 |
160 | @staticmethod
161 | def tilde(μ: Tensor) -> Tensor:
162 | y_sup = (μ + 4) / 16
163 | y_inf = 1 - 1 / torch.sqrt(μ)
164 | return torch.where(μ >= 4, y_sup, y_inf)
165 |
166 |
167 | class DualLogQuad:
168 | mul = 1
169 |
170 | @staticmethod
171 | def __call__(t: Tensor) -> Tensor:
172 | return (1 + t + torch.sqrt((1 + t) ** 2 + 8)) ** 2 / 16 + torch.log(
173 | 0.25 * (1 + t + torch.sqrt((1 + t) ** 2 + 8))) - 1
174 |
175 | @staticmethod
176 | def tilde(μ: Tensor) -> Tensor:
177 | return 2 * μ - 1 / μ - 1
178 |
179 |
180 | class CubicQuad:
181 | mul = 8
182 |
183 | @staticmethod
184 | def __call__(t: Tensor) -> Tensor:
185 | y_sup = 0.5 * t ** 2
186 | y_inf = 1 / 6 * (t + 0.5).clamp(min=0) ** 3 - 1 / 24
187 | return torch.where(t >= 0.5, y_sup, y_inf)
188 |
189 | @staticmethod
190 | def tilde(μ: Tensor) -> Tensor:
191 | y_sup = μ
192 | y_inf = torch.sqrt(2 * μ) - 0.5
193 | return torch.where(μ >= 0.5, y_sup, y_inf)
194 |
195 |
196 | class ExpQuad:
197 | mul = 1
198 |
199 | @staticmethod
200 | def __call__(t: Tensor) -> Tensor:
201 | y_sup = math.exp(0.5) * (0.5 * t ** 2 + 0.5 * t + 0.625)
202 | y_inf = safe_exp(t)
203 | return torch.where(t >= 0.5, y_sup, y_inf)
204 |
205 | @staticmethod
206 | def tilde(μ: Tensor) -> Tensor:
207 | y_sup = μ * math.exp(-0.5) - 0.5
208 | y_inf = torch.log(μ)
209 | return torch.where(μ >= math.exp(0.5), y_sup, y_inf)
210 |
211 |
212 | class LogBarrierQuad:
213 | mul = 0.25
214 |
215 | @staticmethod
216 | def __call__(t: Tensor) -> Tensor:
217 | y_sup = 2 * t ** 2 + 4 * t + 0.5 + math.log(2)
218 | y_inf = -torch.log(-t.clamp(max=-0.5)) - 1
219 | return torch.where(t >= -0.5, y_sup, y_inf)
220 |
221 | @staticmethod
222 | def tilde(μ: Tensor) -> Tensor:
223 | y_sup = (μ - 4) / 4
224 | y_inf = -1 / μ
225 | return torch.where(μ >= 2, y_sup, y_inf)
226 |
227 |
228 | class HyperBarrierQuad:
229 | mul = 1 / 12
230 |
231 | @staticmethod
232 | def __call__(t: Tensor) -> Tensor:
233 | y_sup = 8 * t ** 2 + 12 * t + 6
234 | y_inf = -1 / t.clamp(max=-0.5)
235 | return torch.where(t >= -0.5, y_sup, y_inf)
236 |
237 | @staticmethod
238 | def tilde(μ: Tensor) -> Tensor:
239 | y_sup = (μ - 12) / 16
240 | y_inf = -1 / torch.sqrt(μ)
241 | return torch.where(μ >= 4, y_sup, y_inf)
242 |
243 |
244 | class HyperLogQuad:
245 | mul = 0.125
246 |
247 | @staticmethod
248 | def __call__(t: Tensor) -> Tensor:
249 | y_sup = 8 * t ** 2 + 8 * t + 1.5 + 2 * math.log(2)
250 | y_mid = -torch.log(-t.clamp(-1, -0.25))
251 | y_inf = 4 / (1 - t.clamp(max=-1)) - 2
252 | return torch.where(t >= -0.25, y_sup, torch.where(t <= -1, y_inf, y_mid))
253 |
254 | @staticmethod
255 | def tilde(μ: Tensor) -> Tensor:
256 | y_sup = (μ - 8) / 16
257 | y_mid = -1 / μ
258 | y_inf = 1 - 2 / torch.sqrt(μ)
259 | return torch.where(μ >= 4, y_sup, torch.where(μ <= 1, y_inf, y_mid))
260 |
261 |
262 | class SmoothPlus:
263 | mul = 2
264 |
265 | @staticmethod
266 | def __call__(t: Tensor) -> Tensor:
267 | return 0.5 * (t + torch.sqrt(t ** 2 + 4))
268 |
269 | @staticmethod
270 | def tilde(μ: Tensor, ρ: Tensor) -> Tensor:
271 | return (2 * μ - ρ) / torch.sqrt(μ * ρ - μ ** 2)
272 |
273 |
274 | class NNSmoothPlus:
275 | mul = 2
276 |
277 | @staticmethod
278 | def __call__(t: Tensor) -> Tensor:
279 | return F.softplus(t)
280 |
281 | @staticmethod
282 | def tilde(μ: Tensor, ρ: Tensor) -> Tensor:
283 | return torch.log(μ / (ρ - μ))
284 |
285 |
286 | class ExpSmoothPlus:
287 | mul = 2
288 |
289 | @staticmethod
290 | def __call__(t: Tensor) -> Tensor:
291 | y_sup = t + 0.5 * safe_exp(-t)
292 | y_inf = 0.5 * safe_exp(t)
293 | return torch.where(t >= 0, y_sup, y_inf)
294 |
295 | @staticmethod
296 | def tilde(μ: Tensor, ρ: Tensor) -> Tensor:
297 | y_sup = torch.log(ρ / (2 * (ρ - μ)))
298 | y_inf = torch.log(2 * μ / ρ)
299 | return torch.where(μ >= 0.5 * ρ, y_sup, y_inf)
300 |
--------------------------------------------------------------------------------
/adv_lib/utils/losses.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | import torch
4 | from torch import Tensor
5 |
6 |
7 | def difference_of_logits(logits: Tensor, labels: Tensor, labels_infhot: Optional[Tensor] = None) -> Tensor:
8 | if labels_infhot is None:
9 | labels_infhot = torch.zeros_like(logits).scatter_(1, labels.unsqueeze(1), float('inf'))
10 |
11 | class_logits = logits.gather(1, labels.unsqueeze(1)).squeeze(1)
12 | other_logits = (logits - labels_infhot).amax(dim=1)
13 | return class_logits - other_logits
14 |
15 |
16 | def difference_of_logits_ratio(logits: Tensor, labels: Tensor, labels_infhot: Optional[Tensor] = None,
17 | targeted: bool = False, ε: float = 0) -> Tensor:
18 | """Difference of Logits Ratio from https://arxiv.org/abs/2003.01690. This version is modified such that the DLR is
19 | always positive if argmax(logits) == labels"""
20 | logit_dists = difference_of_logits(logits=logits, labels=labels, labels_infhot=labels_infhot)
21 |
22 | if targeted:
23 | top4_logits = torch.topk(logits, k=4, dim=1).values
24 | logit_normalization = top4_logits[:, 0] - (top4_logits[:, -2] + top4_logits[:, -1]) / 2
25 | else:
26 | top3_logits = torch.topk(logits, k=3, dim=1).values
27 | logit_normalization = top3_logits[:, 0] - top3_logits[:, -1]
28 |
29 | return (logit_dists + ε) / (logit_normalization + 1e-8)
30 |
--------------------------------------------------------------------------------
/adv_lib/utils/projections.py:
--------------------------------------------------------------------------------
1 | from distutils.version import LooseVersion
2 | from typing import Union
3 |
4 | import torch
5 | from torch import Tensor
6 |
7 | use_tensors_in_clamp = False
8 | if LooseVersion(torch.__version__) >= LooseVersion('1.9'):
9 | use_tensors_in_clamp = True
10 |
11 |
12 | @torch.no_grad()
13 | def clamp(x: Tensor, lower: Tensor, upper: Tensor, inplace: bool = False) -> Tensor:
14 | """Clamp based on lower and upper Tensor bounds. Clamping method depends on torch version: clamping with tensors was
15 | introduced in torch 1.9."""
16 | δ_clamped = x if inplace else None
17 | if use_tensors_in_clamp:
18 | δ_clamped = torch.clamp(x, min=lower, max=upper, out=δ_clamped)
19 | else:
20 | δ_clamped = torch.maximum(x, lower, out=δ_clamped)
21 | δ_clamped = torch.minimum(δ_clamped, upper, out=δ_clamped)
22 | return δ_clamped
23 |
24 |
25 | def clamp_(x: Tensor, lower: Tensor, upper: Tensor) -> Tensor:
26 | """In-place alias for clamp."""
27 | return clamp(x=x, lower=lower, upper=upper, inplace=True)
28 |
29 |
30 | def simplex_projection(x: Tensor, ε: Union[float, Tensor] = 1, inplace: bool = False) -> Tensor:
31 | """
32 | Simplex projection based on sorting.
33 |
34 | Parameters
35 | ----------
36 | x : Tensor
37 | Batch of vectors to project on the simplex.
38 | ε : float or Tensor
39 | Size of the simplex, default to 1 for the probability simplex.
40 | inplace : bool
41 | Can optionally do the operation in-place.
42 |
43 | Returns
44 | -------
45 | projected_x : Tensor
46 | Batch of projected vectors on the simplex.
47 | """
48 | u = x.sort(dim=1, descending=True)[0]
49 | ε = ε.unsqueeze(1) if isinstance(ε, Tensor) else x.new_full((), ε)
50 | indices = torch.arange(x.size(1), device=x.device)
51 | cumsum = torch.cumsum(u, dim=1).sub_(ε).div_(indices + 1)
52 | K = (cumsum < u).long().mul_(indices).amax(dim=1, keepdim=True)
53 | τ = cumsum.gather(1, K)
54 | x = x.sub_(τ) if inplace else x - τ
55 | return x.clamp_(min=0)
56 |
57 |
58 | def l1_ball_euclidean_projection(x: Tensor, ε: Union[float, Tensor], inplace: bool = False) -> Tensor:
59 | """
60 | Compute Euclidean projection onto the L1 ball for a batch.
61 |
62 | min ||x - u||_2 s.t. ||u||_1 <= eps
63 |
64 | Inspired by the corresponding numpy version by Adrien Gaidon.
65 | Adapted from Tony Duan's implementation https://gist.github.com/tonyduan/1329998205d88c566588e57e3e2c0c55
66 |
67 | Parameters
68 | ----------
69 | x: Tensor
70 | Batch of tensors to project.
71 | ε: float or Tensor
72 | Radius of L1-ball to project onto. Can be a single value for all tensors in the batch or a batch of values.
73 | inplace : bool
74 | Can optionally do the operation in-place.
75 |
76 | Returns
77 | -------
78 | projected_x: Tensor
79 | Batch of projected tensors with the same shape as x.
80 |
81 | Notes
82 | -----
83 | The complexity of this algorithm is in O(dlogd) as it involves sorting x.
84 |
85 | References
86 | ----------
87 | [1] Efficient Projections onto the l1-Ball for Learning in High Dimensions
88 | John Duchi, Shai Shalev-Shwartz, Yoram Singer, and Tushar Chandra.
89 | International Conference on Machine Learning (ICML 2008)
90 | """
91 | if (to_project := x.norm(p=1, dim=1) > ε).any():
92 | x_to_project = x[to_project]
93 | ε_ = ε[to_project] if isinstance(ε, Tensor) else x_to_project.new_full((1,), ε)
94 | if not inplace:
95 | x = x.clone()
96 | simplex_proj = simplex_projection(x_to_project.abs(), ε=ε_, inplace=True)
97 | x[to_project] = simplex_proj.copysign_(x_to_project)
98 | return x
99 | else:
100 | return x
101 |
--------------------------------------------------------------------------------
/adv_lib/utils/utils.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 | from typing import Tuple
3 |
4 | import torch
5 | from torch import nn, Tensor
6 |
7 |
8 | class ForwardCounter:
9 | def __init__(self):
10 | self.reset()
11 |
12 | def reset(self):
13 | self.num_samples_called = 0
14 |
15 | def __call__(self, module, input) -> None:
16 | self.num_samples_called += len(input[0])
17 |
18 |
19 | class BackwardCounter:
20 | def __init__(self):
21 | self.reset()
22 |
23 | def reset(self):
24 | self.num_samples_called = 0
25 |
26 | def __call__(self, module, grad_input, grad_output) -> None:
27 | self.num_samples_called += len(grad_output[0])
28 |
29 |
30 | class ImageNormalizer(nn.Module):
31 | def __init__(self, mean: Tuple[float, float, float], std: Tuple[float, float, float]) -> None:
32 | super(ImageNormalizer, self).__init__()
33 |
34 | self.register_buffer('mean', torch.as_tensor(mean).view(1, 3, 1, 1))
35 | self.register_buffer('std', torch.as_tensor(std).view(1, 3, 1, 1))
36 |
37 | def forward(self, input: Tensor) -> Tensor:
38 | return (input - self.mean) / self.std
39 |
40 |
41 | def normalize_model(model: nn.Module, mean: Tuple[float, float, float], std: Tuple[float, float, float]) -> nn.Module:
42 | layers = OrderedDict([
43 | ('normalize', ImageNormalizer(mean, std)),
44 | ('model', model)
45 | ])
46 | return nn.Sequential(layers)
47 |
48 |
49 | def requires_grad_(model: nn.Module, requires_grad: bool) -> None:
50 | for param in model.parameters():
51 | param.requires_grad_(requires_grad)
52 |
53 |
54 | def predict_inputs(model: nn.Module, inputs: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
55 | logits = model(inputs)
56 | probabilities = torch.softmax(logits, 1)
57 | predictions = logits.argmax(1)
58 | return logits, probabilities, predictions
59 |
--------------------------------------------------------------------------------
/adv_lib/utils/visdom_logger.py:
--------------------------------------------------------------------------------
1 | from collections import defaultdict
2 | from enum import Enum
3 | from typing import List, Optional, Tuple, Union
4 |
5 | import torch
6 | import visdom
7 | from torch import Tensor
8 |
9 |
10 | class ChartTypes(Enum):
11 | line = 1,
12 | image = 2
13 |
14 |
15 | class ChartData:
16 | def __init__(self):
17 | self.window = None
18 | self.type = None
19 | self.x_list = []
20 | self.y_list = []
21 | self.other_data = None
22 | self.to_plot = {}
23 |
24 |
25 | class VisdomLogger:
26 | def __init__(self, port: int):
27 | self.vis = visdom.Visdom(port=port)
28 | self.windows = defaultdict(lambda: ChartData())
29 |
30 | @staticmethod
31 | def as_unsqueezed_tensor(data: Union[float, List[float], Tensor]) -> Tensor:
32 | data = torch.as_tensor(data).detach()
33 | return data.unsqueeze(0) if data.ndim == 0 else data
34 |
35 | def accumulate_line(self, names: Union[str, List[str]], x: Union[float, Tensor],
36 | y: Union[float, Tensor, List[Tensor]], title: str = '', **kwargs) -> None:
37 | if isinstance(names, str):
38 | names = [names]
39 | data = self.windows['$'.join(names)]
40 | update = None if data.window is None else 'append'
41 |
42 | if isinstance(y, (int, float)):
43 | Y = torch.tensor([y])
44 | elif isinstance(y, list):
45 | Y = torch.stack(list(map(self.as_unsqueezed_tensor, y)), 1)
46 | elif isinstance(y, Tensor):
47 | Y = self.as_unsqueezed_tensor(y)
48 |
49 | if isinstance(x, (int, float)):
50 | X = torch.tensor([x])
51 | elif isinstance(x, Tensor):
52 | X = self.as_unsqueezed_tensor(x)
53 |
54 | if Y.ndim == 2 and X.ndim == 1:
55 | X.expand(len(X), Y.shape[1])
56 |
57 | if len(data.to_plot) == 0:
58 | data.to_plot = {'X': X, 'Y': Y, 'win': data.window, 'update': update,
59 | 'opts': {'legend': names, 'title': title, **kwargs}}
60 | else:
61 | data.to_plot['X'] = torch.cat((data.to_plot['X'], X), 0)
62 | data.to_plot['Y'] = torch.cat((data.to_plot['Y'], Y), 0)
63 |
64 | def update_lines(self) -> None:
65 | for window, data in self.windows.items():
66 | if len(data.to_plot) != 0:
67 | win = self.vis.line(**data.to_plot)
68 |
69 | data.x_list.append(data.to_plot['X'])
70 | data.y_list.append(data.to_plot['Y'])
71 |
72 | # Update the window
73 | data.window = win
74 | data.type = ChartTypes.line
75 |
76 | data.to_plot = {}
77 |
78 | def line(self, names: Union[str, List[str]], x: Union[float, Tensor], y: Union[float, Tensor, List[Tensor]],
79 | title: str = '', **kwargs) -> None:
80 | self.accumulate_line(names=names, x=x, y=y, title=title, **kwargs)
81 | self.update_lines()
82 |
83 | def images(self, name: str, images: Tensor, mean_std: Optional[Tuple[List[float], List[float]]] = None,
84 | title: str = '') -> None:
85 | data = self.windows[name]
86 |
87 | if mean_std is not None:
88 | images = images * torch.as_tensor(mean_std[0]) + torch.as_tensor(mean_std[1])
89 |
90 | win = self.vis.images(images, win=data.window, opts={'legend': [name], 'title': title})
91 |
92 | # Update the window
93 | data.window = win
94 | data.other_data = images
95 | data.type = ChartTypes.image
96 |
97 | def reset_windows(self):
98 | self.windows.clear()
99 |
100 | def save(self, filename):
101 | to_save = {}
102 | for (name, data) in self.windows.items():
103 | type = data.type
104 | if type == ChartTypes.line:
105 | to_save[name] = (type, torch.cat(data.x_list, dim=0).cpu(), torch.cat(data.y_list, dim=0).cpu())
106 | elif type == ChartTypes.image:
107 | to_save[name] = (type, data.other_data.cpu())
108 |
109 | torch.save(to_save, filename)
110 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["setuptools"]
3 | build-backend = "setuptools.build_meta"
4 |
5 | [project]
6 | name = "adv-lib"
7 | dynamic = ["version"]
8 | authors = [
9 | { name = "Jerome Rony", email = "jerome.rony@gmail.com" },
10 | ]
11 | description = "Library of various adversarial attacks resources in PyTorch"
12 | readme = "README.md"
13 | requires-python = ">=3.8"
14 | license = { file = "LICENSE" }
15 | classifiers = [
16 | "Programming Language :: Python :: 3",
17 | "Development Status :: 3 - Alpha",
18 | "Intended Audience :: Developers",
19 | "Intended Audience :: Science/Research",
20 | "Topic :: Scientific/Engineering :: Artificial Intelligence",
21 | ]
22 | dependencies = [
23 | "torch>=1.8.0",
24 | "torchvision>=0.9.0",
25 | "tqdm>=4.48.0",
26 | "visdom>=0.1.8",
27 | ]
28 |
29 | [tool.setuptools.packages.find]
30 | include = ["adv_lib*"]
31 | namespaces = false
32 |
33 | [tool.setuptools.dynamic]
34 | version = { attr = "adv_lib.__version__" }
35 |
36 | [project.optional-dependencies]
37 | test = ["scikit-image", "pytest"]
38 |
39 | [project.urls]
40 | Repository = "https://github.com/jeromerony/adversarial-library.git"
--------------------------------------------------------------------------------
/tests/distances/test_color_difference.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import torch
3 | from torch.autograd import grad
4 |
5 | from adv_lib.distances.color_difference import ciede2000_color_difference
6 |
7 | # test data from http://www2.ece.rochester.edu/~gsharma/ciede2000/ciede2000noteCRNA.pdf
8 | TEST_DATA = [
9 | # L*1 a*1 b*1 L*2 a*2 b*2 ΔE_00
10 | [50.0000, 2.677200, -79.7751, 50.0000, 0.000000, -82.7485, 2.0425],
11 | [50.0000, 3.157100, -77.2803, 50.0000, 0.000000, -82.7485, 2.8615],
12 | [50.0000, 2.836100, -74.0200, 50.0000, 0.000000, -82.7485, 3.4412],
13 | [50.0000, -1.38020, -84.2814, 50.0000, 0.000000, -82.7485, 1.0000],
14 | [50.0000, -1.18480, -84.8006, 50.0000, 0.000000, -82.7485, 1.0000],
15 | [50.0000, -0.90090, -85.5211, 50.0000, 0.000000, -82.7485, 1.0000],
16 | [50.0000, 0.000000, 0.000000, 50.0000, -1.00000, 2.000000, 2.3669],
17 | [50.0000, -1.00000, 2.000000, 50.0000, 0.000000, 0.000000, 2.3669],
18 | [50.0000, 2.490000, -0.00100, 50.0000, -2.49000, 0.000900, 7.1792],
19 | [50.0000, 2.490000, -0.00100, 50.0000, -2.49000, 0.001000, 7.1792],
20 | [50.0000, 2.490000, -0.00100, 50.0000, -2.49000, 0.001100, 7.2195],
21 | [50.0000, 2.490000, -0.00100, 50.0000, -2.49000, 0.001200, 7.2195],
22 | [50.0000, -0.00100, 2.490000, 50.0000, 0.000900, -2.49000, 4.8045],
23 | [50.0000, -0.00100, 2.490000, 50.0000, 0.001000, -2.49000, 4.8045],
24 | [50.0000, -0.00100, 2.490000, 50.0000, 0.001100, -2.49000, 4.7461],
25 | [50.0000, 2.500000, 0.000000, 50.0000, 0.000000, -2.50000, 4.3065],
26 | [50.0000, 2.500000, 0.000000, 73.0000, 25.00000, -18.0000, 27.1492],
27 | [50.0000, 2.500000, 0.000000, 61.0000, -5.00000, 29.00000, 22.8977],
28 | [50.0000, 2.500000, 0.000000, 56.0000, -27.0000, -3.00000, 31.9030],
29 | [50.0000, 2.500000, 0.000000, 58.0000, 24.00000, 15.00000, 19.4535],
30 | [50.0000, 2.500000, 0.000000, 50.0000, 3.173600, 0.585400, 1.0000],
31 | [50.0000, 2.500000, 0.000000, 50.0000, 3.297200, 0.000000, 1.0000],
32 | [50.0000, 2.500000, 0.000000, 50.0000, 1.863400, 0.575700, 1.0000],
33 | [50.0000, 2.500000, 0.000000, 50.0000, 3.259200, 0.335000, 1.0000],
34 | [60.2574, -34.0099, 36.26770, 60.4626, -34.1751, 39.43870, 1.2644],
35 | [63.0109, -31.0961, -5.86630, 62.8187, -29.7946, -4.08640, 1.2630],
36 | [61.2901, 3.719600, -5.39010, 61.4292, 2.248000, -4.96200, 1.8731],
37 | [35.0831, -44.1164, 3.793300, 35.0232, -40.0716, 1.590100, 1.8645],
38 | [22.7233, 20.09040, -46.6940, 23.0331, 14.97300, -42.5619, 2.0373],
39 | [36.4612, 47.85800, 18.38520, 36.2715, 50.50650, 21.22310, 1.4146],
40 | [90.8027, -2.08310, 1.441000, 91.1528, -1.64350, 0.044700, 1.4441],
41 | [90.9257, -0.54060, -0.92080, 88.6381, -0.89850, -0.72390, 1.5381],
42 | [6.77470, -0.29080, -2.42470, 5.87140, -0.09850, -2.22860, 0.6377],
43 | [2.07760, 0.079500, -1.13500, 0.90330, -0.06360, -0.55140, 0.9082],
44 | ]
45 |
46 |
47 | @pytest.mark.parametrize('dtype', [torch.float32, torch.float64])
48 | def test_ciede2000_color_difference_value(dtype: torch.dtype) -> None:
49 | test_data = torch.tensor(TEST_DATA, dtype=dtype)
50 | Lab1 = test_data.narrow(1, 0, 3).clone()
51 | Lab2 = test_data.narrow(1, 3, 3).clone()
52 | test_ΔE_00 = test_data[:, -1].clone()
53 |
54 | ΔE_00_1 = ciede2000_color_difference(Lab1, Lab2)
55 | ΔE_00_2 = ciede2000_color_difference(Lab2, Lab1)
56 | ΔE_00_0 = ciede2000_color_difference(Lab1, Lab1)
57 |
58 | assert torch.equal(ΔE_00_0, torch.zeros_like(ΔE_00_0)) # check identical inputs
59 | assert torch.equal(ΔE_00_1, ΔE_00_2) # check symmetry
60 | assert torch.allclose(ΔE_00_1, test_ΔE_00, rtol=1e-4, atol=1e-5) # check correctness
61 |
62 |
63 | @pytest.mark.parametrize('dtype', [torch.float32, torch.float64])
64 | def test_ciede2000_color_difference_grad(dtype: torch.dtype) -> None:
65 | test_data = torch.tensor(TEST_DATA, dtype=dtype)
66 | Lab1 = test_data.narrow(1, 0, 3).clone()
67 | Lab2 = test_data.narrow(1, 3, 3).clone()
68 |
69 | # check that gradients are not NaN
70 | Lab1.requires_grad_(True)
71 |
72 | # check for identical inputs
73 | ΔE_00 = ciede2000_color_difference(Lab1, Lab1, ε=1e-12)
74 | ΔE_00_grad = grad(ΔE_00.sum(), Lab1, only_inputs=True)[0]
75 | assert not torch.isnan(ΔE_00_grad).any()
76 | assert torch.equal(ΔE_00_grad, torch.zeros_like(ΔE_00_grad))
77 |
78 | # check for different inputs
79 | ΔE_00 = ciede2000_color_difference(Lab1, Lab2, ε=1e-12)
80 | ΔE_00_grad = grad(ΔE_00.sum(), Lab1, only_inputs=True)[0]
81 | assert not torch.isnan(ΔE_00_grad).any()
82 |
--------------------------------------------------------------------------------
/tests/distances/test_structural_similarity.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pytest
3 | import torch
4 | from skimage import data
5 | from skimage.metrics import structural_similarity
6 |
7 | from adv_lib.distances.structural_similarity import compute_ssim
8 |
9 |
10 | @pytest.mark.parametrize('dtype', [np.float32, np.float64])
11 | def test_compute_ssim_gray(dtype: np.dtype) -> None:
12 | # test for gray level images
13 | np_gray_img = data.camera().astype(dtype) / 255
14 | pt_gray_img = torch.as_tensor(np_gray_img)
15 |
16 | for sigma in [0, 0.01, 0.03, 0.1, 0.3]:
17 | noise = torch.randn_like(pt_gray_img) * sigma
18 |
19 | noisy_pt_gray_img = (pt_gray_img + noise).clamp(0, 1)
20 | noisy_np_gray_img = noisy_pt_gray_img.numpy()
21 |
22 | skimage_ssim = structural_similarity(noisy_np_gray_img, np_gray_img, win_size=11, sigma=1.5,
23 | use_sample_covariance=False, gaussian_weights=True, data_range=1)
24 | adv_lib_ssim = compute_ssim(noisy_pt_gray_img.unsqueeze(0).unsqueeze(1),
25 | pt_gray_img.unsqueeze(0).unsqueeze(1))
26 | abs_diff = abs(skimage_ssim - adv_lib_ssim.item())
27 | assert abs_diff < 2e-5
28 |
29 |
30 | @pytest.mark.parametrize('dtype', [np.float32, np.float64])
31 | def test_compute_ssim_color(dtype: np.dtype) -> None:
32 | # test for color images
33 | np_color_img = data.astronaut().astype(dtype) / 255
34 | pt_color_img = torch.as_tensor(np_color_img)
35 |
36 | for sigma in [0, 0.01, 0.03, 0.1, 0.3]:
37 | noise = torch.randn_like(pt_color_img) * sigma
38 |
39 | noisy_pt_color_img = (pt_color_img + noise).clamp(0, 1)
40 | noisy_np_color_img = noisy_pt_color_img.numpy()
41 |
42 | skimage_ssim = structural_similarity(noisy_np_color_img, np_color_img, win_size=11, sigma=1.5,
43 | multichannel=True, use_sample_covariance=False, gaussian_weights=True,
44 | data_range=1)
45 | adv_lib_ssim = compute_ssim(noisy_pt_color_img.permute(2, 0, 1).unsqueeze(0),
46 | pt_color_img.permute(2, 0, 1).unsqueeze(0))
47 |
48 | abs_diff = abs(skimage_ssim - adv_lib_ssim.item())
49 | assert abs_diff < 1e-5
50 |
--------------------------------------------------------------------------------
/tests/utils/lagrangian_penalties/test_penalty_functions.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import torch
3 | from torch.autograd import grad, gradcheck
4 |
5 | from adv_lib.utils.lagrangian_penalties import all_penalties
6 |
7 |
8 | @pytest.mark.parametrize('penalty', list(all_penalties.values()))
9 | def test_grad(penalty) -> None:
10 | y = torch.randn(512, dtype=torch.double, requires_grad=True)
11 | ρ = torch.randn(512, dtype=torch.double).abs_().clamp_min_(1e-3)
12 | μ = torch.randn(512, dtype=torch.double).abs_().clamp_min_(1e-6)
13 | ρ.requires_grad_(True)
14 | μ.requires_grad_(True)
15 |
16 | # check if gradients are correct compared to numerical approximations using finite differences
17 | assert gradcheck(penalty, inputs=(y, ρ, μ))
18 |
19 |
20 | @pytest.mark.parametrize('penalty,value', [(all_penalties['P2'], 1), (all_penalties['P3'], 1)])
21 | @pytest.mark.parametrize('dtype', [torch.float32, torch.float64])
22 | def test_nan_grad(penalty, value, dtype) -> None:
23 | y = torch.full((1,), value, dtype=dtype, requires_grad=True)
24 | ρ = torch.full((1,), value, dtype=dtype)
25 | μ = torch.full((1,), value, dtype=dtype)
26 |
27 | out = penalty(y, ρ, μ)
28 | g = grad(out, y, only_inputs=True)[0]
29 |
30 | assert torch.isnan(g).any() == False # check nan in gradients of penalty
31 |
--------------------------------------------------------------------------------
/tests/utils/lagrangian_penalties/test_univariate_functions.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import torch
3 | from torch.autograd import grad, gradcheck
4 |
5 | from adv_lib.utils.lagrangian_penalties import univariate_functions
6 |
7 |
8 | @pytest.mark.parametrize('univariate', univariate_functions.__all__)
9 | def test_grad(univariate) -> None:
10 | t = torch.randn(512, dtype=torch.double, requires_grad=True)
11 | # check if gradients are correct compared to numerical approximations using finite differences
12 | assert gradcheck(univariate_functions.__dict__[univariate](), inputs=t)
13 |
14 |
15 | @pytest.mark.parametrize('univariate,value', [('LogExp', 1), ('LogQuad_1', 1), ('HyperExp', 1), ('HyperQuad', 1),
16 | ('LogBarrierQuad', 0), ('HyperBarrierQuad', 0), ('HyperLogQuad', 0),
17 | ('HyperLogQuad', 1)])
18 | @pytest.mark.parametrize('dtype', [torch.float32, torch.float64])
19 | def test_nan_grad(univariate, value, dtype) -> None:
20 | t = torch.full((1,), value, dtype=dtype, requires_grad=True)
21 |
22 | univariate_func = univariate_functions.__dict__[univariate]()
23 | out = univariate_func(t)
24 | g = grad(out, t, only_inputs=True)[0]
25 |
26 | assert torch.isnan(g).any() == False # check nan in gradients of penalty
27 |
--------------------------------------------------------------------------------