├── .gitignore ├── CITATION.cff ├── LICENSE ├── README.md ├── adv_lib ├── __init__.py ├── attacks │ ├── __init__.py │ ├── augmented_lagrangian.py │ ├── auto_pgd.py │ ├── carlini_wagner │ │ ├── __init__.py │ │ ├── l2.py │ │ └── linf.py │ ├── decoupled_direction_norm.py │ ├── deepfool.py │ ├── fast_adaptive_boundary │ │ ├── __init__.py │ │ ├── fast_adaptive_boundary.py │ │ └── projections.py │ ├── fast_minimum_norm.py │ ├── perceptual_color_attacks │ │ ├── __init__.py │ │ ├── differential_color_functions.py │ │ └── perceptual_color_distance_al.py │ ├── primal_dual_gradient_descent.py │ ├── projected_gradient_descent.py │ ├── segmentation │ │ ├── __init__.py │ │ ├── alma_prox.py │ │ ├── asma.py │ │ ├── dense_adversary.py │ │ └── primal_dual_gradient_descent.py │ ├── sigma_zero.py │ ├── stochastic_sparse_attacks.py │ ├── structured_adversarial_attack.py │ ├── superdeepfool.py │ └── trust_region.py ├── distances │ ├── __init__.py │ ├── color_difference.py │ ├── lp_norms.py │ ├── lpips.py │ └── structural_similarity.py └── utils │ ├── __init__.py │ ├── attack_utils.py │ ├── color_conversions.py │ ├── image_selection.py │ ├── lagrangian_penalties │ ├── __init__.py │ ├── all_penalties.py │ ├── penalty_functions.py │ ├── scripts │ │ ├── plot_penalties.py │ │ └── plot_univariates.py │ └── univariate_functions.py │ ├── losses.py │ ├── projections.py │ ├── utils.py │ └── visdom_logger.py ├── pyproject.toml └── tests ├── distances ├── test_color_difference.py └── test_structural_similarity.py └── utils └── lagrangian_penalties ├── test_penalty_functions.py └── test_univariate_functions.py /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | # Created by https://www.toptal.com/developers/gitignore/api/python,linux,pycharm 3 | # Edit at https://www.toptal.com/developers/gitignore?templates=python,linux,pycharm 4 | 5 | ### Linux ### 6 | *~ 7 | 8 | # temporary files which can be created if a process still has a handle open of a deleted file 9 | .fuse_hidden* 10 | 11 | # KDE directory preferences 12 | .directory 13 | 14 | # Linux trash folder which might appear on any partition or disk 15 | .Trash-* 16 | 17 | # .nfs files are created when an open file is removed but is still being accessed 18 | .nfs* 19 | 20 | ### PyCharm ### 21 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider 22 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 23 | 24 | # User-specific stuff 25 | .idea/**/workspace.xml 26 | .idea/**/tasks.xml 27 | .idea/**/usage.statistics.xml 28 | .idea/**/dictionaries 29 | .idea/**/shelf 30 | 31 | # Generated files 32 | .idea/**/contentModel.xml 33 | 34 | # Sensitive or high-churn files 35 | .idea/**/dataSources/ 36 | .idea/**/dataSources.ids 37 | .idea/**/dataSources.local.xml 38 | .idea/**/sqlDataSources.xml 39 | .idea/**/dynamic.xml 40 | .idea/**/uiDesigner.xml 41 | .idea/**/dbnavigator.xml 42 | 43 | # Gradle 44 | .idea/**/gradle.xml 45 | .idea/**/libraries 46 | 47 | # Gradle and Maven with auto-import 48 | # When using Gradle or Maven with auto-import, you should exclude module files, 49 | # since they will be recreated, and may cause churn. Uncomment if using 50 | # auto-import. 51 | # .idea/artifacts 52 | # .idea/compiler.xml 53 | # .idea/jarRepositories.xml 54 | # .idea/modules.xml 55 | # .idea/*.iml 56 | # .idea/modules 57 | # *.iml 58 | # *.ipr 59 | 60 | # CMake 61 | cmake-build-*/ 62 | 63 | # Mongo Explorer plugin 64 | .idea/**/mongoSettings.xml 65 | 66 | # File-based project format 67 | *.iws 68 | 69 | # IntelliJ 70 | out/ 71 | 72 | # mpeltonen/sbt-idea plugin 73 | .idea_modules/ 74 | 75 | # JIRA plugin 76 | atlassian-ide-plugin.xml 77 | 78 | # Cursive Clojure plugin 79 | .idea/replstate.xml 80 | 81 | # Crashlytics plugin (for Android Studio and IntelliJ) 82 | com_crashlytics_export_strings.xml 83 | crashlytics.properties 84 | crashlytics-build.properties 85 | fabric.properties 86 | 87 | # Editor-based Rest Client 88 | .idea/httpRequests 89 | 90 | # Android studio 3.1+ serialized cache file 91 | .idea/caches/build_file_checksums.ser 92 | 93 | ### PyCharm Patch ### 94 | # Comment Reason: https://github.com/joeblau/gitignore.io/issues/186#issuecomment-215987721 95 | 96 | # *.iml 97 | # modules.xml 98 | # .idea/misc.xml 99 | # *.ipr 100 | 101 | # Sonarlint plugin 102 | # https://plugins.jetbrains.com/plugin/7973-sonarlint 103 | .idea/**/sonarlint/ 104 | 105 | # SonarQube Plugin 106 | # https://plugins.jetbrains.com/plugin/7238-sonarqube-community-plugin 107 | .idea/**/sonarIssues.xml 108 | 109 | # Markdown Navigator plugin 110 | # https://plugins.jetbrains.com/plugin/7896-markdown-navigator-enhanced 111 | .idea/**/markdown-navigator.xml 112 | .idea/**/markdown-navigator-enh.xml 113 | .idea/**/markdown-navigator/ 114 | 115 | # Cache file creation bug 116 | # See https://youtrack.jetbrains.com/issue/JBR-2257 117 | .idea/$CACHE_FILE$ 118 | 119 | # CodeStream plugin 120 | # https://plugins.jetbrains.com/plugin/12206-codestream 121 | .idea/codestream.xml 122 | 123 | ### Python ### 124 | # Byte-compiled / optimized / DLL files 125 | __pycache__/ 126 | *.py[cod] 127 | *$py.class 128 | 129 | # C extensions 130 | *.so 131 | 132 | # Distribution / packaging 133 | .Python 134 | build/ 135 | develop-eggs/ 136 | dist/ 137 | downloads/ 138 | eggs/ 139 | .eggs/ 140 | lib/ 141 | lib64/ 142 | parts/ 143 | sdist/ 144 | var/ 145 | wheels/ 146 | pip-wheel-metadata/ 147 | share/python-wheels/ 148 | *.egg-info/ 149 | .installed.cfg 150 | *.egg 151 | MANIFEST 152 | 153 | # PyInstaller 154 | # Usually these files are written by a python script from a template 155 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 156 | *.manifest 157 | *.spec 158 | 159 | # Installer logs 160 | pip-log.txt 161 | pip-delete-this-directory.txt 162 | 163 | # Unit tests / coverage reports 164 | htmlcov/ 165 | .tox/ 166 | .nox/ 167 | .coverage 168 | .coverage.* 169 | .cache 170 | nosetests.xml 171 | coverage.xml 172 | *.cover 173 | *.py,cover 174 | .hypothesis/ 175 | .pytest_cache/ 176 | pytestdebug.log 177 | 178 | # Translations 179 | *.mo 180 | *.pot 181 | 182 | # Django stuff: 183 | *.log 184 | local_settings.py 185 | db.sqlite3 186 | db.sqlite3-journal 187 | 188 | # Flask stuff: 189 | instance/ 190 | .webassets-cache 191 | 192 | # Scrapy stuff: 193 | .scrapy 194 | 195 | # Sphinx documentation 196 | docs/_build/ 197 | doc/_build/ 198 | 199 | # PyBuilder 200 | target/ 201 | 202 | # Jupyter Notebook 203 | .ipynb_checkpoints 204 | 205 | # IPython 206 | profile_default/ 207 | ipython_config.py 208 | 209 | # pyenv 210 | .python-version 211 | 212 | # pipenv 213 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 214 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 215 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 216 | # install all needed dependencies. 217 | #Pipfile.lock 218 | 219 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 220 | __pypackages__/ 221 | 222 | # Celery stuff 223 | celerybeat-schedule 224 | celerybeat.pid 225 | 226 | # SageMath parsed files 227 | *.sage.py 228 | 229 | # Environments 230 | .env 231 | .venv 232 | env/ 233 | venv/ 234 | ENV/ 235 | env.bak/ 236 | venv.bak/ 237 | 238 | # Spyder project settings 239 | .spyderproject 240 | .spyproject 241 | 242 | # Rope project settings 243 | .ropeproject 244 | 245 | # mkdocs documentation 246 | /site 247 | 248 | # mypy 249 | .mypy_cache/ 250 | .dmypy.json 251 | dmypy.json 252 | 253 | # Pyre type checker 254 | .pyre/ 255 | 256 | # pytype static type analyzer 257 | .pytype/ 258 | 259 | # End of https://www.toptal.com/developers/gitignore/api/python,linux,pycharm -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | # This CITATION.cff file was generated with cffinit. 2 | # Visit https://bit.ly/cffinit to generate yours today! 3 | 4 | cff-version: 1.2.0 5 | title: Adversarial Library 6 | message: >- 7 | If you use this software, please cite it using the 8 | metadata from this file. 9 | type: software 10 | authors: 11 | - given-names: Jérôme 12 | family-names: Rony 13 | affiliation: ÉTS Montréal 14 | orcid: 'https://orcid.org/0000-0002-6359-6142' 15 | - given-names: Ismail 16 | family-names: Ben Ayed 17 | affiliation: ÉTS Montréal 18 | orcid: 'https://orcid.org/0000-0002-9668-8027' 19 | doi: 10.5281/zenodo.5815063 20 | repository-code: 'https://github.com/jeromerony/adversarial-library' 21 | keywords: 22 | - machine learning 23 | - adversarial examples 24 | - pytorch 25 | license: BSD-3-Clause 26 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2020, Jérôme Rony 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | 3. Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | [![DOI](https://zenodo.org/badge/315504148.svg)](https://zenodo.org/badge/latestdoi/315504148) 3 | 4 | # Adversarial Library 5 | 6 | This library contains various resources related to adversarial attacks implemented in PyTorch. It is aimed towards researchers looking for implementations of state-of-the-art attacks. 7 | 8 | The code was written to maximize efficiency (_e.g._ by preferring low level functions from PyTorch) while retaining simplicity (_e.g._ by avoiding abstractions). As a consequence, most of the library, and especially the attacks, is implemented using **pure functions** (whenever possible). 9 | 10 | While focused on attacks, this library also provides several utilities related to adversarial attacks: distances (SSIM, CIEDE2000, LPIPS), visdom callback, projections, losses and helper functions. Most notably the function `run_attack` from `utils/attack_utils.py` performs an attack on a model given the inputs and labels, with fixed batch size, and reports complexity related metrics (run-time and forward/backward propagations). 11 | 12 | ### Dependencies 13 | 14 | The goal of this library is to be up-to-date with newer versions of PyTorch so the dependencies are expected to be updated regularly (possibly resulting in breaking changes). 15 | 16 | - pytorch>=1.8.0 17 | - torchvision>=0.9.0 18 | - tqdm>=4.48.0 19 | - visdom>=0.1.8 20 | 21 | ### Installation 22 | 23 | You can either install using: 24 | 25 | ```pip install git+https://github.com/jeromerony/adversarial-library``` 26 | 27 | Or you can clone the repo and run: 28 | 29 | ```python setup.py install``` 30 | 31 | Alternatively, you can install (after cloning) the library in editable mode: 32 | 33 | ```pip install -e .``` 34 | 35 | ### Usage 36 | Attacks are implemented as functions, so they can be called directly by providing the model, samples and labels (possibly with optional arguments): 37 | ```python 38 | from adv_lib.attacks import ddn 39 | adv_samples = ddn(model=model, inputs=inputs, labels=labels, steps=300) 40 | ``` 41 | 42 | Classification attacks all expect the following arguments: 43 | - `model`: the model that produces logits (pre-softmax activations) with inputs in $[0, 1]$ 44 | - `inputs`: the samples to attack in $[0, 1]$ 45 | - `labels`: either the ground-truth labels for the samples or the targets 46 | - `targeted`: flag indicated if the attack should be targeted or not -- defaults to `False` 47 | 48 | Additionally, many attacks have an optional `callback` argument which accepts an `adv_lib.utils.visdom_logger.VisdomLogger` to plot data to a visdom server for monitoring purposes. 49 | 50 | For a more detailed example on how to use this library, you can look at this repo: https://github.com/jeromerony/augmented_lagrangian_adversarial_attacks 51 | 52 | ## Contents 53 | 54 | ### Attacks 55 | 56 | #### Classification 57 | 58 | Currently the following classification attacks are implemented in the `adv_lib.attacks` module: 59 | 60 | | Name | Knowledge | Type | Distance(s) | ArXiv Link | 61 | |-----------------------------------------------------------------------------------------|-----------|---------|-----------------------------------------------------------|------------------------------------------------------------------------------------------------------| 62 | | DeepFool (DF) | White-box | Minimal | $\ell_2$, $\ell_\infty$ | [1511.04599](https://arxiv.org/abs/1511.04599) | 63 | | Carlini and Wagner (C&W) | White-box | Minimal | $\ell_2$, $\ell_\infty$ | [1608.04644](https://arxiv.org/abs/1608.04644) | 64 | | Projected Gradient Descent (PGD) | White-box | Budget | $\ell_\infty$ | [1706.06083](https://arxiv.org/abs/1706.06083) | 65 | | Structured Adversarial Attack (StrAttack) | White-box | Minimal | $\ell_2$ + group-sparsity | [1808.01664](https://arxiv.org/abs/1808.01664) | 66 | | **Decoupled Direction and Norm (DDN)** | White-box | Minimal | $\ell_2$ | [1811.09600](https://arxiv.org/abs/1811.09600) | 67 | | Trust Region (TR) | White-box | Minimal | $\ell_2$, $\ell_\infty$ | [1812.06371](https://arxiv.org/abs/1812.06371) | 68 | | Fast Adaptive Boundary (FAB) | White-box | Minimal | $\ell_1$, $\ell_2$, $\ell_\infty$ | [1907.02044](https://arxiv.org/abs/1907.02044) | 69 | | Perceptual Color distance Alternating Loss (PerC-AL) | White-box | Minimal | CIEDE2000 | [1911.02466](https://arxiv.org/abs/1911.02466) | 70 | | Auto-PGD (APGD) | White-box | Budget | $\ell_1$, $\ell_2$, $\ell_\infty$ | [2003.01690](https://arxiv.org/abs/2003.01690)
[2103.01208](https://arxiv.org/abs/2103.01208) | 71 | | **Augmented Lagrangian Method for Adversarial (ALMA)** | White-box | Minimal | $\ell_1$, $\ell_2$, SSIM, CIEDE2000, LPIPS, ... | [2011.11857](https://arxiv.org/abs/2011.11857) | 72 | | Folded Gaussian Attack (FGA)
Voting Folded Gaussian Attack (VFGA) | White-box | Minimal | $\ell_0$ | [2011.12423](https://arxiv.org/abs/2011.12423) | 73 | | Fast Minimum-Norm (FMN) | White-box | Minimal | $\ell_0$, $\ell_1$, $\ell_2$, $\ell_\infty$ | [2102.12827](https://arxiv.org/abs/2102.12827) | 74 | | Primal-Dual Gradient Descent (PDGD)
Primal-Dual Proximal Gradient Descent (PDPGD) | White-box | Minimal | $\ell_2$
$\ell_0$, $\ell_1$, $\ell_2$, $\ell_\infty$ | [2106.01538](https://arxiv.org/abs/2106.01538) | 75 | | SuperDeepFool (SDF) | White-box | Minimal | $\ell_2$ | [2303.12481](https://arxiv.org/abs/2303.12481) | 76 | | σ-zero | White-box | Minimal | $\ell_0$ | [2402.01879](https://arxiv.org/abs/2402.01879) | 77 | 78 | **Bold** means that this repository contains the official implementation. 79 | 80 | _Type_ refers to the goal of the attack: 81 | - _Minimal_ attacks aim to find the smallest adversarial perturbation w.r.t. a given distance; 82 | - _Budget_ attacks aim to find an adversarial perturbation within a distance budget (and often to maximize a loss as well). 83 | 84 | #### Segmentation 85 | 86 | The library now includes segmentation attacks in the `adv_lib.attacks.segmentation` module. These require the following arguments: 87 | - `model`: the model that produces logits (pre-softmax activations) with inputs in $[0, 1]$ 88 | - `inputs`: the images to attack in $[0, 1]$. Shape: $b\times c\times h\times w$ with $b$ the batch size, $c$ the number of color channels and $h$ and $w$ the height and width of the images. 89 | - `labels`: either the ground-truth labels for the samples or the targets. Shape: $b\times h\times w$. 90 | - `masks`: binary mask indicating which pixels to attack, to account for unlabeled pixels (e.g. void in Pascal VOC). Shape: $b\times h\times w$ 91 | - `targeted`: flag indicated if the attack should be targeted or not -- defaults to `False` 92 | - `adv_threshold`: fraction of the pixels to consider an attack successful -- defaults to `0.99` 93 | 94 | The following segmentation attacks are implemented: 95 | 96 | | Name | Knowledge | Type | Distance(s) | ArXiv Link | 97 | |-------------------------------------------------------------------------------------------|-----------|---------|-----------------------------------------------------------|------------------------------------------------| 98 | | Dense Adversary Generation (DAG) | White-box | Minimal | $\ell_2$, $\ell_\infty$ | [1703.08603](https://arxiv.org/abs/1703.08603) | 99 | | Adaptive Segmentation Mask Attack (ASMA) | White-box | Minimal | $\ell_2$ | [1907.13124](https://arxiv.org/abs/1907.13124) | 100 | | _Primal-Dual Gradient Descent (PDGD)
Primal-Dual Proximal Gradient Descent (PDPGD)_ | White-box | Minimal | $\ell_2$
$\ell_0$, $\ell_1$, $\ell_2$, $\ell_\infty$ | [2106.01538](https://arxiv.org/abs/2106.01538) | 101 | | **ALMA prox** | White-box | Minimal | $\ell_\infty$ | [2206.07179](https://arxiv.org/abs/2206.07179) | 102 | 103 | _Italic_ indicates that the attack is unofficially adapted from the classification variant. 104 | 105 | ### Distances 106 | 107 | The following distances are available in the utils `adv_lib.distances` module: 108 | - Lp-norms 109 | - SSIM https://ece.uwaterloo.ca/~z70wang/research/ssim/ 110 | - MS-SSIM https://ece.uwaterloo.ca/~z70wang/publications/msssim.html 111 | - CIEDE2000 color difference http://www2.ece.rochester.edu/~gsharma/ciede2000/ciede2000noteCRNA.pdf 112 | - LPIPS https://arxiv.org/abs/1801.03924 113 | 114 | ## Contributions 115 | 116 | Suggestions and contributions are welcome :) 117 | 118 | ## Citation 119 | 120 | If this library has been useful for your research, you can cite it using the "Cite this repository" button in the "About" section. 121 | -------------------------------------------------------------------------------- /adv_lib/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.2.3" 2 | -------------------------------------------------------------------------------- /adv_lib/attacks/__init__.py: -------------------------------------------------------------------------------- 1 | from .augmented_lagrangian import alma 2 | from .auto_pgd import apgd, apgd_targeted 3 | from .carlini_wagner import carlini_wagner_l2, carlini_wagner_linf 4 | from .decoupled_direction_norm import ddn 5 | from .deepfool import df 6 | from .fast_adaptive_boundary import fab 7 | from .fast_minimum_norm import fmn 8 | from .perceptual_color_attacks import perc_al 9 | from .primal_dual_gradient_descent import pdgd, pdpgd 10 | from .projected_gradient_descent import pgd_linf 11 | from .sigma_zero import sigma_zero 12 | from .stochastic_sparse_attacks import fga, vfga 13 | from .structured_adversarial_attack import str_attack 14 | from .superdeepfool import sdf 15 | from .trust_region import tr 16 | -------------------------------------------------------------------------------- /adv_lib/attacks/augmented_lagrangian.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from typing import Callable, Optional 3 | 4 | import torch 5 | from torch import Tensor, nn 6 | from torch.autograd import grad 7 | 8 | from adv_lib.distances.color_difference import ciede2000_loss 9 | from adv_lib.distances.lp_norms import l1_distances, l2_distances 10 | from adv_lib.distances.lpips import LPIPS 11 | from adv_lib.distances.structural_similarity import ms_ssim_loss, ssim_loss 12 | from adv_lib.utils.lagrangian_penalties import all_penalties 13 | from adv_lib.utils.losses import difference_of_logits_ratio 14 | from adv_lib.utils.visdom_logger import VisdomLogger 15 | 16 | 17 | def init_lr_finder(inputs: Tensor, grad: Tensor, distance_function: Callable, target_distance: float) -> Tensor: 18 | """ 19 | Performs a line search and a binary search to find the learning rate η for each sample such that: 20 | distance_function(inputs - η * grad) = target_distance. 21 | 22 | Parameters 23 | ---------- 24 | inputs : Tensor 25 | Reference to compute the distance from. 26 | grad : Tensor 27 | Direction to step in. 28 | distance_function : Callable 29 | target_distance : float 30 | Target distance that inputs - η * grad should reach. 31 | 32 | Returns 33 | ------- 34 | η : Tensor 35 | Learning rate for each sample. 36 | 37 | """ 38 | batch_size = len(inputs) 39 | batch_view = lambda tensor: tensor.view(batch_size, *[1] * (inputs.ndim - 1)) 40 | lr = torch.ones(batch_size, device=inputs.device) 41 | lower = torch.zeros_like(lr) 42 | 43 | found_upper = distance_function((inputs - grad).clamp_(min=0, max=1)) > target_distance 44 | while (~found_upper).any(): 45 | lower = torch.where(found_upper, lower, lr) 46 | lr = torch.where(found_upper, lr, lr * 2) 47 | found_upper = distance_function((inputs - batch_view(lr) * grad).clamp_(min=0, max=1)) > target_distance 48 | 49 | for i in range(20): 50 | new_lr = (lower + lr) / 2 51 | larger = distance_function((inputs - batch_view(new_lr) * grad).clamp_(min=0, max=1)) > target_distance 52 | lower, lr = torch.where(larger, lower, new_lr), torch.where(larger, new_lr, lr) 53 | 54 | return (lr + lower) / 2 55 | 56 | 57 | _distances = { 58 | 'ssim': ssim_loss, 59 | 'msssim': ms_ssim_loss, 60 | 'ciede2000': partial(ciede2000_loss, ε=1e-12), 61 | 'lpips': LPIPS, 62 | 'l2': l2_distances, 63 | 'l1': l1_distances, 64 | } 65 | 66 | 67 | def alma(model: nn.Module, 68 | inputs: Tensor, 69 | labels: Tensor, 70 | penalty: Callable = all_penalties['P2'], 71 | targeted: bool = False, 72 | num_steps: int = 1000, 73 | lr_init: float = 0.1, 74 | lr_reduction: float = 0.01, 75 | distance: str = 'l2', 76 | init_lr_distance: Optional[float] = None, 77 | μ_init: float = 1, 78 | ρ_init: float = 1, 79 | check_steps: int = 10, 80 | τ: float = 0.95, 81 | γ: float = 1.2, 82 | α: float = 0.9, 83 | α_rms: Optional[float] = None, 84 | momentum: Optional[float] = None, 85 | logit_tolerance: float = 1e-4, 86 | levels: Optional[int] = None, 87 | callback: Optional[VisdomLogger] = None) -> Tensor: 88 | """ 89 | Augmented Lagrangian Method for Adversarial (ALMA) attack from https://arxiv.org/abs/2011.11857. 90 | 91 | Parameters 92 | ---------- 93 | model : nn.Module 94 | Model to attack. 95 | inputs : Tensor 96 | Inputs to attack. Should be in [0, 1]. 97 | labels : Tensor 98 | Labels corresponding to the inputs if untargeted, else target labels. 99 | penalty : Callable 100 | Penalty-Lagrangian function to use. A good default choice is P2 (see the original article). 101 | targeted : bool 102 | Whether to perform a targeted attack or not. 103 | num_steps : int 104 | Number of optimization steps. Corresponds to the number of forward and backward propagations. 105 | lr_init : float 106 | Initial learning rate. 107 | lr_reduction : float 108 | Reduction factor for the learning rate. The final learning rate is lr_init * lr_reduction 109 | distance : str 110 | Distance to use. 111 | init_lr_distance : float 112 | If a float is given, the initial learning rate will be calculated such that the first step results in an 113 | increase of init_lr_distance of the distance to minimize. This corresponds to ε in the original article. 114 | μ_init : float 115 | Initial value of the penalty multiplier. 116 | ρ_init : float 117 | Initial value of the penalty parameter. 118 | check_steps : int 119 | Number of steps between checks for the improvement of the constraint. This corresponds to M in the original 120 | article. 121 | τ : float 122 | Constraint improvement rate. 123 | γ : float 124 | Penalty parameter increase rate. 125 | α : float 126 | Weight for the exponential moving average. 127 | α_rms : float 128 | Smoothing constant for RMSProp. If none is provided, defaults to α. 129 | momentum : float 130 | Momentum for the RMSProp. If none is provided, defaults to α. 131 | logit_tolerance : float 132 | Small quantity added to the difference of logits to avoid solutions where the difference of logits is 0, which 133 | can results in inconsistent class prediction (using argmax) on GPU. This can also be used as a confidence 134 | parameter κ as in https://arxiv.org/abs/1608.04644, however, a confidence parameter on logits is not robust to 135 | scaling of the logits. 136 | levels : int 137 | Number of levels for quantization. The attack will perform quantization only if the number of levels is 138 | provided. 139 | callback : VisdomLogger 140 | Callback to visualize the progress of the algorithm. 141 | 142 | Returns 143 | ------- 144 | best_adv : Tensor 145 | Perturbed inputs (inputs + perturbation) that are adversarial and have smallest distance with the original 146 | inputs. 147 | 148 | """ 149 | device = inputs.device 150 | batch_size = len(inputs) 151 | batch_view = lambda tensor: tensor.view(batch_size, *[1] * (inputs.ndim - 1)) 152 | multiplier = -1 if targeted else 1 153 | 154 | # Setup variables 155 | δ = torch.zeros_like(inputs, requires_grad=True) 156 | square_avg = torch.ones_like(inputs) 157 | momentum_buffer = torch.zeros_like(inputs) 158 | lr = torch.full((batch_size,), lr_init, device=device, dtype=torch.float) 159 | α_rms, momentum = α if α_rms is None else α_rms, α if momentum is None else momentum 160 | 161 | # Init rho and mu 162 | μ = torch.full((batch_size,), μ_init, device=device, dtype=torch.float) 163 | ρ = torch.full((batch_size,), ρ_init, device=device, dtype=torch.float) 164 | 165 | # Init similarity metric 166 | if distance in ['lpips']: 167 | dist_func = _distances[distance](target=inputs) 168 | else: 169 | dist_func = partial(_distances[distance], inputs) 170 | 171 | # Init trackers 172 | best_dist = torch.full((batch_size,), float('inf'), device=device) 173 | best_adv = inputs.clone() 174 | adv_found = torch.zeros_like(best_dist, dtype=torch.bool) 175 | step_found = torch.full_like(best_dist, num_steps + 1) 176 | 177 | for i in range(num_steps): 178 | 179 | adv_inputs = inputs + δ 180 | logits = model(adv_inputs) 181 | dist = dist_func(adv_inputs) 182 | 183 | if i == 0: 184 | labels_infhot = torch.zeros_like(logits).scatter_(1, labels.unsqueeze(1), float('inf')) 185 | dlr_func = partial(difference_of_logits_ratio, labels=labels, labels_infhot=labels_infhot, 186 | targeted=targeted, ε=logit_tolerance) 187 | 188 | dlr = multiplier * dlr_func(logits) 189 | 190 | if i == 0: 191 | prev_dlr = dlr.detach() 192 | elif (i + 1) % check_steps == 0: 193 | improved_dlr = (dlr.detach() < τ * prev_dlr) 194 | ρ = torch.where(~(adv_found | improved_dlr), γ * ρ, ρ) 195 | prev_dlr = dlr.detach() 196 | 197 | if i: 198 | new_μ = grad(penalty(dlr, ρ, μ).sum(), dlr, only_inputs=True)[0] 199 | μ.lerp_(new_μ, weight=1 - α).clamp_(min=1e-6, max=1e12) 200 | 201 | is_adv = dlr < 0 202 | is_smaller = dist < best_dist 203 | is_both = is_adv & is_smaller 204 | step_found.masked_fill_((~adv_found) & is_adv, i) 205 | adv_found.logical_or_(is_adv) 206 | best_dist = torch.where(is_both, dist.detach(), best_dist) 207 | best_adv = torch.where(batch_view(is_both), adv_inputs.detach(), best_adv) 208 | 209 | if i == 0: 210 | loss = penalty(dlr, ρ, μ) 211 | else: 212 | loss = dist + penalty(dlr, ρ, μ) 213 | δ_grad = grad(loss.sum(), δ, only_inputs=True)[0] 214 | 215 | grad_norm = δ_grad.flatten(1).norm(p=2, dim=1) 216 | if init_lr_distance is not None and i == 0: 217 | randn_grad = torch.randn_like(δ_grad).renorm(dim=0, p=2, maxnorm=1) 218 | δ_grad = torch.where(batch_view(grad_norm <= 1e-6), randn_grad, δ_grad) 219 | lr = init_lr_finder(inputs, δ_grad, dist_func, target_distance=init_lr_distance) 220 | 221 | exp_decay = lr_reduction ** ((i - step_found).clamp_(min=0) / (num_steps - step_found)) 222 | step_lr = lr * exp_decay 223 | square_avg.mul_(α_rms).addcmul_(δ_grad, δ_grad, value=1 - α_rms) 224 | momentum_buffer.mul_(momentum).addcdiv_(δ_grad, square_avg.sqrt().add_(1e-8)) 225 | δ.data.addcmul_(momentum_buffer, batch_view(step_lr), value=-1) 226 | 227 | δ.data.add_(inputs).clamp_(min=0, max=1) 228 | if levels is not None: 229 | δ.data.mul_(levels - 1).round_().div_(levels - 1) 230 | δ.data.sub_(inputs) 231 | 232 | if callback: 233 | cb_best_dist = best_dist.masked_select(adv_found).mean() 234 | callback.accumulate_line([distance, 'dlr'], i, [dist.mean(), dlr.mean()]) 235 | callback.accumulate_line(['μ_c', 'ρ_c'], i, [μ.mean(), ρ.mean()]) 236 | callback.accumulate_line('grad_norm', i, grad_norm.mean()) 237 | callback.accumulate_line(['best_{}'.format(distance), 'success', 'lr'], i, 238 | [cb_best_dist, adv_found.float().mean(), step_lr.mean()]) 239 | 240 | if (i + 1) % (num_steps // 20) == 0 or (i + 1) == num_steps: 241 | callback.update_lines() 242 | 243 | return best_adv 244 | -------------------------------------------------------------------------------- /adv_lib/attacks/carlini_wagner/__init__.py: -------------------------------------------------------------------------------- 1 | from .l2 import carlini_wagner_l2 2 | from .linf import carlini_wagner_linf -------------------------------------------------------------------------------- /adv_lib/attacks/carlini_wagner/l2.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/carlini/nn_robust_attacks 2 | 3 | from typing import Tuple, Optional 4 | 5 | import torch 6 | from torch import nn, optim, Tensor 7 | from torch.autograd import grad 8 | 9 | from adv_lib.utils.losses import difference_of_logits 10 | from adv_lib.utils.visdom_logger import VisdomLogger 11 | 12 | 13 | def carlini_wagner_l2(model: nn.Module, 14 | inputs: Tensor, 15 | labels: Tensor, 16 | targeted: bool = False, 17 | confidence: float = 0, 18 | learning_rate: float = 0.01, 19 | initial_const: float = 0.001, 20 | binary_search_steps: int = 9, 21 | max_iterations: int = 10000, 22 | abort_early: bool = True, 23 | callback: Optional[VisdomLogger] = None) -> Tensor: 24 | """ 25 | Carlini and Wagner L2 attack from https://arxiv.org/abs/1608.04644. 26 | 27 | Parameters 28 | ---------- 29 | model : nn.Module 30 | Model to attack. 31 | inputs : Tensor 32 | Inputs to attack. Should be in [0, 1]. 33 | labels : Tensor 34 | Labels corresponding to the inputs if untargeted, else target labels. 35 | targeted : bool 36 | Whether to perform a targeted attack or not. 37 | confidence : float 38 | Confidence of adversarial examples: higher produces examples that are farther away, but more strongly classified 39 | as adversarial. 40 | learning_rate: float 41 | The learning rate for the attack algorithm. Smaller values produce better results but are slower to converge. 42 | initial_const : float 43 | The initial tradeoff-constant to use to tune the relative importance of distance and confidence. If 44 | binary_search_steps is large, the initial constant is not important. 45 | binary_search_steps : int 46 | The number of times we perform binary search to find the optimal tradeoff-constant between distance and 47 | confidence. 48 | max_iterations : int 49 | The maximum number of iterations. Larger values are more accurate; setting too small will require a large 50 | learning rate and will produce poor results. 51 | abort_early : bool 52 | If true, allows early aborts if gradient descent gets stuck. 53 | callback : Optional 54 | 55 | Returns 56 | ------- 57 | adv_inputs : Tensor 58 | Modified inputs to be adversarial to the model. 59 | 60 | """ 61 | device = inputs.device 62 | batch_size = len(inputs) 63 | batch_view = lambda tensor: tensor.view(batch_size, *[1] * (inputs.ndim - 1)) 64 | t_inputs = (inputs * 2).sub_(1).mul_(1 - 1e-6).atanh_() 65 | multiplier = -1 if targeted else 1 66 | 67 | # set the lower and upper bounds accordingly 68 | c = torch.full((batch_size,), initial_const, device=device) 69 | lower_bound = torch.zeros_like(c) 70 | upper_bound = torch.full_like(c, 1e10) 71 | 72 | o_best_l2 = torch.full_like(c, float('inf')) 73 | o_best_adv = inputs.clone() 74 | o_adv_found = torch.zeros(batch_size, device=device, dtype=torch.bool) 75 | 76 | i_total = 0 77 | for outer_step in range(binary_search_steps): 78 | 79 | # setup the modifier variable and the optimizer 80 | modifier = torch.zeros_like(inputs, requires_grad=True) 81 | optimizer = optim.Adam([modifier], lr=learning_rate) 82 | best_l2 = torch.full_like(c, float('inf')) 83 | adv_found = torch.zeros(batch_size, device=device, dtype=torch.bool) 84 | 85 | # The last iteration (if we run many steps) repeat the search once. 86 | if (binary_search_steps >= 10) and outer_step == (binary_search_steps - 1): 87 | c = upper_bound 88 | 89 | prev = float('inf') 90 | for i in range(max_iterations): 91 | 92 | adv_inputs = (torch.tanh(t_inputs + modifier) + 1) / 2 93 | l2_squared = (adv_inputs - inputs).flatten(1).square().sum(1) 94 | l2 = l2_squared.detach().sqrt() 95 | logits = model(adv_inputs) 96 | 97 | if outer_step == 0 and i == 0: 98 | # setup the target variable, we need it to be in one-hot form for the loss function 99 | labels_onehot = torch.zeros_like(logits).scatter_(1, labels.unsqueeze(1), 1) 100 | labels_infhot = torch.zeros_like(logits).scatter_(1, labels.unsqueeze(1), float('inf')) 101 | 102 | # adjust the best result found so far 103 | predicted_classes = (logits - labels_onehot * confidence).argmax(1) if targeted else \ 104 | (logits + labels_onehot * confidence).argmax(1) 105 | 106 | is_adv = (predicted_classes == labels) if targeted else (predicted_classes != labels) 107 | is_smaller = l2 < best_l2 108 | o_is_smaller = l2 < o_best_l2 109 | is_both = is_adv & is_smaller 110 | o_is_both = is_adv & o_is_smaller 111 | 112 | best_l2 = torch.where(is_both, l2, best_l2) 113 | adv_found.logical_or_(is_both) 114 | o_best_l2 = torch.where(o_is_both, l2, o_best_l2) 115 | o_adv_found.logical_or_(is_both) 116 | o_best_adv = torch.where(batch_view(o_is_both), adv_inputs.detach(), o_best_adv) 117 | 118 | logit_dists = multiplier * difference_of_logits(logits, labels, labels_infhot=labels_infhot) 119 | loss = l2_squared + c * (logit_dists + confidence).clamp_(min=0) 120 | 121 | # check if we should abort search if we're getting nowhere. 122 | if abort_early and i % (max_iterations // 10) == 0: 123 | if (loss > prev * 0.9999).all(): 124 | break 125 | prev = loss.detach() 126 | 127 | optimizer.zero_grad(set_to_none=True) 128 | modifier.grad = grad(loss.sum(), modifier, only_inputs=True)[0] 129 | optimizer.step() 130 | 131 | if callback: 132 | i_total += 1 133 | callback.accumulate_line('logit_dist', i_total, logit_dists.mean()) 134 | callback.accumulate_line('l2_norm', i_total, l2.mean()) 135 | if i_total % (max_iterations // 20) == 0: 136 | callback.update_lines() 137 | 138 | if callback: 139 | best_l2 = o_best_l2.masked_select(o_adv_found).mean() 140 | callback.line(['success', 'best_l2', 'c'], outer_step, [o_adv_found.float().mean(), best_l2, c.mean()]) 141 | 142 | # adjust the constant as needed 143 | upper_bound[adv_found] = torch.min(upper_bound[adv_found], c[adv_found]) 144 | adv_not_found = ~adv_found 145 | lower_bound[adv_not_found] = torch.max(lower_bound[adv_not_found], c[adv_not_found]) 146 | is_smaller = upper_bound < 1e9 147 | c[is_smaller] = (lower_bound[is_smaller] + upper_bound[is_smaller]) / 2 148 | c[(~is_smaller) & adv_not_found] *= 10 149 | 150 | # return the best solution found 151 | return o_best_adv 152 | -------------------------------------------------------------------------------- /adv_lib/attacks/carlini_wagner/linf.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/carlini/nn_robust_attacks 2 | 3 | from typing import Tuple, Optional 4 | 5 | import torch 6 | from torch import nn, optim, Tensor 7 | 8 | from adv_lib.utils.losses import difference_of_logits 9 | from adv_lib.utils.visdom_logger import VisdomLogger 10 | 11 | 12 | def carlini_wagner_linf(model: nn.Module, 13 | inputs: Tensor, 14 | labels: Tensor, 15 | targeted: bool = False, 16 | learning_rate: float = 0.01, 17 | max_iterations: int = 1000, 18 | initial_const: float = 1e-5, 19 | largest_const: float = 2e+1, 20 | const_factor: float = 2, 21 | reduce_const: bool = False, 22 | decrease_factor: float = 0.9, 23 | abort_early: bool = True, 24 | callback: Optional[VisdomLogger] = None) -> Tensor: 25 | """ 26 | Carlini and Wagner Linf attack from https://arxiv.org/abs/1608.04644. 27 | 28 | Parameters 29 | ---------- 30 | model : nn.Module 31 | Model to attack. 32 | inputs : Tensor 33 | Inputs to attack. Should be in [0, 1]. 34 | labels : Tensor 35 | Labels corresponding to the inputs if untargeted, else target labels. 36 | targeted : bool 37 | Whether to perform a targeted attack or not. 38 | learning_rate: float 39 | The learning rate for the attack algorithm. Smaller values produce better results but are slower to converge. 40 | max_iterations : int 41 | The maximum number of iterations. Larger values are more accurate; setting too small will require a large 42 | learning rate and will produce poor results. 43 | initial_const : float 44 | The initial tradeoff-constant to use to tune the relative importance of distance and classification objective. 45 | largest_const : float 46 | The maximum tradeoff-constant to use to tune the relative importance of distance and classification objective. 47 | const_factor : float 48 | The multiplicative factor by which the constant is increased if the search failed. 49 | reduce_const : float 50 | If true, after each successful attack, make the constant smaller. 51 | decrease_factor : float 52 | Rate at which τ is decreased. Larger produces better quality results. 53 | abort_early : bool 54 | If true, allows early aborts if gradient descent gets stuck. 55 | image_constraints : Tuple[float, float] 56 | Minimum and maximum pixel values. 57 | callback : Optional 58 | 59 | Returns 60 | ------- 61 | adv_inputs : Tensor 62 | Modified inputs to be adversarial to the model. 63 | 64 | """ 65 | device = inputs.device 66 | batch_size = len(inputs) 67 | t_inputs = (inputs * 2).sub_(1).mul_(1 - 1e-6).atanh_() 68 | multiplier = -1 if targeted else 1 69 | 70 | # set modifier and the parameters used in the optimization 71 | modifier = torch.zeros_like(inputs) 72 | c = torch.full((batch_size,), initial_const, device=device, dtype=torch.float) 73 | τ = torch.ones(batch_size, device=device) 74 | 75 | o_adv_found = torch.zeros_like(c, dtype=torch.bool) 76 | o_best_linf = torch.ones_like(c) 77 | o_best_adv = inputs.clone() 78 | 79 | outer_loops = 0 80 | total_iters = 0 81 | while (to_optimize := (τ > 1 / 255) & (c < largest_const)).any(): 82 | 83 | inputs_, t_inputs_, labels_ = inputs[to_optimize], t_inputs[to_optimize], labels[to_optimize] 84 | batch_view = lambda tensor: tensor.view(len(inputs_), *[1] * (inputs_.ndim - 1)) 85 | 86 | if callback: 87 | callback.line(['const', 'τ'], outer_loops, [c[to_optimize].mean(), τ[to_optimize].mean()]) 88 | callback.line(['success', 'best_linf'], outer_loops, [o_adv_found.float().mean(), best_linf.mean()]) 89 | 90 | # setup the optimizer 91 | modifier_ = modifier[to_optimize].requires_grad_(True) 92 | optimizer = optim.Adam([modifier_], lr=learning_rate) 93 | c_, τ_ = c[to_optimize], τ[to_optimize] 94 | 95 | adv_found = torch.zeros(len(modifier_), device=device, dtype=torch.bool) 96 | best_linf = o_best_linf[to_optimize] 97 | best_adv = inputs_.clone() 98 | 99 | if callback: 100 | callback.line(['const', 'τ'], outer_loops, [c_.mean(), τ_.mean()]) 101 | callback.line(['success', 'best_linf'], outer_loops, [o_adv_found.float().mean(), o_best_linf.mean()]) 102 | 103 | for i in range(max_iterations): 104 | 105 | adv_inputs = (torch.tanh(t_inputs_ + modifier_) + 1) / 2 106 | linf = (adv_inputs.detach() - inputs_).flatten(1).norm(p=float('inf'), dim=1) 107 | logits = model(adv_inputs) 108 | 109 | if i == 0: 110 | labels_infhot = torch.zeros_like(logits).scatter_(1, labels[to_optimize].unsqueeze(1), float('inf')) 111 | 112 | # adjust the best result found so far 113 | predicted_classes = logits.argmax(1) 114 | 115 | is_adv = (predicted_classes == labels_) if targeted else (predicted_classes != labels_) 116 | is_smaller = linf < best_linf 117 | is_both = is_adv & is_smaller 118 | adv_found.logical_or_(is_both) 119 | best_linf = torch.where(is_both, linf, best_linf) 120 | best_adv = torch.where(batch_view(is_both), adv_inputs.detach(), best_adv) 121 | 122 | logit_dists = multiplier * difference_of_logits(logits, labels_, labels_infhot=labels_infhot) 123 | linf_loss = (adv_inputs - inputs_).abs_().sub_(batch_view(τ_)).clamp_(min=0).flatten(1).sum(1) 124 | loss = linf_loss + c_ * logit_dists.clamp_(min=0) 125 | 126 | # check if we should abort search 127 | if abort_early and (loss < 0.0001 * c_).all(): 128 | break 129 | 130 | optimizer.zero_grad() 131 | loss.sum().backward() 132 | optimizer.step() 133 | 134 | if callback: 135 | callback.accumulate_line('logit_dist', total_iters, logit_dists.mean()) 136 | callback.accumulate_line('linf_norm', total_iters, linf.mean()) 137 | 138 | if (i + 1) % (max_iterations // 10) == 0 or (i + 1) == max_iterations: 139 | callback.update_lines() 140 | 141 | total_iters += 1 142 | 143 | o_adv_found[to_optimize] = adv_found | o_adv_found[to_optimize] 144 | o_best_linf[to_optimize] = torch.where(adv_found, best_linf, o_best_linf[to_optimize]) 145 | o_best_adv[to_optimize] = torch.where(batch_view(adv_found), best_adv, o_best_adv[to_optimize]) 146 | modifier[to_optimize] = modifier_.detach() 147 | 148 | smaller_τ_ = adv_found & (best_linf < τ_) 149 | τ_ = torch.where(smaller_τ_, best_linf, τ_) 150 | τ[to_optimize] = torch.where(adv_found, decrease_factor * τ_, τ_) 151 | c[to_optimize] = torch.where(~adv_found, const_factor * c_, c_) 152 | if reduce_const: 153 | c[to_optimize] = torch.where(adv_found, c[to_optimize] / 2, c[to_optimize]) 154 | 155 | outer_loops += 1 156 | 157 | # return the best solution found 158 | return o_best_adv 159 | -------------------------------------------------------------------------------- /adv_lib/attacks/decoupled_direction_norm.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Optional 3 | 4 | import torch 5 | from torch import Tensor, nn 6 | from torch.autograd import grad 7 | from torch.nn import functional as F 8 | 9 | from adv_lib.utils.visdom_logger import VisdomLogger 10 | 11 | 12 | def ddn(model: nn.Module, 13 | inputs: Tensor, 14 | labels: Tensor, 15 | targeted: bool = False, 16 | steps: int = 100, 17 | γ: float = 0.05, 18 | init_norm: float = 1., 19 | levels: Optional[int] = 256, 20 | callback: Optional[VisdomLogger] = None) -> Tensor: 21 | """ 22 | Decoupled Direction and Norm attack from https://arxiv.org/abs/1811.09600. 23 | 24 | Parameters 25 | ---------- 26 | model : nn.Module 27 | Model to attack. 28 | inputs : Tensor 29 | Inputs to attack. Should be in [0, 1]. 30 | labels : Tensor 31 | Labels corresponding to the inputs if untargeted, else target labels. 32 | targeted : bool 33 | Whether to perform a targeted attack or not. 34 | steps : int 35 | Number of optimization steps. 36 | γ : float 37 | Factor by which the norm will be modified. new_norm = norm * (1 + or - γ). 38 | init_norm : float 39 | Initial value for the norm of the attack. 40 | levels : int 41 | If not None, the returned adversarials will have quantized values to the specified number of levels. 42 | callback : Optional 43 | 44 | Returns 45 | ------- 46 | adv_inputs : Tensor 47 | Modified inputs to be adversarial to the model. 48 | 49 | """ 50 | if inputs.min() < 0 or inputs.max() > 1: raise ValueError('Input values should be in the [0, 1] range.') 51 | device = inputs.device 52 | batch_size = len(inputs) 53 | batch_view = lambda tensor: tensor.view(batch_size, *[1] * (inputs.ndim - 1)) 54 | 55 | # Init variables 56 | multiplier = -1 if targeted else 1 57 | δ = torch.zeros_like(inputs, requires_grad=True) 58 | ε = torch.full((batch_size,), init_norm, device=device, dtype=torch.float) 59 | worst_norm = torch.max(inputs, 1 - inputs).flatten(1).norm(p=2, dim=1) 60 | 61 | # Init trackers 62 | best_l2 = worst_norm.clone() 63 | best_adv = inputs.clone() 64 | adv_found = torch.zeros(batch_size, dtype=torch.bool, device=device) 65 | 66 | for i in range(steps): 67 | l2 = δ.data.flatten(1).norm(p=2, dim=1) 68 | adv_inputs = inputs + δ 69 | logits = model(adv_inputs) 70 | pred_labels = logits.argmax(1) 71 | ce_loss = F.cross_entropy(logits, labels, reduction='none') 72 | loss = multiplier * ce_loss 73 | 74 | is_adv = (pred_labels == labels) if targeted else (pred_labels != labels) 75 | is_smaller = l2 < best_l2 76 | is_both = is_adv & is_smaller 77 | adv_found.logical_or_(is_adv) 78 | best_l2 = torch.where(is_both, l2, best_l2) 79 | best_adv = torch.where(batch_view(is_both), adv_inputs.detach(), best_adv) 80 | 81 | δ_grad = grad(loss.sum(), δ, only_inputs=True)[0] 82 | # renorming gradient 83 | grad_norms = δ_grad.flatten(1).norm(p=2, dim=1) 84 | δ_grad.div_(batch_view(grad_norms)) 85 | # avoid nan or inf if gradient is 0 86 | if (zero_grad := (grad_norms < 1e-12)).any(): 87 | δ_grad[zero_grad] = torch.randn_like(δ_grad[zero_grad]) 88 | 89 | α = 0.01 + (1 - 0.01) * (1 + math.cos(math.pi * i / steps)) / 2 90 | 91 | if callback is not None: 92 | cosine = F.cosine_similarity(δ_grad.flatten(1), δ.data.flatten(1), dim=1).mean() 93 | callback.accumulate_line('ce', i, ce_loss.mean()) 94 | callback_best = best_l2.masked_select(adv_found).mean() 95 | callback.accumulate_line(['ε', 'l2', 'best_l2'], i, [ε.mean(), l2.mean(), callback_best]) 96 | callback.accumulate_line(['cosine', 'α', 'success'], i, 97 | [cosine, torch.tensor(α, device=device), adv_found.float().mean()]) 98 | 99 | if (i + 1) % (steps // 20) == 0 or (i + 1) == steps: 100 | callback.update_lines() 101 | 102 | δ.data.add_(δ_grad, alpha=α) 103 | 104 | ε = torch.where(is_adv, (1 - γ) * ε, (1 + γ) * ε) 105 | ε = torch.minimum(ε, worst_norm) 106 | 107 | δ.data.mul_(batch_view(ε / δ.data.flatten(1).norm(p=2, dim=1))) 108 | δ.data.add_(inputs).clamp_(min=0, max=1) 109 | if levels is not None: 110 | δ.data.mul_(levels - 1).round_().div_(levels - 1) 111 | δ.data.sub_(inputs) 112 | 113 | return best_adv 114 | -------------------------------------------------------------------------------- /adv_lib/attacks/deepfool.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import torch 4 | from torch import Tensor, nn 5 | from torch.autograd import grad 6 | 7 | from adv_lib.utils.attack_utils import get_all_targets 8 | 9 | 10 | def df(model: nn.Module, 11 | inputs: Tensor, 12 | labels: Tensor, 13 | targeted: bool = False, 14 | steps: int = 100, 15 | overshoot: float = 0.02, 16 | norm: float = 2, 17 | return_unsuccessful: bool = False, 18 | return_targets: bool = False) -> Tensor: 19 | """ 20 | DeepFool attack from https://arxiv.org/abs/1511.04599. Properly implement parallel sample-wise early-stopping. 21 | 22 | Parameters 23 | ---------- 24 | model : nn.Module 25 | Model to attack. 26 | inputs : Tensor 27 | Inputs to attack. Should be in [0, 1]. 28 | labels : Tensor 29 | Labels corresponding to the inputs if untargeted, else target labels. 30 | targeted : bool 31 | Whether to perform a targeted attack or not. 32 | steps : int 33 | Maixmum number of attack steps. 34 | overshoot : float 35 | Ratio by which to overshoot the boundary estimated from linear model. 36 | norm : float 37 | Norm to minimize in {2, float('inf')}. 38 | return_unsuccessful : bool 39 | Whether to return unsuccessful adversarial inputs ; used by SuperDeepFool. 40 | return_unsuccessful : bool 41 | Whether to return last target labels ; used by SuperDeepFool. 42 | 43 | Returns 44 | ------- 45 | adv_inputs : Tensor 46 | Modified inputs to be adversarial to the model. 47 | 48 | """ 49 | if targeted: 50 | warnings.warn('DeepFool attack is untargeted only. Returning inputs.') 51 | return inputs 52 | 53 | if inputs.min() < 0 or inputs.max() > 1: raise ValueError('Input values should be in the [0, 1] range.') 54 | device = inputs.device 55 | batch_size = len(inputs) 56 | batch_view = lambda tensor: tensor.view(-1, *[1] * (inputs.ndim - 1)) 57 | 58 | # Setup variables 59 | adv_inputs = inputs.clone() 60 | adv_inputs.requires_grad_(True) 61 | 62 | adv_out = inputs.clone() 63 | adv_found = torch.zeros(batch_size, dtype=torch.bool, device=device) 64 | if return_targets: 65 | targets = labels.clone() 66 | 67 | arange = torch.arange(batch_size, device=device) 68 | for i in range(steps): 69 | 70 | logits = model(adv_inputs) 71 | 72 | if i == 0: 73 | other_labels = get_all_targets(labels=labels, num_classes=logits.shape[1]) 74 | 75 | pred_labels = logits.argmax(dim=1) 76 | is_adv = (pred_labels == labels) if targeted else (pred_labels != labels) 77 | 78 | if is_adv.any(): 79 | adv_not_found = ~adv_found 80 | adv_out[adv_not_found] = torch.where(batch_view(is_adv), adv_inputs.detach(), adv_out[adv_not_found]) 81 | adv_found.masked_scatter_(adv_not_found, is_adv) 82 | if is_adv.all(): 83 | break 84 | 85 | not_adv = ~is_adv 86 | logits, labels, other_labels = logits[not_adv], labels[not_adv], other_labels[not_adv] 87 | arange = torch.arange(not_adv.sum(), device=device) 88 | 89 | f_prime = logits.gather(dim=1, index=other_labels) - logits.gather(dim=1, index=labels.unsqueeze(1)) 90 | w_prime = [] 91 | for j, f_prime_k in enumerate(f_prime.unbind(dim=1)): 92 | w_prime_k = grad(f_prime_k.sum(), inputs=adv_inputs, retain_graph=(j + 1) < f_prime.shape[1], 93 | only_inputs=True)[0] 94 | w_prime.append(w_prime_k) 95 | w_prime = torch.stack(w_prime, dim=1) # batch_size × num_classes × ... 96 | 97 | if is_adv.any(): 98 | not_adv = ~is_adv 99 | adv_inputs, w_prime = adv_inputs[not_adv], w_prime[not_adv] 100 | 101 | if norm == 2: 102 | w_prime_norms = w_prime.flatten(2).norm(p=2, dim=2).clamp_(min=1e-6) 103 | elif norm == float('inf'): 104 | w_prime_norms = w_prime.flatten(2).norm(p=1, dim=2).clamp_(min=1e-6) 105 | 106 | distance = f_prime.detach().abs_().div_(w_prime_norms).add_(1e-4) 107 | l_hat = distance.argmin(dim=1) 108 | 109 | if return_targets: 110 | targets[~adv_found] = torch.where(l_hat >= labels, l_hat + 1, l_hat) 111 | 112 | if norm == 2: 113 | # 1e-4 added in original implementation 114 | scale = distance[arange, l_hat] / w_prime_norms[arange, l_hat] 115 | adv_inputs.data.addcmul_(batch_view(scale), w_prime[arange, l_hat], value=1 + overshoot) 116 | elif norm == float('inf'): 117 | adv_inputs.data.addcmul_(batch_view(distance[arange, l_hat]), w_prime[arange, l_hat].sign(), 118 | value=1 + overshoot) 119 | adv_inputs.data.clamp_(min=0, max=1) 120 | 121 | if return_unsuccessful and not adv_found.all(): 122 | adv_out[~adv_found] = adv_inputs.detach() 123 | 124 | if return_targets: 125 | return adv_out, targets 126 | 127 | return adv_out 128 | -------------------------------------------------------------------------------- /adv_lib/attacks/fast_adaptive_boundary/__init__.py: -------------------------------------------------------------------------------- 1 | from .fast_adaptive_boundary import fab 2 | -------------------------------------------------------------------------------- /adv_lib/attacks/fast_adaptive_boundary/fast_adaptive_boundary.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/fra31/auto-attack 2 | 3 | import warnings 4 | from functools import partial 5 | from typing import Optional, Tuple 6 | 7 | import torch 8 | from torch import Tensor, nn 9 | from torch.autograd import grad 10 | 11 | from adv_lib.utils.attack_utils import get_all_targets 12 | from .projections import projection_l1, projection_l2, projection_linf 13 | 14 | 15 | def fab(model: nn.Module, 16 | inputs: Tensor, 17 | labels: Tensor, 18 | norm: float, 19 | n_iter: int = 100, 20 | ε: Optional[float] = None, 21 | α_max: float = 0.1, 22 | β: float = 0.9, 23 | η: float = 1.05, 24 | restarts: Optional[int] = None, 25 | targeted_restarts: bool = False, 26 | targeted: bool = False) -> Tensor: 27 | """ 28 | Fast Adaptive Boundary (FAB) attack from https://arxiv.org/abs/1907.02044 29 | 30 | Parameters 31 | ---------- 32 | model : nn.Module 33 | Model to attack. 34 | inputs : Tensor 35 | Inputs to attack. Should be in [0, 1]. 36 | labels : Tensor 37 | Labels corresponding to the inputs if untargeted, else target labels. 38 | norm : float 39 | Norm to minimize in {1, 2 ,float('inf')}. 40 | n_iter : int 41 | Number of optimization steps. This does not correspond to the number of forward / backward propagations for this 42 | attack. For a more comprehensive discussion on complexity, see section 4 of https://arxiv.org/abs/1907.02044 and 43 | for a comparison of complexities, see https://arxiv.org/abs/2011.11857. 44 | TL;DR: FAB performs 2 forwards and K - 1 (i.e. number of classes - 1) backwards per step in default mode. If 45 | `targeted_restarts` is `True`, performs `2 * restarts` forwards and `restarts` or (K - 1) backwards per step. 46 | ε : float 47 | Maximum norm of the random initialization for restarts. 48 | α_max : float 49 | Maximum weight for the biased gradient step. α = 0 corresponds to taking the projection of the `adv_inputs` on 50 | the decision hyperplane, while α = 1 corresponds to taking the projection of the `inputs` on the decision 51 | hyperplane. 52 | β : float 53 | Weight for the biased backward step, i.e. a linear interpolation between `inputs` and `adv_inputs` at step i. 54 | β = 0 corresponds to taking the original `inputs` and β = 1 corresponds to taking the `adv_inputs`. 55 | η : float 56 | Extrapolation for the optimization step. η = 1 corresponds to projecting the `adv_inputs` on the decision 57 | hyperplane. η > 1 corresponds to overshooting to increase the probability of crossing the decision hyperplane. 58 | restarts : int 59 | Number of random restarts in default mode; starts from the inputs in the first run and then add random noise for 60 | the consecutive restarts. Number of classes to attack if `targeted_restarts` is `True`. 61 | targeted_restarts : bool 62 | If `True`, performs targeted attack towards the most likely classes for the unperturbed `inputs`. If `restarts` 63 | is not given, this will attack each class (except the original class). If `restarts` is given, the `restarts` 64 | most likely classes will be attacked. If `restarts` is larger than K - 1, this will re-attack the most likely 65 | classes with random noise. 66 | targeted : bool 67 | Placeholder argument for library. FAB is only for untargeted attacks, so setting this to True will raise a 68 | warning and return the inputs. 69 | 70 | Returns 71 | ------- 72 | adv_inputs : Tensor 73 | Modified inputs to be adversarial to the model. 74 | 75 | """ 76 | if targeted: 77 | warnings.warn('FAB attack is untargeted only. Returning inputs.') 78 | return inputs 79 | 80 | best_adv = inputs.clone() 81 | best_norm = torch.full_like(labels, float('inf'), dtype=torch.float) 82 | 83 | fab_attack = partial(_fab, model=model, norm=norm, n_iter=n_iter, ε=ε, α_max=α_max, β=β, η=η) 84 | 85 | if targeted_restarts: 86 | logits = model(inputs) 87 | n_target_classes = logits.size(1) - 1 88 | labels_infhot = torch.zeros_like(logits).scatter_(1, labels.unsqueeze(1), float('inf')) 89 | k = min(restarts or n_target_classes, n_target_classes) 90 | topk_labels = (logits - labels_infhot).topk(k=k, dim=1).indices 91 | 92 | n_restarts = restarts or (n_target_classes if targeted_restarts else 1) 93 | for i in range(n_restarts): 94 | 95 | if targeted_restarts: 96 | target_labels = topk_labels[:, i % n_target_classes] 97 | adv_inputs_run, adv_found_run, norm_run = fab_attack( 98 | inputs=inputs, labels=labels, random_start=i >= n_target_classes, targets=target_labels, u=best_norm) 99 | else: 100 | adv_inputs_run, adv_found_run, norm_run = fab_attack(inputs=inputs, labels=labels, random_start=i != 0, 101 | u=best_norm) 102 | 103 | is_better_adv = adv_found_run & (norm_run < best_norm) 104 | best_norm[is_better_adv] = norm_run[is_better_adv] 105 | best_adv[is_better_adv] = adv_inputs_run[is_better_adv] 106 | 107 | return best_adv 108 | 109 | 110 | def get_best_diff_logits_grads(model: nn.Module, 111 | inputs: Tensor, 112 | labels: Tensor, 113 | other_labels: Tensor, 114 | q: float) -> Tuple[Tensor, Tensor]: 115 | batch_view = lambda tensor: tensor.view(-1, *[1] * (inputs.ndim - 1)) 116 | min_ratio = torch.full_like(labels, float('inf'), dtype=torch.float) 117 | best_logit_diff, best_grad_diff = torch.zeros_like(labels, dtype=torch.float), torch.zeros_like(inputs) 118 | 119 | inputs.requires_grad_(True) 120 | logits = model(inputs) 121 | class_logits = logits.gather(1, labels.unsqueeze(1)).squeeze(1) 122 | 123 | n_other_labels = other_labels.size(1) 124 | for i, o_labels in enumerate(other_labels.transpose(0, 1)): 125 | other_logits = logits.gather(1, o_labels.unsqueeze(1)).squeeze(1) 126 | logits_diff = other_logits - class_logits 127 | grad_diff = grad(logits_diff.sum(), inputs, only_inputs=True, retain_graph=i + 1 != n_other_labels)[0] 128 | ratio = logits_diff.abs().div_(grad_diff.flatten(1).norm(p=q, dim=1).clamp_(min=1e-12)) 129 | 130 | smaller_ratio = ratio < min_ratio 131 | min_ratio = torch.min(ratio, min_ratio) 132 | best_logit_diff = torch.where(smaller_ratio, logits_diff.detach(), best_logit_diff) 133 | best_grad_diff = torch.where(batch_view(smaller_ratio), grad_diff.detach(), best_grad_diff) 134 | 135 | inputs.detach_() 136 | return best_logit_diff, best_grad_diff 137 | 138 | 139 | def _fab(model: nn.Module, 140 | inputs: Tensor, 141 | labels: Tensor, 142 | norm: float, 143 | n_iter: int = 100, 144 | ε: Optional[float] = None, 145 | α_max: float = 0.1, 146 | β: float = 0.9, 147 | η: float = 1.05, 148 | random_start: bool = False, 149 | u: Optional[Tensor] = None, 150 | targets: Optional[Tensor] = None) -> Tuple[Tensor, Tensor, Tensor]: 151 | _projection_dual_default_ε = { 152 | 1: (projection_l1, float('inf'), 5), 153 | 2: (projection_l2, 2, 1), 154 | float('inf'): (projection_linf, 1, 0.3) 155 | } 156 | 157 | device = inputs.device 158 | batch_size = len(inputs) 159 | batch_view = lambda tensor: tensor.view(-1, *[1] * (inputs.ndim - 1)) 160 | projection, dual_norm, default_ε = _projection_dual_default_ε[norm] 161 | ε = default_ε if ε is None else ε 162 | 163 | logits = model(inputs) 164 | if targets is not None: 165 | other_labels = targets.unsqueeze(1) 166 | else: 167 | other_labels = get_all_targets(labels=labels, num_classes=logits.shape[1]) 168 | 169 | get_df_dg = partial(get_best_diff_logits_grads, model=model, labels=labels, other_labels=other_labels, q=dual_norm) 170 | 171 | adv_inputs = inputs.clone() 172 | adv_found = logits.argmax(dim=1) != labels 173 | best_norm = torch.full((batch_size,), float('inf'), device=device, dtype=torch.float) if u is None else u 174 | best_norm[adv_found] = 0 175 | best_adv = inputs.clone() 176 | 177 | if random_start: 178 | if norm == float('inf'): 179 | t = torch.rand_like(inputs).mul_(2).sub_(1) 180 | elif norm in [1, 2]: 181 | t = torch.randn_like(inputs) 182 | 183 | adv_inputs.add_(t.mul_(batch_view(best_norm.clamp(max=ε) / t.flatten(1).norm(p=norm, dim=1).mul_(2)))) 184 | adv_inputs.clamp_(min=0.0, max=1.0) 185 | 186 | for i in range(n_iter): 187 | df, dg = get_df_dg(inputs=adv_inputs) 188 | b = (dg * adv_inputs).flatten(1).sum(dim=1).sub_(df) 189 | w = dg.flatten(1) 190 | 191 | d3 = projection(torch.cat((adv_inputs.flatten(1), inputs.flatten(1)), 0), w.repeat(2, 1), b.repeat(2)) 192 | d1, d2 = map(lambda t: t.view_as(adv_inputs), torch.chunk(d3, 2, dim=0)) 193 | 194 | a0 = batch_view(d3.flatten(1).norm(p=norm, dim=1).clamp_(min=1e-8)) 195 | a1, a2 = torch.chunk(a0, 2, dim=0) 196 | 197 | α = a1.div_(a2.add_(a1)).clamp_(min=0, max=α_max) 198 | adv_inputs.add_(d1, alpha=η).mul_(1 - α).add_(inputs.add(d2, alpha=η).mul_(α)).clamp_(min=0, max=1) 199 | 200 | is_adv = model(adv_inputs).argmax(1) != labels 201 | adv_found.logical_or_(is_adv) 202 | adv_norm = (adv_inputs - inputs).flatten(1).norm(p=norm, dim=1) 203 | is_smaller = adv_norm < best_norm 204 | is_both = is_adv & is_smaller 205 | best_norm = torch.where(is_both, adv_norm, best_norm) 206 | best_adv = torch.where(batch_view(is_both), adv_inputs, best_adv) 207 | 208 | adv_inputs = torch.where(batch_view(is_adv), inputs + (adv_inputs - inputs) * β, adv_inputs) 209 | 210 | return best_adv, adv_found, best_norm 211 | -------------------------------------------------------------------------------- /adv_lib/attacks/fast_adaptive_boundary/projections.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | from torch import Tensor 5 | from torch.nn import functional as F 6 | 7 | 8 | def projection_l1(points_to_project: Tensor, w_hyperplane: Tensor, b_hyperplane: Tensor) -> Tensor: 9 | device = points_to_project.device 10 | t, w, b = points_to_project, w_hyperplane.clone(), b_hyperplane 11 | 12 | c = (w * t).sum(1).sub_(b) 13 | ind2 = (c >= 0).float().mul_(2).sub_(1) 14 | w.mul_(ind2.unsqueeze(1)) 15 | c.mul_(ind2) 16 | 17 | w_abs = w.abs() 18 | r = (1 / w_abs).clamp_(max=1e12) 19 | indr = torch.argsort(r, dim=1) 20 | indr_rev = torch.argsort(indr) 21 | 22 | d = (w < 0).float().sub_(t).mul_(w != 0) 23 | ds = torch.min(-w * t, (1 - t).mul_(w)).gather(1, indr) 24 | ds2 = torch.cat((c.unsqueeze(-1), ds), 1) 25 | s = torch.cumsum(ds2, dim=1) 26 | 27 | c2 = s[:, -1] < 0 28 | 29 | lb = torch.zeros(c2.sum(), device=device) 30 | ub = torch.full_like(lb, s.shape[1]) 31 | nitermax = math.ceil(math.log2(w.shape[1])) 32 | 33 | s_ = s[c2] 34 | for counter in range(nitermax): 35 | counter4 = (lb + ub).mul_(0.5).floor_() 36 | counter2 = counter4.long().unsqueeze(1) 37 | c3 = s_.gather(1, counter2).squeeze(1) > 0 38 | lb = torch.where(c3, counter4, lb) 39 | ub = torch.where(c3, ub, counter4) 40 | 41 | lb2 = lb.long() 42 | 43 | if c2.any(): 44 | indr = indr[c2].gather(1, lb2.unsqueeze(1)).squeeze(1) 45 | u = torch.arange(0, w.shape[0], device=device).unsqueeze(1) 46 | u2 = torch.arange(0, w.shape[1], device=device, dtype=torch.float).unsqueeze(0) 47 | alpha = s[c2, lb2].neg().div_(w[c2, indr]) 48 | c5 = u2 < lb.unsqueeze(-1) 49 | u3 = c5[u[:c5.shape[0]], indr_rev[c2]] 50 | d[c2] *= u3 51 | d[c2, indr] = alpha 52 | 53 | return d.mul_(w_abs > 1e-8) 54 | 55 | 56 | def projection_l2(points_to_project: Tensor, w_hyperplane: Tensor, b_hyperplane: Tensor) -> Tensor: 57 | device = points_to_project.device 58 | t, w, b = points_to_project, w_hyperplane.clone(), b_hyperplane 59 | 60 | c = (w * t).sum(1).sub_(b) 61 | ind2 = (c >= 0).float().mul_(2).sub_(1) 62 | w.mul_(ind2.unsqueeze(1)) 63 | w_nonzero = w.abs() > 1e-8 64 | c.mul_(ind2) 65 | 66 | r = torch.maximum(t / w, (t - 1).div_(w)).clamp_(min=-1e12, max=1e12) 67 | r.masked_fill_(~w_nonzero, 1e12) 68 | r[r == -1e12] *= -1 69 | rs, indr = torch.sort(r, dim=1) 70 | rs2 = F.pad(rs[:, 1:], (0, 1)) 71 | rs.masked_fill_(rs == 1e12, 0) 72 | rs2.masked_fill_(rs2 == 1e12, 0) 73 | 74 | w3s = w.square().gather(1, indr) 75 | w5 = w3s.sum(dim=1, keepdim=True) 76 | ws = w5 - torch.cumsum(w3s, dim=1) 77 | d = (r * w).neg_() 78 | d.mul_(w_nonzero) 79 | s = torch.cat((-w5 * rs[:, 0:1], torch.cumsum((rs - rs2).mul_(ws), dim=1).sub_(w5 * rs[:, 0:1])), 1) 80 | 81 | c4 = s[:, 0] + c < 0 82 | c3 = (d * w).sum(dim=1).add_(c) > 0 83 | c2 = ~(c4 | c3) 84 | 85 | lb = torch.zeros(c2.sum(), device=device) 86 | ub = torch.full_like(lb, w.shape[1] - 1) 87 | nitermax = math.ceil(math.log2(w.shape[1])) 88 | 89 | s_, c_ = s[c2], c[c2] 90 | for counter in range(nitermax): 91 | counter4 = (lb + ub).mul_(0.5).floor_() 92 | counter2 = counter4.long().unsqueeze(1) 93 | c3 = s_.gather(1, counter2).squeeze(1).add_(c_) > 0 94 | lb = torch.where(c3, counter4, lb) 95 | ub = torch.where(c3, ub, counter4) 96 | 97 | lb = lb.long() 98 | 99 | if c4.any(): 100 | alpha = c[c4] / w5[c4].squeeze(-1) 101 | d[c4] = -alpha.unsqueeze(-1) * w[c4] 102 | 103 | if c2.any(): 104 | alpha = (s[c2, lb] + c[c2]).div_(ws[c2, lb]).add_(rs[c2, lb]) 105 | alpha[ws[c2, lb] == 0] = 0 106 | c5 = alpha.unsqueeze(-1) > r[c2] 107 | d[c2] = (d[c2] * c5).sub_((~c5).float().mul_(alpha.unsqueeze(-1)).mul_(w[c2])) 108 | 109 | return d.mul_(w_nonzero) 110 | 111 | 112 | def projection_linf(points_to_project: Tensor, w_hyperplane: Tensor, b_hyperplane: Tensor) -> Tensor: 113 | device = points_to_project.device 114 | t, w, b = points_to_project, w_hyperplane.clone(), b_hyperplane.clone() 115 | 116 | sign = ((w * t).sum(1).sub_(b) >= 0).float().mul_(2).sub_(1) 117 | w.mul_(sign.unsqueeze(1)) 118 | b.mul_(sign) 119 | 120 | a = (w < 0).float() 121 | d = (a - t).mul_(w != 0) 122 | 123 | p = (2 * a).sub_(1).mul_(t).neg_().add_(a) 124 | indp = torch.argsort(p, dim=1) 125 | 126 | b.sub_((w * t).sum(1)) 127 | b0 = (w * d).sum(1) 128 | 129 | indp2 = indp.flip((1,)) 130 | ws = w.gather(1, indp2) 131 | bs2 = -ws * d.gather(1, indp2) 132 | 133 | s = torch.cumsum(ws.abs_(), dim=1) 134 | sb = torch.cumsum(bs2, dim=1).add_(b0.unsqueeze(1)) 135 | 136 | b2 = sb[:, -1] - s[:, -1] * p.gather(1, indp[:, 0:1]).squeeze(1) 137 | c_l = b - b2 > 0 138 | c2 = (b - b0 > 0) & (~c_l) 139 | lb = torch.zeros(c2.sum(), device=device) 140 | ub = torch.full_like(lb, w.shape[1] - 1) 141 | nitermax = math.ceil(math.log2(w.shape[1])) 142 | 143 | indp_, sb_, s_, p_, b_ = indp[c2], sb[c2], s[c2], p[c2], b[c2] 144 | for counter in range(nitermax): 145 | counter4 = (lb + ub).mul_(0.5).floor_() 146 | 147 | counter2 = counter4.long().unsqueeze(1) 148 | indcurr = indp_.gather(1, indp_.size(1) - 1 - counter2) 149 | b2 = sb_.gather(1, counter2).sub_(s_.gather(1, counter2).mul_(p_.gather(1, indcurr))).squeeze(1) 150 | c = b_ - b2 > 0 151 | 152 | lb = torch.where(c, counter4, lb) 153 | ub = torch.where(c, ub, counter4) 154 | 155 | lb = lb.long() 156 | 157 | if c_l.any(): 158 | lmbd_opt = (b[c_l] - sb[c_l, -1]).div_(-s[c_l, -1]).clamp_(min=0).unsqueeze_(-1) 159 | d[c_l] = (2 * a[c_l]).sub_(1).mul_(lmbd_opt) 160 | 161 | lmbd_opt = (b[c2] - sb[c2, lb]).div_(-s[c2, lb]).clamp_(min=0).unsqueeze_(-1) 162 | d[c2] = torch.minimum(lmbd_opt, d[c2]).mul_(a[c2]).add_(torch.maximum(-lmbd_opt, d[c2]).mul_(1 - a[c2])) 163 | 164 | return d.mul_(w != 0) 165 | -------------------------------------------------------------------------------- /adv_lib/attacks/fast_minimum_norm.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/maurapintor/Fast-Minimum-Norm-FMN-Attack 2 | 3 | import math 4 | from functools import partial 5 | from typing import Optional 6 | 7 | import torch 8 | from torch import Tensor, nn 9 | from torch.autograd import grad 10 | 11 | from adv_lib.utils.losses import difference_of_logits 12 | from adv_lib.utils.projections import clamp_, l1_ball_euclidean_projection 13 | 14 | 15 | def l0_projection_(δ: Tensor, ε: Tensor) -> Tensor: 16 | """In-place l0 projection""" 17 | δ = δ.flatten(1) 18 | δ_abs = δ.abs() 19 | thresholds = δ_abs.topk(k=ε.long().max(), dim=1).values.gather(1, (ε.long().unsqueeze(1) - 1).clamp_(min=0)) 20 | δ[δ_abs < thresholds] = 0 21 | return δ 22 | 23 | 24 | def l1_projection_(δ: Tensor, ε: Tensor) -> Tensor: 25 | """In-place l1 projection""" 26 | δ = l1_ball_euclidean_projection(x=δ.flatten(1), ε=ε, inplace=True) 27 | return δ 28 | 29 | 30 | def l2_projection_(δ: Tensor, ε: Tensor) -> Tensor: 31 | """In-place l2 projection""" 32 | δ = δ.flatten(1) 33 | l2_norms = δ.norm(p=2, dim=1, keepdim=True).clamp_(min=1e-12) 34 | δ.mul_(ε.unsqueeze(1) / l2_norms).clamp_(max=1) 35 | return δ 36 | 37 | 38 | def linf_projection_(δ: Tensor, ε: Tensor) -> Tensor: 39 | """In-place linf projection""" 40 | δ, ε = δ.flatten(1), ε.unsqueeze(1) 41 | δ = clamp_(δ, lower=-ε, upper=ε) 42 | return δ 43 | 44 | 45 | def l0_mid_points(x0: Tensor, x1: Tensor, ε: Tensor) -> Tensor: 46 | n_features = x0[0].numel() 47 | δ = l0_projection_(δ=x1 - x0, ε=n_features * ε) 48 | return δ.view_as(x0).add_(x0) 49 | 50 | 51 | def l1_mid_points(x0: Tensor, x1: Tensor, ε: Tensor) -> Tensor: 52 | threshold = (1 - ε).unsqueeze(1) 53 | δ = (x1 - x0).flatten(1) 54 | δ_abs = δ.abs() 55 | mask = δ_abs <= threshold 56 | mid_points = δ_abs.sub_(threshold).copysign_(δ) 57 | mid_points[mask] = 0 58 | return mid_points.view_as(x0).add_(x0) 59 | 60 | 61 | def l2_mid_points(x0: Tensor, x1: Tensor, ε: Tensor) -> Tensor: 62 | return torch.lerp(x0.flatten(1), x1.flatten(1), weight=ε.unsqueeze(1)).view_as(x0) 63 | 64 | 65 | def linf_mid_points(x0: Tensor, x1: Tensor, ε: Tensor) -> Tensor: 66 | ε = ε.unsqueeze(1) 67 | δ = clamp_((x1 - x0).flatten(1), lower=-ε, upper=ε) 68 | return δ.view_as(x0).add_(x0) 69 | 70 | 71 | def fmn(model: nn.Module, 72 | inputs: Tensor, 73 | labels: Tensor, 74 | norm: float, 75 | targeted: bool = False, 76 | steps: int = 10, 77 | α_init: float = 1.0, 78 | α_final: Optional[float] = None, 79 | γ_init: float = 0.05, 80 | γ_final: float = 0.001, 81 | starting_points: Optional[Tensor] = None, 82 | binary_search_steps: int = 10) -> Tensor: 83 | """ 84 | Fast Minimum-Norm attack from https://arxiv.org/abs/2102.12827. 85 | 86 | Parameters 87 | ---------- 88 | model : nn.Module 89 | Model to attack. 90 | inputs : Tensor 91 | Inputs to attack. Should be in [0, 1]. 92 | labels : Tensor 93 | Labels corresponding to the inputs if untargeted, else target labels. 94 | norm : float 95 | Norm to minimize in {0, 1, 2 ,float('inf')}. 96 | targeted : bool 97 | Whether to perform a targeted attack or not. 98 | steps : int 99 | Number of optimization steps. 100 | α_init : float 101 | Initial step size. 102 | α_final : float 103 | Final step size after cosine annealing. 104 | γ_init : float 105 | Initial factor by which ε is modified: ε = ε * (1 + or - γ). 106 | γ_final : float 107 | Final factor, after cosine annealing, by which ε is modified. 108 | starting_points : Tensor 109 | Optional warm-start for the attack. 110 | binary_search_steps : int 111 | Number of binary search steps to find the decision boundary between inputs and starting_points. 112 | 113 | Returns 114 | ------- 115 | adv_inputs : Tensor 116 | Modified inputs to be adversarial to the model. 117 | 118 | """ 119 | _dual_projection_mid_points = { 120 | 0: (None, l0_projection_, l0_mid_points), 121 | 1: (float('inf'), l1_projection_, l1_mid_points), 122 | 2: (2, l2_projection_, l2_mid_points), 123 | float('inf'): (1, linf_projection_, linf_mid_points), 124 | } 125 | if inputs.min() < 0 or inputs.max() > 1: raise ValueError('Input values should be in the [0, 1] range.') 126 | device = inputs.device 127 | batch_size = len(inputs) 128 | batch_view = lambda tensor: tensor.view(batch_size, *[1] * (inputs.ndim - 1)) 129 | dual, projection, mid_point = _dual_projection_mid_points[norm] 130 | α_final = α_init / 100 if α_final is None else α_final 131 | multiplier = 1 if targeted else -1 132 | 133 | # If starting_points is provided, search for the boundary 134 | if starting_points is not None: 135 | start_preds = model(starting_points).argmax(dim=1) 136 | is_adv = (start_preds == labels) if targeted else (start_preds != labels) 137 | if not is_adv.all(): 138 | raise ValueError('Starting points are not all adversarial.') 139 | lower_bound = torch.zeros(batch_size, device=device) 140 | upper_bound = torch.ones(batch_size, device=device) 141 | for _ in range(binary_search_steps): 142 | ε = (lower_bound + upper_bound) / 2 143 | mid_points = mid_point(x0=inputs, x1=starting_points, ε=ε) 144 | pred_labels = model(mid_points).argmax(dim=1) 145 | is_adv = (pred_labels == labels) if targeted else (pred_labels != labels) 146 | lower_bound = torch.where(is_adv, lower_bound, ε) 147 | upper_bound = torch.where(is_adv, ε, upper_bound) 148 | 149 | δ = mid_point(x0=inputs, x1=starting_points, ε=upper_bound).sub_(inputs) 150 | else: 151 | δ = torch.zeros_like(inputs) 152 | δ.requires_grad_(True) 153 | 154 | if norm == 0: 155 | ε = torch.ones(batch_size, device=device) if starting_points is None else δ.flatten(1).norm(p=0, dim=1) 156 | else: 157 | ε = torch.full((batch_size,), float('inf'), device=device) 158 | 159 | # Init trackers 160 | worst_norm = torch.maximum(inputs, 1 - inputs).flatten(1).norm(p=norm, dim=1) 161 | best_norm = worst_norm.clone() 162 | best_adv = inputs.clone() 163 | adv_found = torch.zeros(batch_size, dtype=torch.bool, device=device) 164 | 165 | for i in range(steps): 166 | cosine = (1 + math.cos(math.pi * i / steps)) / 2 167 | α = α_final + (α_init - α_final) * cosine 168 | γ = γ_final + (γ_init - γ_final) * cosine 169 | 170 | δ_norm = δ.data.flatten(1).norm(p=norm, dim=1) 171 | adv_inputs = inputs + δ 172 | logits = model(adv_inputs) 173 | pred_labels = logits.argmax(dim=1) 174 | 175 | if i == 0: 176 | labels_infhot = torch.zeros_like(logits).scatter_(1, labels.unsqueeze(1), float('inf')) 177 | logit_diff_func = partial(difference_of_logits, labels=labels, labels_infhot=labels_infhot) 178 | 179 | logit_diffs = logit_diff_func(logits=logits) 180 | loss = multiplier * logit_diffs 181 | δ_grad = grad(loss.sum(), δ, only_inputs=True)[0] 182 | 183 | is_adv = (pred_labels == labels) if targeted else (pred_labels != labels) 184 | is_smaller = δ_norm < best_norm 185 | is_both = is_adv & is_smaller 186 | adv_found.logical_or_(is_adv) 187 | best_norm = torch.where(is_both, δ_norm, best_norm) 188 | best_adv = torch.where(batch_view(is_both), adv_inputs.detach(), best_adv) 189 | 190 | if norm == 0: 191 | ε = torch.where(is_adv, 192 | torch.minimum(torch.minimum(ε - 1, (ε * (1 - γ)).floor_()), best_norm), 193 | torch.maximum(ε + 1, (ε * (1 + γ)).floor_())) 194 | ε.clamp_(min=0) 195 | else: 196 | distance_to_boundary = loss.detach().abs_().div_(δ_grad.flatten(1).norm(p=dual, dim=1).clamp_(min=1e-12)) 197 | ε = torch.where(is_adv, 198 | torch.minimum(ε * (1 - γ), best_norm), 199 | torch.where(adv_found, ε * (1 + γ), δ_norm + distance_to_boundary)) 200 | 201 | # clip ε 202 | ε = torch.minimum(ε, worst_norm) 203 | # gradient ascent step with normalized gradient 204 | grad_l2_norms = δ_grad.flatten(1).norm(p=2, dim=1).clamp_(min=1e-12) 205 | δ.data.addcdiv_(δ_grad, batch_view(grad_l2_norms), value=α) 206 | # project in place 207 | projection(δ=δ.data, ε=ε) 208 | # clamp 209 | δ.data.add_(inputs).clamp_(min=0, max=1).sub_(inputs) 210 | 211 | return best_adv 212 | -------------------------------------------------------------------------------- /adv_lib/attacks/perceptual_color_attacks/__init__.py: -------------------------------------------------------------------------------- 1 | from .perceptual_color_distance_al import perc_al -------------------------------------------------------------------------------- /adv_lib/attacks/perceptual_color_attacks/differential_color_functions.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/ZhengyuZhao/PerC-Adversarial 2 | import numpy as np 3 | import torch 4 | from torch import Tensor 5 | 6 | 7 | def rgb2xyz(rgb_image): 8 | device = rgb_image.device 9 | mt = torch.tensor([[0.4124, 0.3576, 0.1805], 10 | [0.2126, 0.7152, 0.0722], 11 | [0.0193, 0.1192, 0.9504]], device=device) 12 | mask1 = (rgb_image > 0.0405).float() 13 | mask1_no = 1 - mask1 14 | temp_img = mask1 * (((rgb_image + 0.055) / 1.055) ** 2.4) 15 | temp_img = temp_img + mask1_no * (rgb_image / 12.92) 16 | temp_img = 100 * temp_img 17 | 18 | res = torch.matmul(mt, temp_img.permute(1, 0, 2, 3).contiguous().view(3, -1)).view( 19 | 3, rgb_image.size(0), rgb_image.size(2), rgb_image.size(3)).permute(1, 0, 2, 3) 20 | return res 21 | 22 | 23 | def xyz_lab(xyz_image): 24 | mask_value_0 = (xyz_image == 0).float() 25 | mask_value_0_no = 1 - mask_value_0 26 | xyz_image = xyz_image + 0.0001 * mask_value_0 27 | mask1 = (xyz_image > 0.008856).float() 28 | mask1_no = 1 - mask1 29 | res = mask1 * (xyz_image) ** (1 / 3) 30 | res = res + mask1_no * ((7.787 * xyz_image) + (16 / 116)) 31 | res = res * mask_value_0_no 32 | return res 33 | 34 | 35 | def rgb2lab_diff(rgb_image): 36 | """ 37 | Function to convert a batch of image tensors from RGB space to CIELAB space. 38 | parameters: xn, yn, zn are the CIE XYZ tristimulus values of the reference white point. 39 | Here use the standard Illuminant D65 with normalization Y = 100. 40 | """ 41 | rgb_image = rgb_image 42 | res = torch.zeros_like(rgb_image) 43 | xyz_image = rgb2xyz(rgb_image) 44 | 45 | xn = 95.0489 46 | yn = 100 47 | zn = 108.8840 48 | 49 | x = xyz_image[:, 0, :, :] 50 | y = xyz_image[:, 1, :, :] 51 | z = xyz_image[:, 2, :, :] 52 | 53 | L = 116 * xyz_lab(y / yn) - 16 54 | a = 500 * (xyz_lab(x / xn) - xyz_lab(y / yn)) 55 | b = 200 * (xyz_lab(y / yn) - xyz_lab(z / zn)) 56 | res[:, 0, :, :] = L 57 | res[:, 1, :, :] = a 58 | res[:, 2, :, :] = b 59 | 60 | return res 61 | 62 | 63 | def degrees(n): return n * (180. / np.pi) 64 | 65 | 66 | def radians(n): return n * (np.pi / 180.) 67 | 68 | 69 | def hpf_diff(x, y): 70 | mask1 = ((x == 0) * (y == 0)).float() 71 | mask1_no = 1 - mask1 72 | 73 | tmphp = degrees(torch.atan2(x * mask1_no, y * mask1_no)) 74 | tmphp1 = tmphp * (tmphp >= 0).float() 75 | tmphp2 = (360 + tmphp) * (tmphp < 0).float() 76 | 77 | return tmphp1 + tmphp2 78 | 79 | 80 | def dhpf_diff(c1, c2, h1p, h2p): 81 | mask1 = ((c1 * c2) == 0).float() 82 | mask1_no = 1 - mask1 83 | res1 = (h2p - h1p) * mask1_no * (torch.abs(h2p - h1p) <= 180).float() 84 | res2 = ((h2p - h1p) - 360) * ((h2p - h1p) > 180).float() * mask1_no 85 | res3 = ((h2p - h1p) + 360) * ((h2p - h1p) < -180).float() * mask1_no 86 | 87 | return res1 + res2 + res3 88 | 89 | 90 | def ahpf_diff(c1, c2, h1p, h2p): 91 | mask1 = ((c1 * c2) == 0).float() 92 | mask1_no = 1 - mask1 93 | mask2 = (torch.abs(h2p - h1p) <= 180).float() 94 | mask2_no = 1 - mask2 95 | mask3 = (torch.abs(h2p + h1p) < 360).float() 96 | mask3_no = 1 - mask3 97 | 98 | res1 = (h1p + h2p) * mask1_no * mask2 99 | res2 = (h1p + h2p + 360.) * mask1_no * mask2_no * mask3 100 | res3 = (h1p + h2p - 360.) * mask1_no * mask2_no * mask3_no 101 | res = (res1 + res2 + res3) + (res1 + res2 + res3) * mask1 102 | return res * 0.5 103 | 104 | 105 | def ciede2000_diff(lab1, lab2): 106 | """ 107 | CIEDE2000 metric to claculate the color distance map for a batch of image tensors defined in CIELAB space 108 | 109 | This version contains errors: 110 | - Typo in the formula for T: "- 39" should be "- 30" 111 | - 0.0001s added for numerical stability change the conditional values 112 | 113 | """ 114 | L1 = lab1[:, 0, :, :] 115 | A1 = lab1[:, 1, :, :] 116 | B1 = lab1[:, 2, :, :] 117 | L2 = lab2[:, 0, :, :] 118 | A2 = lab2[:, 1, :, :] 119 | B2 = lab2[:, 2, :, :] 120 | kL = 1 121 | kC = 1 122 | kH = 1 123 | 124 | mask_value_0_input1 = ((A1 == 0) * (B1 == 0)).float() 125 | mask_value_0_input2 = ((A2 == 0) * (B2 == 0)).float() 126 | mask_value_0_input1_no = 1 - mask_value_0_input1 127 | mask_value_0_input2_no = 1 - mask_value_0_input2 128 | B1 = B1 + 0.0001 * mask_value_0_input1 129 | B2 = B2 + 0.0001 * mask_value_0_input2 130 | 131 | C1 = torch.sqrt((A1 ** 2.) + (B1 ** 2.)) 132 | C2 = torch.sqrt((A2 ** 2.) + (B2 ** 2.)) 133 | 134 | aC1C2 = (C1 + C2) / 2. 135 | G = 0.5 * (1. - torch.sqrt((aC1C2 ** 7.) / ((aC1C2 ** 7.) + (25 ** 7.)))) 136 | a1P = (1. + G) * A1 137 | a2P = (1. + G) * A2 138 | c1P = torch.sqrt((a1P ** 2.) + (B1 ** 2.)) 139 | c2P = torch.sqrt((a2P ** 2.) + (B2 ** 2.)) 140 | 141 | h1P = hpf_diff(B1, a1P) 142 | h2P = hpf_diff(B2, a2P) 143 | h1P = h1P * mask_value_0_input1_no 144 | h2P = h2P * mask_value_0_input2_no 145 | 146 | dLP = L2 - L1 147 | dCP = c2P - c1P 148 | dhP = dhpf_diff(C1, C2, h1P, h2P) 149 | dHP = 2. * torch.sqrt(c1P * c2P) * torch.sin(radians(dhP) / 2.) 150 | mask_0_no = 1 - torch.max(mask_value_0_input1, mask_value_0_input2) 151 | dHP = dHP * mask_0_no 152 | 153 | aL = (L1 + L2) / 2. 154 | aCP = (c1P + c2P) / 2. 155 | aHP = ahpf_diff(C1, C2, h1P, h2P) 156 | T = 1. - 0.17 * torch.cos(radians(aHP - 39)) + 0.24 * torch.cos(radians(2. * aHP)) + 0.32 * torch.cos( 157 | radians(3. * aHP + 6.)) - 0.2 * torch.cos(radians(4. * aHP - 63.)) 158 | dRO = 30. * torch.exp(-1. * (((aHP - 275.) / 25.) ** 2.)) 159 | rC = torch.sqrt((aCP ** 7.) / ((aCP ** 7.) + (25. ** 7.))) 160 | sL = 1. + ((0.015 * ((aL - 50.) ** 2.)) / torch.sqrt(20. + ((aL - 50.) ** 2.))) 161 | 162 | sC = 1. + 0.045 * aCP 163 | sH = 1. + 0.015 * aCP * T 164 | rT = -2. * rC * torch.sin(radians(2. * dRO)) 165 | 166 | # res_square=((dLP / (sL * kL)) ** 2.) + ((dCP / (sC * kC)) ** 2.) + ((dHP / (sH * kH)) ** 2.) + rT * (dCP / (sC * kC)) * (dHP / (sH * kH)) 167 | 168 | res_square = ((dLP / (sL * kL)) ** 2.) + ((dCP / (sC * kC)) ** 2.) * mask_0_no + ( 169 | (dHP / (sH * kH)) ** 2.) * mask_0_no + rT * (dCP / (sC * kC)) * (dHP / (sH * kH)) * mask_0_no 170 | mask_0 = (res_square <= 0).float() 171 | mask_0_no = 1 - mask_0 172 | res_square = res_square + 0.0001 * mask_0 173 | res = torch.sqrt(res_square) 174 | res = res * mask_0_no 175 | 176 | return res 177 | 178 | 179 | def ciede2000_loss(input: Tensor, target: Tensor) -> Tensor: 180 | Lab_input, Lab_target = map(rgb2lab_diff, (input, target)) 181 | return ciede2000_diff(Lab_input, Lab_target).flatten(1).norm(p=2, dim=1) 182 | -------------------------------------------------------------------------------- /adv_lib/attacks/perceptual_color_attacks/perceptual_color_distance_al.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/ZhengyuZhao/PerC-Adversarial 2 | 3 | from math import pi, cos 4 | 5 | import torch 6 | from torch import nn, Tensor 7 | from torch.autograd import grad 8 | 9 | from .differential_color_functions import rgb2lab_diff, ciede2000_diff 10 | 11 | 12 | def quantization(x): 13 | """quantize the continus image tensors into 255 levels (8 bit encoding)""" 14 | x_quan = torch.round(x * 255) / 255 15 | return x_quan 16 | 17 | 18 | def perc_al(model: nn.Module, 19 | images: Tensor, 20 | labels: Tensor, 21 | num_classes: int, 22 | targeted: bool = False, 23 | max_iterations: int = 1000, 24 | alpha_l_init: float = 1., 25 | alpha_c_init: float = 0.5, 26 | confidence: float = 0, **kwargs) -> Tensor: 27 | """ 28 | PerC_AL: Alternating Loss of Classification and Color Differences to achieve imperceptibile perturbations with few 29 | iterations. Adapted from https://github.com/ZhengyuZhao/PerC-Adversarial. 30 | 31 | Parameters 32 | ---------- 33 | model : nn.Module 34 | Model to fool. 35 | images : Tensor 36 | Batch of image examples in the range of [0,1]. 37 | labels : Tensor 38 | Original labels if untargeted, else labels of targets. 39 | targeted : bool, optional 40 | Whether to perform a targeted adversary or not. 41 | max_iterations : int 42 | Number of iterations for the optimization. 43 | alpha_l_init: float 44 | step size for updating perturbations with respect to classification loss 45 | alpha_c_init: float 46 | step size for updating perturbations with respect to perceptual color differences. for relatively easy 47 | untargeted case, alpha_c_init is adjusted to a smaller value (e.g., 0.1 is used in the paper) 48 | confidence : float, optional 49 | Confidence of the adversary for Carlini's loss, in term of distance between logits. 50 | Note that this approach only supports confidence setting in an untargeted case 51 | 52 | Returns 53 | ------- 54 | Tensor 55 | Batch of image samples modified to be adversarial 56 | """ 57 | 58 | if images.min() < 0 or images.max() > 1: raise ValueError('Input values should be in the [0, 1] range.') 59 | device = images.device 60 | 61 | alpha_l_min = alpha_l_init / 100 62 | alpha_c_min = alpha_c_init / 10 63 | multiplier = -1 if targeted else 1 64 | 65 | X_adv_round_best = images.clone() 66 | inputs_LAB = rgb2lab_diff(images) 67 | batch_size = images.shape[0] 68 | delta = torch.zeros_like(images, requires_grad=True) 69 | mask_isadv = torch.zeros(batch_size, dtype=torch.bool, device=device) 70 | color_l2_delta_bound_best = torch.full((batch_size,), 100000, dtype=torch.float, device=device) 71 | 72 | if (targeted == False) and confidence != 0: 73 | labels_onehot = torch.zeros(labels.size(0), num_classes, device=device) 74 | labels_onehot.scatter_(1, labels.unsqueeze(1), 1) 75 | labels_infhot = torch.zeros_like(labels_onehot).scatter_(1, labels.unsqueeze(1), float('inf')) 76 | if (targeted == True) and confidence != 0: 77 | print('Only support setting confidence in untargeted case!') 78 | return 79 | 80 | # check if some images are already adversarial 81 | if (targeted == False) and confidence != 0: 82 | logits = model(images) 83 | real = logits.gather(1, labels.unsqueeze(1)).squeeze(1) 84 | other = (logits - labels_infhot).amax(dim=1) 85 | mask_isadv = (real - other) <= -40 86 | elif confidence == 0: 87 | if targeted: 88 | mask_isadv = model(images).argmax(1) == labels 89 | else: 90 | mask_isadv = model(images).argmax(1) != labels 91 | color_l2_delta_bound_best[mask_isadv] = 0 92 | X_adv_round_best[mask_isadv] = images[mask_isadv] 93 | 94 | for i in range(max_iterations): 95 | # cosine annealing for alpha_l_init and alpha_c_init 96 | alpha_c = alpha_c_min + 0.5 * (alpha_c_init - alpha_c_min) * (1 + cos(i / max_iterations * pi)) 97 | alpha_l = alpha_l_min + 0.5 * (alpha_l_init - alpha_l_min) * (1 + cos(i / max_iterations * pi)) 98 | 99 | loss = multiplier * nn.CrossEntropyLoss(reduction='sum')(model(images + delta), labels) 100 | grad_a = grad(loss, delta, only_inputs=True)[0] 101 | delta.data[~mask_isadv] = delta.data[~mask_isadv] + alpha_l * (grad_a.permute(1, 2, 3, 0) / torch.norm( 102 | grad_a.flatten(1), dim=1)).permute(3, 0, 1, 2)[~mask_isadv] 103 | 104 | d_map = ciede2000_diff(inputs_LAB, rgb2lab_diff(images + delta)).unsqueeze(1) 105 | color_dis = torch.norm(d_map.flatten(1), dim=1) 106 | grad_color = grad(color_dis.sum(), delta, only_inputs=True)[0] 107 | delta.data[mask_isadv] = delta.data[mask_isadv] - alpha_c * (grad_color.permute(1, 2, 3, 0) / torch.norm( 108 | grad_color.flatten(1), dim=1)).permute(3, 0, 1, 2)[mask_isadv] 109 | 110 | delta.data = (images + delta.data).clamp_(min=0, max=1) - images 111 | X_adv_round = quantization(images + delta.data) 112 | 113 | if (targeted == False) and confidence != 0: 114 | logits = model(X_adv_round) 115 | real = logits.gather(1, labels.unsqueeze(1)).squeeze(1) 116 | other = (logits - labels_infhot).amax(dim=1) 117 | mask_isadv = (real - other) <= -40 118 | elif confidence == 0: 119 | if targeted: 120 | mask_isadv = model(X_adv_round).argmax(1) == labels 121 | else: 122 | mask_isadv = model(X_adv_round).argmax(1) != labels 123 | mask_best = (color_dis.data < color_l2_delta_bound_best) 124 | mask = mask_best * mask_isadv 125 | color_l2_delta_bound_best[mask] = color_dis.data[mask] 126 | X_adv_round_best[mask] = X_adv_round[mask] 127 | 128 | return X_adv_round_best 129 | -------------------------------------------------------------------------------- /adv_lib/attacks/projected_gradient_descent.py: -------------------------------------------------------------------------------- 1 | import numbers 2 | from functools import partial 3 | from typing import Optional, Tuple, Union 4 | 5 | import torch 6 | from torch import Tensor, nn 7 | from torch.autograd import grad 8 | from torch.nn import functional as F 9 | 10 | from adv_lib.utils.losses import difference_of_logits, difference_of_logits_ratio 11 | from adv_lib.utils.projections import clamp_ 12 | from adv_lib.utils.visdom_logger import VisdomLogger 13 | 14 | 15 | def pgd_linf(model: nn.Module, 16 | inputs: Tensor, 17 | labels: Tensor, 18 | ε: Union[float, Tensor], 19 | targeted: bool = False, 20 | steps: int = 40, 21 | random_init: bool = True, 22 | restarts: int = 1, 23 | loss_function: str = 'ce', 24 | relative_step_size: float = 0.01 / 0.3, 25 | absolute_step_size: Optional[float] = None, 26 | callback: Optional[VisdomLogger] = None) -> Tensor: 27 | device = inputs.device 28 | batch_size = len(inputs) 29 | 30 | adv_inputs = inputs.clone() 31 | adv_found = torch.zeros(batch_size, dtype=torch.bool, device=device) 32 | 33 | if isinstance(ε, numbers.Real): 34 | ε = torch.full_like(adv_found, ε, dtype=inputs.dtype) 35 | 36 | pgd_attack = partial(_pgd_linf, model=model, targeted=targeted, steps=steps, random_init=random_init, 37 | loss_function=loss_function, relative_step_size=relative_step_size, 38 | absolute_step_size=absolute_step_size) 39 | 40 | for i in range(restarts): 41 | 42 | adv_found_run, adv_inputs_run = pgd_attack(inputs=inputs[~adv_found], labels=labels[~adv_found], 43 | ε=ε[~adv_found]) 44 | adv_inputs[~adv_found] = adv_inputs_run 45 | adv_found[~adv_found] = adv_found_run 46 | 47 | if callback: 48 | callback.line('success', i + 1, adv_found.float().mean()) 49 | 50 | if adv_found.all(): 51 | break 52 | 53 | return adv_inputs 54 | 55 | 56 | def _pgd_linf(model: nn.Module, 57 | inputs: Tensor, 58 | labels: Tensor, 59 | ε: Tensor, 60 | targeted: bool = False, 61 | steps: int = 40, 62 | random_init: bool = True, 63 | loss_function: str = 'ce', 64 | relative_step_size: float = 0.01 / 0.3, 65 | absolute_step_size: Optional[float] = None) -> Tuple[Tensor, Tensor]: 66 | _loss_functions = { 67 | 'ce': (partial(F.cross_entropy, reduction='none'), 1), 68 | 'dl': (difference_of_logits, -1), 69 | 'dlr': (partial(difference_of_logits_ratio, targeted=targeted), -1), 70 | } 71 | 72 | device = inputs.device 73 | batch_size = len(inputs) 74 | batch_view = lambda tensor: tensor.view(batch_size, *[1] * (inputs.ndim - 1)) 75 | lower, upper = torch.maximum(-inputs, -batch_view(ε)), torch.minimum(1 - inputs, batch_view(ε)) 76 | 77 | loss_func, multiplier = _loss_functions[loss_function.lower()] 78 | 79 | step_size: Tensor = ε * relative_step_size if absolute_step_size is None else torch.full_like(ε, absolute_step_size) 80 | if targeted: 81 | step_size *= -1 82 | 83 | δ = torch.zeros_like(inputs, requires_grad=True) 84 | best_adv = inputs.clone() 85 | adv_found = torch.zeros(batch_size, dtype=torch.bool, device=device) 86 | 87 | if random_init: 88 | δ.data.uniform_(-1, 1).mul_(batch_view(ε)) 89 | clamp_(δ, lower=lower, upper=upper) 90 | 91 | for i in range(steps): 92 | adv_inputs = inputs + δ 93 | logits = model(adv_inputs) 94 | 95 | if i == 0 and loss_function.lower() in ['dl', 'dlr']: 96 | labels_infhot = torch.zeros_like(logits).scatter_(1, labels.unsqueeze(1), float('inf')) 97 | loss_func = partial(loss_func, labels_infhot=labels_infhot) 98 | 99 | loss = multiplier * loss_func(logits, labels) 100 | δ_grad = grad(loss.sum(), δ, only_inputs=True)[0].sign_() 101 | 102 | is_adv = (logits.argmax(1) == labels) if targeted else (logits.argmax(1) != labels) 103 | best_adv = torch.where(batch_view(is_adv), adv_inputs.detach(), best_adv) 104 | adv_found.logical_or_(is_adv) 105 | 106 | δ.data.addcmul_(batch_view(step_size), δ_grad) 107 | clamp_(δ, lower=lower, upper=upper) 108 | 109 | return adv_found, best_adv 110 | -------------------------------------------------------------------------------- /adv_lib/attacks/segmentation/__init__.py: -------------------------------------------------------------------------------- 1 | from .alma_prox import alma_prox 2 | from .asma import asma 3 | from .dense_adversary import dag 4 | from .primal_dual_gradient_descent import pdgd, pdpgd 5 | -------------------------------------------------------------------------------- /adv_lib/attacks/segmentation/asma.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | from adv_lib.utils.visdom_logger import VisdomLogger 5 | from torch import Tensor, nn 6 | from torch.autograd import grad 7 | 8 | 9 | def iou_masks(mask1: Tensor, mask2: Tensor, n: int): 10 | k = (mask1 >= 0) & (mask1 < n) 11 | inds = n * mask1[k].to(torch.int64) + mask2[k] 12 | mat = torch.bincount(inds, minlength=n ** 2).reshape(n, n) 13 | iu = torch.diag(mat) / (mat.sum(1) + mat.sum(0) - torch.diag(mat) + 1e-6) 14 | return iu.mean().item() 15 | 16 | 17 | def asma(model: nn.Module, 18 | inputs: Tensor, 19 | labels: Tensor, 20 | masks: Tensor = None, 21 | targeted: bool = False, 22 | adv_threshold: float = 0.99, 23 | num_steps: int = 1000, 24 | τ: float = 1e-7, 25 | β: float = 1e-6, 26 | callback: Optional[VisdomLogger] = None) -> Tensor: 27 | "ASMA attack from https://arxiv.org/abs/1907.13124" 28 | attack_name = 'ASMA' 29 | device = inputs.device 30 | batch_size = len(inputs) 31 | batch_view = lambda tensor: tensor.view(batch_size, *[1] * (inputs.ndim - 1)) 32 | 33 | # Setup variables 34 | δ = torch.zeros_like(inputs, requires_grad=True) 35 | lower, upper = -inputs, 1 - inputs 36 | pert_mul = τ 37 | 38 | # Init trackers 39 | best_dist = torch.full((batch_size,), float('inf'), device=device) 40 | best_adv_percent = torch.zeros_like(best_dist) 41 | adv_found = torch.zeros_like(best_dist, dtype=torch.bool) 42 | best_adv = inputs.clone() 43 | 44 | for i in range(num_steps): 45 | 46 | adv_inputs = inputs + δ 47 | logits = model(adv_inputs) 48 | l2_squared = δ.flatten(1).square().sum(dim=1) 49 | 50 | if i == 0: 51 | # initialize variables based on model's output 52 | num_classes = logits.size(1) 53 | if masks is None: 54 | masks = labels < num_classes 55 | masks_sum = masks.flatten(1).sum(dim=1) 56 | labels_ = labels * masks 57 | 58 | # track progress 59 | pred = logits.argmax(dim=1) 60 | pixel_is_adv = (pred == labels) if targeted else (pred != labels) 61 | adv_percent = (pixel_is_adv & masks).flatten(1).sum(dim=1) / masks_sum 62 | is_adv = adv_percent >= adv_threshold 63 | is_smaller = l2_squared <= best_dist 64 | improves_constraints = adv_percent >= best_adv_percent.clamp_max(adv_threshold) 65 | is_better_adv = (is_smaller & is_adv) | (~adv_found & improves_constraints) 66 | adv_found.logical_or_(is_adv) 67 | best_dist = torch.where(is_better_adv, l2_squared.detach(), best_dist) 68 | best_adv_percent = torch.where(is_better_adv, adv_percent, best_adv_percent) 69 | best_adv = torch.where(batch_view(is_better_adv), adv_inputs.detach(), best_adv) 70 | 71 | iou = iou_masks(labels, pred, n=num_classes) 72 | if i: 73 | pert_mul = β * iou + τ 74 | 75 | logit_loss = logits.gather(1, labels_.unsqueeze(1)).squeeze(1).mul(masks & (pred != labels_)).sum() 76 | loss = logit_loss - l2_squared 77 | δ_grad = grad(loss.sum(), δ, only_inputs=True)[0] 78 | 79 | δ.data.add_(δ_grad, alpha=pert_mul).clamp_(min=lower, max=upper) 80 | 81 | if callback: 82 | callback.accumulate_line('logit_loss', i, logit_loss.mean(), title=attack_name + ' - Logit loss') 83 | callback.accumulate_line(['adv%', 'best_adv%'], i, [adv_percent.mean(), best_adv_percent.mean()], 84 | title=attack_name + ' - APSR') 85 | callback.accumulate_line(['ℓ2', 'best ℓ2'], i, [l2_squared.detach().sqrt().mean(), best_dist.sqrt().mean()], 86 | title=attack_name + ' - L2 Norms') 87 | callback.accumulate_line('lr', i, pert_mul, title=attack_name + ' - Step size') 88 | callback.accumulate_line('IoU', i, iou, title=attack_name + ' - IoU') 89 | if (i + 1) % (num_steps // 20) == 0 or (i + 1) == num_steps: 90 | callback.update_lines() 91 | 92 | return best_adv 93 | -------------------------------------------------------------------------------- /adv_lib/attacks/segmentation/dense_adversary.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | from adv_lib.utils.losses import difference_of_logits 5 | from adv_lib.utils.visdom_logger import VisdomLogger 6 | from torch import Tensor, nn 7 | from torch.autograd import grad 8 | 9 | 10 | def dag(model: nn.Module, 11 | inputs: Tensor, 12 | labels: Tensor, 13 | masks: Tensor = None, 14 | targeted: bool = False, 15 | adv_threshold: float = 0.99, 16 | max_iter: int = 200, 17 | γ: float = 0.5, 18 | p: float = float('inf'), 19 | callback: Optional[VisdomLogger] = None) -> Tensor: 20 | """DAG attack from https://arxiv.org/abs/1703.08603""" 21 | device = inputs.device 22 | batch_size = len(inputs) 23 | batch_view = lambda tensor: tensor.view(-1, *[1] * (inputs.ndim - 1)) 24 | multiplier = -1 if targeted else 1 25 | 26 | # Setup variables 27 | r = torch.zeros_like(inputs) 28 | 29 | # Init trackers 30 | best_adv_percent = torch.zeros(batch_size, device=device) 31 | adv_found = torch.zeros_like(best_adv_percent, dtype=torch.bool) 32 | best_adv = inputs.clone() 33 | 34 | for i in range(max_iter): 35 | 36 | active_inputs = ~adv_found 37 | inputs_ = inputs[active_inputs] 38 | r_ = r[active_inputs] 39 | r_.requires_grad_(True) 40 | 41 | adv_inputs_ = (inputs_ + r_).clamp(0, 1) 42 | logits = model(adv_inputs_) 43 | 44 | if i == 0: 45 | num_classes = logits.size(1) 46 | if masks is None: 47 | masks = labels < num_classes 48 | masks_sum = masks.flatten(1).sum(dim=1) 49 | masked_labels = labels * masks 50 | labels_infhot = torch.zeros_like(logits.detach()).scatter(1, masked_labels.unsqueeze(1), float('inf')) 51 | 52 | dl = multiplier * difference_of_logits(logits, labels=masked_labels[active_inputs], 53 | labels_infhot=labels_infhot[active_inputs]) 54 | pixel_is_adv = dl < 0 55 | 56 | active_masks = masks[active_inputs] 57 | adv_percent = (pixel_is_adv & active_masks).flatten(1).sum(dim=1) / masks_sum[active_inputs] 58 | is_adv = adv_percent >= adv_threshold 59 | adv_found[active_inputs] = is_adv 60 | best_adv[active_inputs] = torch.where(batch_view(is_adv), adv_inputs_.detach(), best_adv[active_inputs]) 61 | 62 | if callback: 63 | callback.accumulate_line('dl', i, dl[active_masks].mean(), title=f'DAG (p={p}, γ={γ}) - DL') 64 | callback.accumulate_line(f'L{p}', i, r.flatten(1).norm(p=p, dim=1).mean(), title=f'DAG (p={p}, γ={γ}) - Norm') 65 | callback.accumulate_line('adv%', i, adv_percent.mean(), title=f'DAG (p={p}, γ={γ}) - Adv percent') 66 | 67 | if (i + 1) % (max_iter // 20) == 0 or (i + 1) == max_iter: 68 | callback.update_lines() 69 | 70 | if is_adv.all(): 71 | break 72 | 73 | loss = (dl[~is_adv] * active_masks[~is_adv]).relu() 74 | r_grad = grad(loss.sum(), r_, only_inputs=True)[0] 75 | r_grad.div_(batch_view(r_grad.flatten(1).norm(p=p, dim=1).clamp_min_(1e-8))) 76 | r_.data.sub_(r_grad, alpha=γ) 77 | 78 | r[active_inputs] = r_ 79 | 80 | if callback: 81 | callback.update_lines() 82 | 83 | return best_adv 84 | -------------------------------------------------------------------------------- /adv_lib/attacks/sigma_zero.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/Cinofix/sigma-zero-adversarial-attack 2 | import math 3 | import warnings 4 | 5 | import torch 6 | from torch import Tensor, nn 7 | from torch.autograd import grad 8 | 9 | from adv_lib.utils.losses import difference_of_logits 10 | 11 | 12 | def sigma_zero(model: nn.Module, 13 | inputs: Tensor, 14 | labels: Tensor, 15 | num_steps: int = 1000, 16 | η_0: float = 1.0, 17 | σ: float = 0.001, 18 | τ_0: float = 0.3, 19 | τ_factor: float = 0.01, 20 | grad_norm: float = float('inf'), 21 | targeted: bool = False) -> Tensor: 22 | """ 23 | σ-zero attack from https://arxiv.org/abs/2402.01879. 24 | 25 | Parameters 26 | ---------- 27 | model : nn.Module 28 | Model to attack. 29 | inputs : Tensor 30 | Inputs to attack. Should be in [0, 1]. 31 | labels : Tensor 32 | Labels corresponding to the inputs if untargeted, else target labels. 33 | num_steps : int 34 | Number of optimization steps. Corresponds to the number of forward and backward propagations. 35 | η_0 : float 36 | Initial step size. 37 | σ : float 38 | \ell_0 approximation parameter: smaller values produce sharper approximations while larger values produce a 39 | smoother approximation. 40 | τ_0 : float 41 | Initial sparsity threshold. 42 | τ_factor : float 43 | Threshold adjustment factor w.r.t. step size η. 44 | grad_norm: float 45 | Norm to use for gradient normalization. 46 | targeted : bool 47 | Attack is untargeted only: will raise a warning and return inputs if targeted is True. 48 | 49 | Returns 50 | ------- 51 | best_adv : Tensor 52 | Perturbed inputs (inputs + perturbation) that are adversarial and have smallest distance with the original 53 | inputs. 54 | 55 | """ 56 | if targeted: 57 | warnings.warn('σ-zero attack is untargeted only. Returning inputs.') 58 | return inputs 59 | 60 | batch_size, numel = len(inputs), inputs[0].numel() 61 | batch_view = lambda tensor: tensor.view(batch_size, *[1] * (inputs.ndim - 1)) 62 | 63 | δ = torch.zeros_like(inputs, requires_grad=True) 64 | # Adam variables 65 | exp_avg = torch.zeros_like(inputs) 66 | exp_avg_sq = torch.zeros_like(inputs) 67 | β_1, β_2 = 0.9, 0.999 68 | 69 | best_l0 = inputs.new_full((batch_size,), numel) 70 | best_adv = inputs.clone() 71 | τ = torch.full_like(best_l0, τ_0) 72 | 73 | η = η_0 74 | for i in range(num_steps): 75 | adv_inputs = inputs + δ 76 | 77 | # compute loss 78 | logits = model(adv_inputs) 79 | dl_loss = difference_of_logits(logits, labels).clamp_(min=0) 80 | δ_square = δ.square() 81 | l0_approx_normalized = (δ_square / (δ_square + σ)).flatten(1).mean(dim=1) 82 | 83 | # keep best solutions 84 | predicted_classes = logits.argmax(dim=1) 85 | l0_norm = δ.data.flatten(1).norm(p=0, dim=1) 86 | is_adv = (predicted_classes == labels) if targeted else (predicted_classes != labels) 87 | is_smaller = l0_norm < best_l0 88 | is_both = is_adv & is_smaller 89 | best_l0 = torch.where(is_both, l0_norm, best_l0) 90 | best_adv = torch.where(batch_view(is_both), adv_inputs.detach(), best_adv) 91 | 92 | # compute loss and gradient 93 | adv_loss = (dl_loss + l0_approx_normalized).sum() 94 | δ_grad = grad(adv_loss, inputs=δ, only_inputs=True)[0] 95 | 96 | # normalize gradient based on grad_norm type 97 | δ_inf_norm = δ_grad.flatten(1).norm(p=grad_norm, dim=1).clamp_(min=1e-12) 98 | δ_grad.div_(batch_view(δ_inf_norm)) 99 | 100 | # adam computations 101 | exp_avg.lerp_(δ_grad, weight=1 - β_1) 102 | exp_avg_sq.mul_(β_2).addcmul_(δ_grad, δ_grad, value=1 - β_2) 103 | bias_correction1 = 1 - β_1 ** (i + 1) 104 | bias_correction2 = 1 - β_2 ** (i + 1) 105 | denom = exp_avg_sq.sqrt().div_(bias_correction2 ** 0.5).add_(1e-8) 106 | 107 | # step and clamp 108 | δ.data.addcdiv_(exp_avg, denom, value=-η / bias_correction1) 109 | δ.data.add_(inputs).clamp_(min=0, max=1).sub_(inputs) 110 | 111 | # update step size with cosine annealing 112 | η = 0.1 * η_0 + 0.9 * η_0 * (1 + math.cos(math.pi * i / num_steps)) / 2 113 | # dynamic thresholding 114 | τ.add_(torch.where(is_adv, τ_factor * η, -τ_factor * η)).clamp_(min=0, max=1) 115 | 116 | # filter components 117 | δ.data[δ.data.abs() < batch_view(τ)] = 0 118 | 119 | return best_adv 120 | -------------------------------------------------------------------------------- /adv_lib/attacks/stochastic_sparse_attacks.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | from torch import Tensor, nn 5 | from torch.autograd import grad 6 | 7 | 8 | def fga(model: nn.Module, 9 | inputs: Tensor, 10 | labels: Tensor, 11 | targeted: bool, 12 | increasing: bool = True, 13 | max_iter: Optional[int] = None, 14 | n_samples: int = 10, 15 | large_memory: bool = False) -> Tensor: 16 | """Folded Gaussian Attack (FGA) attack from https://arxiv.org/abs/2011.12423. 17 | 18 | Parameters 19 | ---------- 20 | model : nn.Module 21 | Model to attack. 22 | inputs : Tensor 23 | Inputs to attack. Should be in [0, 1]. 24 | labels : Tensor 25 | Labels corresponding to the inputs if untargeted, else target labels. 26 | targeted : bool 27 | Whether to perform a targeted attack or not. 28 | increasing : bool 29 | Whether to add positive or negative perturbations. 30 | max_iter : int 31 | Maximum number of iterations for the attack. If None is provided, the attack will run as long as adversarial 32 | examples are not found and non-modified pixels are left. 33 | n_samples : int 34 | Number of random samples to draw in each iteration. 35 | large_memory : bool 36 | If True, performs forward propagations on all randomly perturbed inputs in one batch. This is slightly faster 37 | for small models but also uses `n_samples` times more memory. This should only be used when working on small 38 | models. For larger models, the speed gain is negligible, so this option should be left to False. 39 | 40 | Returns 41 | ------- 42 | adv_inputs : Tensor 43 | Modified inputs to be adversarial to the model. 44 | 45 | """ 46 | batch_size, *input_shape = inputs.shape 47 | batch_view = lambda tensor: tensor.view(-1, *[1] * (inputs.ndim - 1)) 48 | input_view = lambda tensor: tensor.view(-1, *input_shape) 49 | device = inputs.device 50 | model_ = lambda t: model(t).softmax(dim=1) 51 | multiplier = 1 if targeted else -1 52 | 53 | adv_inputs = inputs.clone() 54 | best_adv = inputs.clone() 55 | adv_found = torch.zeros(batch_size, device=device, dtype=torch.bool) 56 | 57 | max_iter = inputs[0].numel() if max_iter is None else max_iter 58 | for i in range(max_iter): 59 | Γ = adv_inputs == inputs 60 | Γ_empty = ~Γ.flatten(1).any(1) 61 | if (Γ_empty | adv_found).all(): 62 | break 63 | to_attack = ~(Γ_empty | adv_found) 64 | 65 | inputs_, labels_, Γ_ = adv_inputs[to_attack], labels[to_attack], Γ[to_attack] 66 | batch_size_ = len(inputs_) 67 | inputs_.requires_grad_(True) 68 | Γ_inf = torch.zeros_like(Γ_, dtype=torch.float).masked_fill_(~Γ_, float('inf')) 69 | 70 | probs = model_(inputs_) 71 | label_probs = probs.gather(1, labels_.unsqueeze(1)).squeeze(1) 72 | grad_label_probs = grad(multiplier * label_probs.sum(), inputs=inputs_, only_inputs=True)[0] 73 | inputs_.detach_() 74 | 75 | # find index of most relevant feature 76 | if increasing: 77 | i_0 = grad_label_probs.mul_(1 - inputs_).sub_(Γ_inf).flatten(1).argmax(dim=1, keepdim=True) 78 | else: 79 | i_0 = grad_label_probs.mul_(inputs_).add_(Γ_inf).flatten(1).argmin(dim=1, keepdim=True) 80 | 81 | # compute variance of gaussian noise 82 | θ = inputs_.flatten(1).gather(1, i_0).neg_() 83 | if increasing: 84 | θ.add_(1) 85 | # generate random perturbation from folded Gaussian noise 86 | S = torch.randn(batch_size_, n_samples, device=device).abs_().mul_(θ) 87 | 88 | # add perturbation to inputs 89 | perturbed_inputs = inputs_.flatten(1).unsqueeze(1).repeat(1, n_samples, 1) 90 | perturbed_inputs.scatter_add_(2, i_0.repeat_interleave(n_samples, dim=1).unsqueeze(2), S.unsqueeze(2)) 91 | perturbed_inputs.clamp_(min=0, max=1) 92 | 93 | # get probabilities for perturbed inputs 94 | if large_memory: 95 | new_probs = model_(input_view(perturbed_inputs)) 96 | else: 97 | new_probs = [] 98 | for chunk in torch.chunk(input_view(perturbed_inputs), chunks=n_samples): 99 | new_probs.append(model_(chunk)) 100 | new_probs = torch.cat(new_probs, dim=0) 101 | new_probs = new_probs.view(batch_size_, n_samples, -1) 102 | new_preds = new_probs.argmax(dim=2) 103 | 104 | new_label_probs = new_probs.gather(2, labels_.view(-1, 1, 1).expand(-1, n_samples, 1)).squeeze(2) 105 | if targeted: 106 | # finding the index of max probability for target class. If a sample is adv, it will be prioritized. If 107 | # several are adversarial, taking the index of the adv sample with max probability. 108 | adv_found_ = new_preds == labels_.unsqueeze(1) 109 | best_sample_index = (new_label_probs + adv_found_.float()).argmax(dim=1) 110 | else: 111 | # finding the index of min probability for original class. If a sample is adv, it will be prioritized. If 112 | # several are adversarial, taking the index of the adv sample with min probability. 113 | adv_found_ = new_preds != labels_.unsqueeze(1) 114 | best_sample_index = (new_label_probs - adv_found_.float()).argmin(dim=1) 115 | 116 | # update trackers 117 | adv_inputs[to_attack] = input_view(perturbed_inputs[range(batch_size_), best_sample_index]) 118 | preds = new_preds.gather(1, best_sample_index.unsqueeze(1)).squeeze(1) 119 | is_adv = (preds == labels_) if targeted else (preds != labels_) 120 | adv_found[to_attack] = is_adv 121 | best_adv[to_attack] = torch.where(batch_view(is_adv), adv_inputs[to_attack], best_adv[to_attack]) 122 | 123 | return best_adv 124 | 125 | 126 | def vfga(model: nn.Module, 127 | inputs: Tensor, 128 | labels: Tensor, 129 | targeted: bool, 130 | max_iter: Optional[int] = None, 131 | n_samples: int = 10, 132 | large_memory: bool = False) -> Tensor: 133 | """Voting Folded Gaussian Attack (VFGA) attack from https://arxiv.org/abs/2011.12423. 134 | 135 | Parameters 136 | ---------- 137 | model : nn.Module 138 | Model to attack. 139 | inputs : Tensor 140 | Inputs to attack. Should be in [0, 1]. 141 | labels : Tensor 142 | Labels corresponding to the inputs if untargeted, else target labels. 143 | targeted : bool 144 | Whether to perform a targeted attack or not. 145 | max_iter : int 146 | Maximum number of iterations for the attack. If None is provided, the attack will run as long as adversarial 147 | examples are not found and non-modified pixels are left. 148 | n_samples : int 149 | Number of random samples to draw in each iteration. 150 | large_memory : bool 151 | If True, performs forward propagations on all randomly perturbed inputs in one batch. This is slightly faster 152 | for small models but also uses `n_samples` times more memory. This should only be used when working on small 153 | models. For larger models, the speed gain is negligible, so this option should be left to False. 154 | 155 | Returns 156 | ------- 157 | adv_inputs : Tensor 158 | Modified inputs to be adversarial to the model. 159 | 160 | """ 161 | batch_size, *input_shape = inputs.shape 162 | batch_view = lambda tensor: tensor.view(-1, *[1] * (inputs.ndim - 1)) 163 | input_view = lambda tensor: tensor.view(-1, *input_shape) 164 | device = inputs.device 165 | model_ = lambda t: model(t).softmax(dim=1) 166 | multiplier = 1 if targeted else -1 167 | 168 | adv_inputs = inputs.clone() 169 | best_adv = inputs.clone() 170 | adv_found = torch.zeros(batch_size, device=device, dtype=torch.bool) 171 | 172 | max_iter = inputs[0].numel() if max_iter is None else max_iter 173 | for i in range(max_iter): 174 | Γ = adv_inputs == inputs 175 | Γ_empty = ~Γ.flatten(1).any(1) 176 | if (Γ_empty | adv_found).all(): 177 | break 178 | to_attack = ~(Γ_empty | adv_found) 179 | 180 | inputs_, labels_, Γ_ = adv_inputs[to_attack], labels[to_attack], Γ[to_attack] 181 | batch_size_ = len(inputs_) 182 | inputs_.requires_grad_(True) 183 | Γ_inf = torch.zeros_like(Γ_, dtype=torch.float).masked_fill_(~Γ_, float('inf')) 184 | 185 | probs = model_(inputs_) 186 | label_probs = probs.gather(1, labels_.unsqueeze(1)).squeeze(1) 187 | grad_label_probs = grad(multiplier * label_probs.sum(), inputs=inputs_, only_inputs=True)[0] 188 | inputs_.detach_() 189 | 190 | # find index of most relevant feature 191 | i_plus = (1 - inputs_).mul_(grad_label_probs).sub_(Γ_inf).flatten(1).argmax(dim=1, keepdim=True) 192 | i_minus = grad_label_probs.mul_(inputs_).add_(Γ_inf).flatten(1).argmin(dim=1, keepdim=True) 193 | # compute variance of gaussian noise 194 | θ_plus = 1 - inputs_.flatten(1).gather(1, i_plus) 195 | θ_minus = inputs_.flatten(1).gather(1, i_minus) 196 | # generate random perturbation from folded Gaussian noise 197 | S_plus = torch.randn(batch_size_, n_samples, device=device).abs_().mul_(θ_plus) 198 | S_minus = torch.randn(batch_size_, n_samples, device=device).abs_().neg_().mul_(θ_minus) 199 | 200 | # add perturbation to inputs 201 | perturbed_inputs = inputs_.flatten(1).unsqueeze(1).repeat(1, 2 * n_samples, 1) 202 | i_plus_minus = torch.cat([i_plus, i_minus], dim=1).repeat_interleave(n_samples, dim=1) 203 | S_plus_minus = torch.cat([S_plus, S_minus], dim=1) 204 | perturbed_inputs.scatter_add_(2, i_plus_minus.unsqueeze(2), S_plus_minus.unsqueeze(2)) 205 | perturbed_inputs.clamp_(min=0, max=1) 206 | 207 | # get probabilities for perturbed inputs 208 | if large_memory: 209 | new_probs = model_(input_view(perturbed_inputs)) 210 | else: 211 | new_probs = [] 212 | for chunk in torch.chunk(input_view(perturbed_inputs), chunks=n_samples): 213 | new_probs.append(model_(chunk)) 214 | new_probs = torch.cat(new_probs, dim=0) 215 | new_probs = new_probs.view(batch_size_, 2 * n_samples, -1) 216 | new_preds = new_probs.argmax(dim=2) 217 | 218 | new_label_probs = new_probs.gather(2, labels_.view(-1, 1, 1).expand(-1, 2 * n_samples, 1)).squeeze(2) 219 | if targeted: 220 | # finding the index of max probability for target class. If a sample is adv, it will be prioritized. If 221 | # several are adversarial, taking the index of the adv sample with max probability. 222 | adv_found_ = new_preds == labels_.unsqueeze(1) 223 | best_sample_index = (new_label_probs + adv_found_.float()).argmax(dim=1) 224 | else: 225 | # finding the index of min probability for original class. If a sample is adv, it will be prioritized. If 226 | # several are adversarial, taking the index of the adv sample with min probability. 227 | adv_found_ = new_preds != labels_.unsqueeze(1) 228 | best_sample_index = (new_label_probs - adv_found_.float()).argmin(dim=1) 229 | 230 | # update trackers 231 | adv_inputs[to_attack] = input_view(perturbed_inputs[range(batch_size_), best_sample_index]) 232 | preds = new_preds.gather(1, best_sample_index.unsqueeze(1)).squeeze(1) 233 | is_adv = (preds == labels_) if targeted else (preds != labels_) 234 | adv_found[to_attack] = is_adv 235 | best_adv[to_attack] = torch.where(batch_view(is_adv), adv_inputs[to_attack], best_adv[to_attack]) 236 | 237 | return best_adv 238 | -------------------------------------------------------------------------------- /adv_lib/attacks/structured_adversarial_attack.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/KaidiXu/StrAttack 2 | import math 3 | from itertools import product 4 | from typing import Optional 5 | 6 | import numpy as np 7 | import torch 8 | from torch import Tensor, nn 9 | from torch.autograd import grad 10 | from torch.nn import functional as F 11 | 12 | from adv_lib.utils.losses import difference_of_logits 13 | from adv_lib.utils.visdom_logger import VisdomLogger 14 | 15 | 16 | def str_attack(model: nn.Module, 17 | inputs: Tensor, 18 | labels: Tensor, 19 | targeted: bool = False, 20 | confidence: float = 0, 21 | initial_const: float = 1, 22 | binary_search_steps: int = 6, 23 | max_iterations: int = 2000, 24 | ρ: float = 1, 25 | α: float = 5, 26 | τ: float = 2, 27 | γ: float = 1, 28 | group_size: int = 2, 29 | stride: int = 2, 30 | retrain: bool = True, 31 | σ: float = 3, 32 | fix_y_step: bool = False, 33 | callback: Optional[VisdomLogger] = None) -> Tensor: 34 | """ 35 | StrAttack from https://arxiv.org/abs/1808.01664. 36 | 37 | Parameters 38 | ---------- 39 | model : nn.Module 40 | Model to attack. 41 | inputs : Tensor 42 | Inputs to attack. Should be in [0, 1]. 43 | labels : Tensor 44 | Labels corresponding to the inputs if untargeted, else target labels. 45 | targeted : bool 46 | Whether to perform a targeted attack or not. 47 | confidence : float 48 | Confidence of adversarial examples: higher produces examples that are farther away, but more strongly classified 49 | as adversarial. 50 | initial_const : float 51 | The initial tradeoff-constant to use to tune the relative importance of distance and confidence. If 52 | binary_search_steps is large, the initial constant is not important. 53 | binary_search_steps : int 54 | The number of times we perform binary search to find the optimal tradeoff-constant between distance and 55 | confidence. 56 | max_iterations : int 57 | The maximum number of iterations. Larger values are more accurate; setting too small will require a large 58 | learning rate and will produce poor results. 59 | ρ : float 60 | Penalty parameter adjusting the trade-off between convergence speed and value. Larger ρ leads to faster 61 | convergence but larger perturbations. 62 | α : float 63 | Initial learning rate (η_1 in section F of the paper). 64 | τ : float 65 | Weight of the group sparsity penalty. 66 | γ : float 67 | Weight of the l2-norm penalty. 68 | group_size : int 69 | Size of the groups for the sparsity penalty. 70 | stride : int 71 | Stride of the groups for the sparsity penalty. If stride < group_size, then the groups are overlapping. 72 | retrain : bool 73 | If True, refines the perturbation by constraining it to σ% of the pixels. 74 | σ : float 75 | Percentage of pixels allowed to be perturbed in the refinement procedure. 76 | fix_y_step : bool 77 | Fix the group folding of the original implementation by correctly summing over groups for pixels in the 78 | overlapping regions. Typically results in smaller perturbations and is faster. 79 | callback : Optional 80 | 81 | Returns 82 | ------- 83 | adv_inputs : Tensor 84 | Modified inputs to be adversarial to the model. 85 | 86 | """ 87 | device = inputs.device 88 | batch_size, C, H, W = inputs.shape 89 | batch_view = lambda tensor: tensor.view(batch_size, *[1] * (inputs.ndim - 1)) 90 | multiplier = -1 if targeted else 1 91 | zeros = torch.zeros_like(inputs) 92 | 93 | # set the lower and upper bounds accordingly 94 | const = torch.full((batch_size,), initial_const, device=device, dtype=torch.float) 95 | lower_bound = torch.zeros_like(const) 96 | upper_bound = torch.full_like(const, 1e10) 97 | 98 | # bounds for the perturbations to get valid inputs 99 | sup_bound = 1 - inputs 100 | inf_bound = -inputs 101 | 102 | # number of groups per row and column 103 | P, Q = math.floor((W - group_size) / stride) + 1, math.floor((H - group_size) / stride) + 1 104 | overlap = group_size > stride 105 | 106 | z = torch.zeros_like(inputs) 107 | v = torch.zeros_like(inputs) 108 | u = torch.zeros_like(inputs) 109 | s = torch.zeros_like(inputs) 110 | 111 | o_best_l2 = torch.full_like(const, float('inf')) 112 | o_best_adv = inputs.clone() 113 | o_adv_found = torch.zeros(batch_size, device=device, dtype=torch.bool) 114 | 115 | i_total = 0 116 | for outer_step in range(binary_search_steps): 117 | 118 | best_l2 = torch.full_like(const, float('inf')) 119 | adv_found = torch.zeros(batch_size, device=device, dtype=torch.bool) 120 | 121 | # The last iteration (if we run many steps) repeat the search once. 122 | if (binary_search_steps >= 10) and outer_step == (binary_search_steps - 1): 123 | const = upper_bound 124 | 125 | for i in range(max_iterations): # max_iterations + outer_step * 1000 in the original implementation 126 | 127 | z.requires_grad_(True) 128 | adv_inputs = z + inputs 129 | logits = model(adv_inputs) 130 | 131 | if outer_step == 0 and i == 0: 132 | # setup the target variable, we need it to be in one-hot form for the loss function 133 | labels_onehot = torch.zeros_like(logits).scatter_(1, labels.unsqueeze(1), 1) 134 | labels_infhot = torch.zeros_like(logits).scatter_(1, labels.unsqueeze(1), float('inf')) 135 | 136 | logit_dists = multiplier * difference_of_logits(logits, labels, labels_infhot=labels_infhot) 137 | loss = const * (logit_dists + confidence).clamp(min=0) 138 | z_grad = grad(loss.sum(), inputs=z, only_inputs=True)[0] 139 | z.detach_() 140 | 141 | # δ step (equation 10) 142 | a = z - u 143 | δ = ρ / (ρ + 2 * γ) * a 144 | 145 | # w step (equation 11) 146 | b = z - s 147 | w = torch.minimum(torch.maximum(b, inf_bound), sup_bound) 148 | 149 | # y step (equation 17) 150 | c = z - v 151 | groups = F.unfold(c, kernel_size=group_size, stride=stride) 152 | group_norms = groups.norm(dim=1, p=2, keepdim=True) 153 | temp = torch.where(group_norms != 0, 1 - τ / (ρ * group_norms), torch.zeros_like(group_norms)) 154 | temp_ = groups * temp.clamp(min=0) 155 | 156 | if overlap and not fix_y_step: # match original implementation when overlapping 157 | y = c 158 | for i, (p, q) in enumerate(product(range(P), range(Q))): 159 | p_start, p_end = p * stride, p * stride + group_size 160 | q_start, q_end = q * stride, q * stride + group_size 161 | y[:, :, p_start:p_end, q_start:q_end] = temp_[:, :, i].view(-1, C, group_size, group_size) 162 | else: # faster folding (matches original implementation when groups are not overlapping) 163 | y = F.fold(temp_, output_size=(H, W), kernel_size=group_size, stride=stride) 164 | 165 | # MODIFIED: add projection for valid perturbation 166 | y = torch.minimum(torch.maximum(y, inf_bound), sup_bound) 167 | 168 | # z step (equation 18) 169 | a_prime = δ + u 170 | b_prime = w + s 171 | c_prime = y + v 172 | η = α * (i + 1) ** 0.5 173 | z = (z * η + ρ * (2 * a_prime + b_prime + c_prime) - z_grad) / (η + 4 * ρ) 174 | 175 | # MODIFIED: add projection for valid perturbation 176 | z = torch.minimum(torch.maximum(z, inf_bound), sup_bound) 177 | 178 | # update steps 179 | u.add_(δ - z) 180 | v.add_(y - z) 181 | s.add_(w - z) 182 | 183 | # new predictions 184 | adv_inputs = y + inputs 185 | l2 = (adv_inputs - inputs).flatten(1).norm(p=2, dim=1) 186 | logits = model(adv_inputs) 187 | logit_dists = multiplier * difference_of_logits(logits, labels, labels_infhot=labels_infhot) 188 | 189 | # adjust the best result found so far 190 | predicted_classes = (logits - labels_onehot * confidence).argmax(1) if targeted else \ 191 | (logits + labels_onehot * confidence).argmax(1) 192 | 193 | is_adv = (predicted_classes == labels) if targeted else (predicted_classes != labels) 194 | is_smaller = l2 < best_l2 195 | o_is_smaller = l2 < o_best_l2 196 | is_both = is_adv & is_smaller 197 | o_is_both = is_adv & o_is_smaller 198 | 199 | best_l2 = torch.where(is_both, l2, best_l2) 200 | adv_found.logical_or_(is_both) 201 | o_best_l2 = torch.where(o_is_both, l2, o_best_l2) 202 | o_adv_found.logical_or_(is_both) 203 | o_best_adv = torch.where(batch_view(o_is_both), adv_inputs.detach(), o_best_adv) 204 | 205 | if callback: 206 | i_total += 1 207 | callback.accumulate_line('logit_dist', i_total, logit_dists.mean()) 208 | callback.accumulate_line('l2_norm', i_total, l2.mean()) 209 | if i_total % (max_iterations // 20) == 0: 210 | callback.update_lines() 211 | 212 | if callback: 213 | best_l2 = o_best_l2[o_adv_found].mean() if o_adv_found.any() else torch.tensor(float('nan'), device=device) 214 | callback.line(['success', 'best_l2', 'c'], outer_step, [o_adv_found.float().mean(), best_l2, c.mean()]) 215 | 216 | # adjust the constant as needed 217 | upper_bound[adv_found] = torch.min(upper_bound[adv_found], const[adv_found]) 218 | adv_not_found = ~adv_found 219 | lower_bound[adv_not_found] = torch.max(lower_bound[adv_not_found], const[adv_not_found]) 220 | is_smaller = upper_bound < 1e9 221 | const[is_smaller] = (lower_bound[is_smaller] + upper_bound[is_smaller]) / 2 222 | const[(~is_smaller) & adv_not_found] *= 5 223 | 224 | if retrain: 225 | lower_bound = torch.zeros_like(const) 226 | const = torch.full_like(const, initial_const) 227 | upper_bound = torch.full_like(const, 1e10) 228 | 229 | for i in range(8): # taken from the original implementation 230 | 231 | best_l2 = torch.full_like(const, float('inf')) 232 | adv_found = torch.zeros(batch_size, device=device, dtype=torch.bool) 233 | 234 | o_best_y = o_best_adv - inputs 235 | np_o_best_y = o_best_y.cpu().numpy() 236 | Nz = np.abs(np_o_best_y[np.nonzero(np_o_best_y)]) 237 | threshold = np.percentile(Nz, σ) 238 | 239 | S_σ = o_best_y.abs() <= threshold 240 | z = o_best_y.clone() 241 | u = torch.zeros_like(inputs) 242 | tmpC = ρ / (ρ + γ / 100) 243 | 244 | for outer_step in range(400): # taken from the original implementation 245 | # δ step (equation 21) 246 | temp_a = (z - u) * tmpC 247 | δ = torch.where(S_σ, zeros, torch.minimum(torch.maximum(temp_a, inf_bound), sup_bound)) 248 | 249 | # new predictions 250 | δ.requires_grad_(True) 251 | adv_inputs = δ + inputs 252 | logits = model(adv_inputs) 253 | l2 = (adv_inputs.detach() - inputs).flatten(1).norm(p=2, dim=1) 254 | logit_dists = multiplier * difference_of_logits(logits, labels, labels_infhot=labels_infhot) 255 | loss = const * (logit_dists + confidence).clamp(min=0) 256 | z_grad = grad(loss.sum(), inputs=δ, only_inputs=True)[0] 257 | δ.detach_() 258 | 259 | # z step (equation 22) 260 | a_prime = δ + u 261 | z = torch.where(S_σ, zeros, (α * z + ρ * a_prime - z_grad) / (α + 2 * ρ)) 262 | u.add_(δ - z) 263 | 264 | # adjust the best result found so far 265 | predicted_classes = (logits - labels_onehot * confidence).argmax(1) if targeted else \ 266 | (logits + labels_onehot * confidence).argmax(1) 267 | 268 | is_adv = (predicted_classes == labels) if targeted else (predicted_classes != labels) 269 | is_smaller = l2 < best_l2 270 | o_is_smaller = l2 < o_best_l2 271 | is_both = is_adv & is_smaller 272 | o_is_both = is_adv & o_is_smaller 273 | 274 | best_l2 = torch.where(is_both, l2, best_l2) 275 | adv_found.logical_or_(is_both) 276 | o_best_l2 = torch.where(o_is_both, l2, o_best_l2) 277 | o_adv_found.logical_or_(is_both) 278 | o_best_adv = torch.where(batch_view(o_is_both), adv_inputs.detach(), o_best_adv) 279 | 280 | # adjust the constant as needed 281 | upper_bound[adv_found] = torch.min(upper_bound[adv_found], const[adv_found]) 282 | adv_not_found = ~adv_found 283 | lower_bound[adv_not_found] = torch.max(lower_bound[adv_not_found], const[adv_not_found]) 284 | is_smaller = upper_bound < 1e9 285 | const[is_smaller] = (lower_bound[is_smaller] + upper_bound[is_smaller]) / 2 286 | const[(~is_smaller) & adv_not_found] *= 5 287 | 288 | # return the best solution found 289 | return o_best_adv 290 | -------------------------------------------------------------------------------- /adv_lib/attacks/superdeepfool.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import torch 4 | from torch import Tensor, nn 5 | from torch.autograd import grad 6 | 7 | from .deepfool import df 8 | 9 | 10 | def sdf(model: nn.Module, 11 | inputs: Tensor, 12 | labels: Tensor, 13 | targeted: bool = False, 14 | steps: int = 100, 15 | df_steps: int = 100, 16 | overshoot: float = 0.02, 17 | search_iter: int = 10) -> Tensor: 18 | """ 19 | SuperDeepFool attack from https://arxiv.org/abs/2303.12481. 20 | 21 | Parameters 22 | ---------- 23 | model : nn.Module 24 | Model to attack. 25 | inputs : Tensor 26 | Inputs to attack. Should be in [0, 1]. 27 | labels : Tensor 28 | Labels corresponding to the inputs if untargeted, else target labels. 29 | targeted : bool 30 | Whether to perform a targeted attack or not. 31 | steps : int 32 | Number of steps. 33 | df_steps : int 34 | Maximum number of steps for DeepFool attack at each iteration of SuperDeepFool. 35 | overshoot : float 36 | overshoot parameter in DeepFool. 37 | search_iter : int 38 | Number of binary search steps at the end of the attack. 39 | 40 | Returns 41 | ------- 42 | adv_inputs : Tensor 43 | Modified inputs to be adversarial to the model. 44 | 45 | """ 46 | if targeted: 47 | warnings.warn('DeepFool attack is untargeted only. Returning inputs.') 48 | return inputs 49 | 50 | if inputs.min() < 0 or inputs.max() > 1: raise ValueError('Input values should be in the [0, 1] range.') 51 | device = inputs.device 52 | batch_size = len(inputs) 53 | batch_view = lambda tensor: tensor.view(-1, *[1] * (inputs.ndim - 1)) 54 | 55 | # Setup variables 56 | adv_inputs = inputs_ = inputs 57 | labels_ = labels 58 | adv_out = inputs.clone() 59 | adv_found = torch.zeros(batch_size, dtype=torch.bool, device=device) 60 | 61 | for i in range(steps): 62 | logits = model(adv_inputs) 63 | pred_labels = logits.argmax(dim=1) 64 | 65 | is_adv = pred_labels != labels_ 66 | if is_adv.any(): 67 | adv_not_found = ~adv_found 68 | adv_out[adv_not_found] = torch.where(batch_view(is_adv), adv_inputs, adv_out[adv_not_found]) 69 | adv_found.masked_scatter_(adv_not_found, is_adv) 70 | if is_adv.all(): 71 | break 72 | 73 | not_adv = ~is_adv 74 | inputs_, adv_inputs, labels_ = inputs_[not_adv], adv_inputs[not_adv], labels_[not_adv] 75 | 76 | # start by doing deepfool -> need to return adv_inputs even for unsuccessful attacks 77 | df_adv_inputs, df_targets = df(model=model, inputs=adv_inputs, labels=labels_, steps=df_steps, norm=2, 78 | overshoot=overshoot, return_unsuccessful=True, return_targets=True) 79 | 80 | r_df = df_adv_inputs - inputs_ 81 | df_adv_inputs.requires_grad_(True) 82 | logits = model(df_adv_inputs) 83 | pred_labels = logits.argmax(dim=1) 84 | pred_labels = torch.where(pred_labels != labels_, pred_labels, df_targets) 85 | 86 | logit_diff = logits.gather(1, pred_labels.unsqueeze(1)) - logits.gather(1, labels_.unsqueeze(1)) 87 | w = grad(logit_diff.sum(), inputs=df_adv_inputs, only_inputs=True)[0] 88 | w.div_(batch_view(w.flatten(1).norm(p=2, dim=1).clamp_(min=1e-6))) # w / ||w||_2 89 | scale = torch.linalg.vecdot(r_df.flatten(1), w.flatten(1), dim=1) # (\tilde{x} - x_0)^T w / ||w||_2 90 | 91 | adv_inputs = adv_inputs.addcmul(batch_view(scale), w) 92 | adv_inputs.clamp_(min=0, max=1) # added compared to original implementation to produce valid adv 93 | 94 | if search_iter: # binary search to bring perturbation as close to the decision boundary as possible 95 | low, high = torch.zeros(batch_size, device=device), torch.ones(batch_size, device=device) 96 | for i in range(search_iter): 97 | mid = (low + high) / 2 98 | logits = torch.lerp(inputs, adv_out, weight=batch_view(mid)) 99 | pred_labels = model(logits).argmax(dim=1) 100 | is_adv = pred_labels != labels 101 | high = torch.where(is_adv, mid, high) 102 | low = torch.where(is_adv, low, mid) 103 | adv_out = torch.lerp(inputs, adv_out, weight=batch_view(high)) 104 | 105 | return adv_out 106 | -------------------------------------------------------------------------------- /adv_lib/attacks/trust_region.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/amirgholami/trattack 2 | 3 | import warnings 4 | from functools import partial 5 | from typing import Tuple 6 | 7 | import torch 8 | from torch import Tensor, nn 9 | from torch.autograd import grad 10 | 11 | 12 | def select_index(model: nn.Module, 13 | inputs: Tensor, 14 | c: int = 9, 15 | p: float = 2, 16 | worst_case: bool = False) -> Tensor: 17 | """ 18 | Select the attack target class. 19 | """ 20 | _duals = {2: 2, float('inf'): 1} 21 | dual = _duals[p] 22 | 23 | inputs.requires_grad_(True) 24 | logits = model(inputs) 25 | logits, indices = torch.sort(logits, descending=True) 26 | 27 | top_logits = logits[:, 0] 28 | top_grad = grad(top_logits.sum(), inputs, only_inputs=True, retain_graph=True)[0] 29 | pers = [] 30 | 31 | c = min(logits.size(1) - 1, c) 32 | for i in range(c): 33 | other_logits = logits[:, i + 1] 34 | other_grad = grad(other_logits.sum(), inputs, only_inputs=True, retain_graph=i + 1 != c)[0] 35 | grad_dual_norm = (top_grad - other_grad).flatten(1).norm(p=dual, dim=1) 36 | pers.append((top_logits.detach() - other_logits.detach()).div_(grad_dual_norm)) 37 | 38 | pers = torch.stack(pers, dim=1) 39 | inputs.detach_() 40 | 41 | if worst_case: 42 | index = pers.argmax(dim=1, keepdim=True) 43 | else: 44 | index = pers.clamp_(min=0).argmin(dim=1, keepdim=True) 45 | 46 | return indices.gather(1, index + 1).squeeze(1) 47 | 48 | 49 | def _step(model: nn.Module, 50 | inputs: Tensor, 51 | labels: Tensor, 52 | target_labels: Tensor, 53 | eps: Tensor, 54 | p: float = 2) -> Tuple[Tensor, Tensor]: 55 | _duals = {2: 2, float('inf'): 1} 56 | dual = _duals[p] 57 | 58 | inputs.requires_grad_(True) 59 | logits = model(inputs) 60 | 61 | logit_diff = (logits.gather(1, target_labels.unsqueeze(1)) - logits.gather(1, labels.unsqueeze(1))).squeeze(1) 62 | 63 | grad_inputs = grad(logit_diff.sum(), inputs, only_inputs=True)[0].flatten(1) 64 | inputs.detach_() 65 | per = logit_diff.detach().neg_().div_(grad_inputs.norm(p=dual, dim=1).clamp_(min=1e-6)) 66 | 67 | if p == float('inf'): 68 | grad_inputs.sign_() 69 | elif p == 2: 70 | grad_inputs.div_(grad_inputs.norm(p=2, dim=1, keepdim=True).clamp_(min=1e-6)) 71 | 72 | per = torch.min(per, eps) 73 | adv_inputs = grad_inputs.mul_(per.add_(1e-4).mul_(1.02).unsqueeze(1)).view_as(inputs).add_(inputs) 74 | adv_inputs.clamp_(min=0, max=1) 75 | return adv_inputs, eps 76 | 77 | 78 | def _adaptive_step(model: nn.Module, 79 | inputs: Tensor, 80 | labels: Tensor, 81 | target_labels: Tensor, 82 | eps: Tensor, 83 | p: float = 2) -> Tuple[Tensor, Tensor]: 84 | _duals = {2: 2, float('inf'): 1} 85 | dual = _duals[p] 86 | 87 | inputs.requires_grad_(True) 88 | logits = model(inputs) 89 | 90 | class_logits = logits.gather(1, labels.unsqueeze(1)).squeeze(1) 91 | target_logits = logits.gather(1, target_labels.unsqueeze(1)).squeeze(1) 92 | logit_diff = target_logits - class_logits 93 | 94 | grad_inputs = grad(logit_diff.sum(), inputs, only_inputs=True)[0].flatten(1) 95 | inputs.detach_() 96 | per = logit_diff.detach().neg_().div_(grad_inputs.norm(p=dual, dim=1).clamp_(min=1e-6)) 97 | 98 | if p == float('inf'): 99 | grad_inputs.sign_() 100 | elif p == 2: 101 | grad_inputs.div_(grad_inputs.norm(p=2, dim=1, keepdim=True).clamp_(min=1e-6)) 102 | 103 | new_eps = torch.min(per, eps) 104 | 105 | adv_inputs = grad_inputs.mul_(new_eps.add(1e-4).mul_(1.02).unsqueeze_(1)).view_as(inputs).add_(inputs) 106 | adv_inputs.clamp_(min=0, max=1) 107 | 108 | adv_logits = model(adv_inputs) 109 | class_adv_logits = adv_logits.gather(1, labels.unsqueeze(1)).squeeze(1) 110 | 111 | obj_diff = (class_logits - class_adv_logits).div_(new_eps) 112 | increase = obj_diff > 0.9 113 | decrease = obj_diff < 0.5 114 | new_eps = torch.where(increase, new_eps * 1.2, torch.where(decrease, new_eps / 1.2, new_eps)) 115 | if p == 2: 116 | new_eps.clamp_(min=0.0005, max=0.05) 117 | elif p == float('inf'): 118 | new_eps.clamp_(min=0.0001, max=0.01) 119 | 120 | return adv_inputs, new_eps 121 | 122 | 123 | def tr(model: nn.Module, 124 | inputs: Tensor, 125 | labels: Tensor, 126 | iter: int = 100, 127 | adaptive: bool = False, 128 | p: float = 2, 129 | eps: float = 0.001, 130 | c: int = 9, 131 | worst_case: bool = False, 132 | targeted: bool = False) -> Tensor: 133 | if targeted: 134 | warnings.warn('TR attack is untargeted only. Returning inputs.') 135 | return inputs 136 | 137 | adv_inputs = inputs.clone() 138 | target_labels = select_index(model, inputs, c=c, p=p, worst_case=worst_case) 139 | attack_step = partial(_adaptive_step if adaptive else _step, model=model, p=p) 140 | 141 | to_attack = torch.ones(len(inputs), dtype=torch.bool, device=inputs.device) 142 | eps = torch.full_like(to_attack, eps, dtype=torch.float, device=inputs.device) 143 | 144 | for _ in range(iter): 145 | 146 | logits = model(adv_inputs[to_attack]) 147 | to_attack.masked_scatter_(to_attack, logits.argmax(dim=1) == labels[to_attack]) 148 | if (~to_attack).all(): 149 | break 150 | adv_inputs[to_attack], eps[to_attack] = attack_step(inputs=adv_inputs[to_attack], labels=labels[to_attack], 151 | target_labels=target_labels[to_attack], eps=eps[to_attack]) 152 | 153 | return adv_inputs 154 | -------------------------------------------------------------------------------- /adv_lib/distances/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jeromerony/adversarial-library/4ca9b77bb6c909d47e161ced0e257c6003fb4116/adv_lib/distances/__init__.py -------------------------------------------------------------------------------- /adv_lib/distances/color_difference.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | from torch import Tensor 5 | 6 | from adv_lib.utils.color_conversions import rgb_to_cielab 7 | 8 | 9 | def cie94_color_difference(Lab_1: Tensor, Lab_2: Tensor, k_L: float = 1, k_C: float = 1, k_H: float = 1, 10 | K_1: float = 0.045, K_2: float = 0.015, squared: bool = False, ε: float = 0) -> Tensor: 11 | """ 12 | Inputs should be L*, a*, b*. Star from formulas are omitted for conciseness. 13 | 14 | Parameters 15 | ---------- 16 | Lab_1 : Tensor 17 | First image in L*a*b* space. First image is intended to be the reference image. 18 | Lab_2 : Tensor 19 | Second image in L*a*b* space. Second image is intended to be the modified one. 20 | k_L : float 21 | Weighting factor for S_L. 22 | k_C : float 23 | Weighting factor for S_C. 24 | k_H : float 25 | Weighting factor for S_H. 26 | squared : bool 27 | Return the squared ΔE_94. 28 | ε : float 29 | Small value for numerical stability when computing gradients. Default to 0 for most accurate evaluation. 30 | 31 | Returns 32 | ------- 33 | ΔE_94 : Tensor 34 | The CIEDE2000 color difference for each pixel. 35 | 36 | """ 37 | ΔL = Lab_1.narrow(1, 0, 1) - Lab_2.narrow(1, 0, 1) 38 | C_1 = torch.norm(Lab_1.narrow(1, 1, 2), p=2, dim=1, keepdim=True) 39 | C_2 = torch.norm(Lab_2.narrow(1, 1, 2), p=2, dim=1, keepdim=True) 40 | ΔC = C_1 - C_2 41 | Δa = Lab_1.narrow(1, 1, 1) - Lab_2.narrow(1, 1, 1) 42 | Δb = Lab_1.narrow(1, 2, 1) - Lab_2.narrow(1, 2, 1) 43 | ΔH = Δa ** 2 + Δb ** 2 - ΔC ** 2 44 | S_L = 1 45 | S_C = 1 + K_1 * C_1 46 | S_H = 1 + K_2 * C_1 47 | ΔE_94_squared = (ΔL / (k_L * S_L)) ** 2 + (ΔC / (k_C * S_C)) ** 2 + ΔH / ((k_H * S_H) ** 2) 48 | if squared: 49 | return ΔE_94_squared 50 | return ΔE_94_squared.clamp(min=ε).sqrt() 51 | 52 | 53 | def rgb_cie94_color_difference(input: Tensor, target: Tensor, **kwargs) -> Tensor: 54 | """Computes the CIEDE2000 Color-Difference from RGB inputs.""" 55 | return cie94_color_difference(*map(rgb_to_cielab, (input, target)), **kwargs) 56 | 57 | 58 | def cie94_loss(x1: Tensor, x2: Tensor, squared: bool = False, **kwargs) -> Tensor: 59 | """ 60 | Computes the L2-norm over all pixels of the CIEDE2000 Color-Difference for two RGB inputs. 61 | 62 | Parameters 63 | ---------- 64 | x1 : Tensor: 65 | First input. 66 | x2 : Tensor: 67 | Second input (of size matching x1). 68 | squared : bool 69 | Returns the squared L2-norm. 70 | 71 | Returns 72 | ------- 73 | ΔE_00_l2 : Tensor 74 | The L2-norm over all pixels of the CIEDE2000 Color-Difference. 75 | 76 | """ 77 | ΔE_94_squared = rgb_cie94_color_difference(x1, x2, squared=True, **kwargs).flatten(1) 78 | ε = kwargs.get('ε', 0) 79 | if squared: 80 | return ΔE_94_squared.sum(1) 81 | return ΔE_94_squared.sum(1).clamp(min=ε).sqrt() 82 | 83 | 84 | def ciede2000_color_difference(Lab_1: Tensor, Lab_2: Tensor, k_L: float = 1, k_C: float = 1, k_H: float = 1, 85 | squared: bool = False, ε: float = 0) -> Tensor: 86 | """ 87 | Inputs should be L*, a*, b*. Primes from formulas in 88 | http://www2.ece.rochester.edu/~gsharma/ciede2000/ciede2000noteCRNA.pdf are omitted for conciseness. 89 | This version is based on the matlab implementation from Gaurav Sharma 90 | http://www2.ece.rochester.edu/~gsharma/ciede2000/dataNprograms/deltaE2000.m modified to have non NaN gradients. 91 | 92 | Parameters 93 | ---------- 94 | Lab_1 : Tensor 95 | First image in L*a*b* space. First image is intended to be the reference image. 96 | Lab_2 : Tensor 97 | Second image in L*a*b* space. Second image is intended to be the modified one. 98 | k_L : float 99 | Weighting factor for S_L. 100 | k_C : float 101 | Weighting factor for S_C. 102 | k_H : float 103 | Weighting factor for S_H. 104 | squared : bool 105 | Return the squared ΔE_00. 106 | ε : float 107 | Small value for numerical stability when computing gradients. Default to 0 for most accurate evaluation. 108 | 109 | Returns 110 | ------- 111 | ΔE_00 : Tensor 112 | The CIEDE2000 color difference for each pixel. 113 | 114 | """ 115 | assert Lab_1.size(1) == 3 and Lab_2.size(1) == 3 116 | assert Lab_1.dtype == Lab_2.dtype 117 | dtype = Lab_1.dtype 118 | π = torch.tensor(math.pi, dtype=dtype, device=Lab_1.device) 119 | π_compare = π if dtype == torch.float64 else torch.tensor(math.pi, dtype=torch.float64, device=Lab_1.device) 120 | 121 | L_star_1, a_star_1, b_star_1 = Lab_1.unbind(dim=1) 122 | L_star_2, a_star_2, b_star_2 = Lab_2.unbind(dim=1) 123 | 124 | C_star_1 = torch.norm(torch.stack((a_star_1, b_star_1), dim=1), p=2, dim=1) 125 | C_star_2 = torch.norm(torch.stack((a_star_2, b_star_2), dim=1), p=2, dim=1) 126 | C_star_bar = (C_star_1 + C_star_2) / 2 127 | C7 = C_star_bar ** 7 128 | G = 0.5 * (1 - (C7 / (C7 + 25 ** 7)).clamp(min=ε).sqrt()) 129 | 130 | scale = 1 + G 131 | a_1 = scale * a_star_1 132 | a_2 = scale * a_star_2 133 | C_1 = torch.norm(torch.stack((a_1, b_star_1), dim=1), p=2, dim=1) 134 | C_2 = torch.norm(torch.stack((a_2, b_star_2), dim=1), p=2, dim=1) 135 | C_1_C_2_zero = (C_1 == 0) | (C_2 == 0) 136 | h_1 = torch.atan2(b_star_1, a_1 + ε * (a_1 == 0)) 137 | h_2 = torch.atan2(b_star_2, a_2 + ε * (a_2 == 0)) 138 | 139 | # required to match the test data 140 | h_abs_diff_compare = (torch.atan2(b_star_1.to(dtype=torch.float64), 141 | a_1.to(dtype=torch.float64)).remainder(2 * π_compare) - 142 | torch.atan2(b_star_2.to(dtype=torch.float64), 143 | a_2.to(dtype=torch.float64)).remainder(2 * π_compare)).abs() <= π_compare 144 | 145 | h_1 = h_1.remainder(2 * π) 146 | h_2 = h_2.remainder(2 * π) 147 | h_diff = h_2 - h_1 148 | h_sum = h_1 + h_2 149 | 150 | ΔL = L_star_2 - L_star_1 151 | ΔC = C_2 - C_1 152 | Δh = torch.where(C_1_C_2_zero, torch.zeros_like(h_1), 153 | torch.where(h_abs_diff_compare, h_diff, 154 | torch.where(h_diff > π, h_diff - 2 * π, h_diff + 2 * π))) 155 | 156 | ΔH = 2 * (C_1 * C_2).clamp(min=ε).sqrt() * torch.sin(Δh / 2) 157 | ΔH_squared = 4 * C_1 * C_2 * torch.sin(Δh / 2) ** 2 158 | 159 | L_bar = (L_star_1 + L_star_2) / 2 160 | C_bar = (C_1 + C_2) / 2 161 | 162 | h_bar = torch.where(C_1_C_2_zero, h_sum, 163 | torch.where(h_abs_diff_compare, h_sum / 2, 164 | torch.where(h_sum < 2 * π, h_sum / 2 + π, h_sum / 2 - π))) 165 | 166 | T = 1 - 0.17 * (h_bar - π / 6).cos() + 0.24 * (2 * h_bar).cos() + \ 167 | 0.32 * (3 * h_bar + π / 30).cos() - 0.20 * (4 * h_bar - 63 * π / 180).cos() 168 | 169 | Δθ = π / 6 * (torch.exp(-((180 / π * h_bar - 275) / 25) ** 2)) 170 | C7 = C_bar ** 7 171 | R_C = 2 * (C7 / (C7 + 25 ** 7)).clamp(min=ε).sqrt() 172 | S_L = 1 + 0.015 * (L_bar - 50) ** 2 / torch.sqrt(20 + (L_bar - 50) ** 2) 173 | S_C = 1 + 0.045 * C_bar 174 | S_H = 1 + 0.015 * C_bar * T 175 | R_T = -torch.sin(2 * Δθ) * R_C 176 | 177 | ΔE_00 = (ΔL / (k_L * S_L)) ** 2 + (ΔC / (k_C * S_C)) ** 2 + ΔH_squared / (k_H * S_H) ** 2 + \ 178 | R_T * (ΔC / (k_C * S_C)) * (ΔH / (k_H * S_H)) 179 | if squared: 180 | return ΔE_00 181 | return ΔE_00.clamp(min=ε).sqrt() 182 | 183 | 184 | def rgb_ciede2000_color_difference(input: Tensor, target: Tensor, **kwargs) -> Tensor: 185 | """Computes the CIEDE2000 Color-Difference from RGB inputs.""" 186 | return ciede2000_color_difference(*map(rgb_to_cielab, (input, target)), **kwargs) 187 | 188 | 189 | def ciede2000_loss(x1: Tensor, x2: Tensor, squared: bool = False, **kwargs) -> Tensor: 190 | """ 191 | Computes the L2-norm over all pixels of the CIEDE2000 Color-Difference for two RGB inputs. 192 | 193 | Parameters 194 | ---------- 195 | x1 : Tensor: 196 | First input. 197 | x2 : Tensor: 198 | Second input (of size matching x1). 199 | squared : bool 200 | Returns the squared L2-norm. 201 | 202 | Returns 203 | ------- 204 | ΔE_00_l2 : Tensor 205 | The L2-norm over all pixels of the CIEDE2000 Color-Difference. 206 | 207 | """ 208 | ΔE_00 = rgb_ciede2000_color_difference(x1, x2, squared=True, **kwargs).flatten(1) 209 | ε = kwargs.get('ε', 0) 210 | if squared: 211 | return ΔE_00.sum(1) 212 | return ΔE_00.sum(1).clamp(min=ε).sqrt() 213 | -------------------------------------------------------------------------------- /adv_lib/distances/lp_norms.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from typing import Union 3 | 4 | from torch import Tensor 5 | 6 | 7 | def lp_distances(x1: Tensor, x2: Tensor, p: Union[float, int] = 2, dim: int = 1) -> Tensor: 8 | return (x1 - x2).flatten(dim).norm(p=p, dim=dim) 9 | 10 | 11 | l0_distances = partial(lp_distances, p=0) 12 | l1_distances = partial(lp_distances, p=1) 13 | l2_distances = partial(lp_distances, p=2) 14 | linf_distances = partial(lp_distances, p=float('inf')) 15 | 16 | 17 | def squared_l2_distances(x1: Tensor, x2: Tensor, dim: int = 1) -> Tensor: 18 | return (x1 - x2).square().flatten(dim).sum(dim) 19 | -------------------------------------------------------------------------------- /adv_lib/distances/lpips.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Optional, Union 2 | 3 | import torch 4 | from torch import nn, Tensor 5 | from torchvision import models 6 | 7 | from adv_lib.utils import ImageNormalizer, requires_grad_ 8 | 9 | 10 | class AlexNetFeatures(nn.Module): 11 | def __init__(self) -> None: 12 | super(AlexNetFeatures, self).__init__() 13 | self.normalize = ImageNormalizer(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)) 14 | self.model = models.alexnet(pretrained=True) 15 | self.model.eval() 16 | 17 | self.features_layers = nn.ModuleList([ 18 | self.model.features[:2], 19 | self.model.features[2:5], 20 | self.model.features[5:8], 21 | self.model.features[8:10], 22 | self.model.features[10:12], 23 | ]) 24 | 25 | requires_grad_(self, False) 26 | 27 | def forward(self, x: Tensor) -> Tuple[Tensor, ...]: 28 | return self.features(x) 29 | 30 | def features(self, x: Tensor) -> Tuple[Tensor, ...]: 31 | x = self.normalize(x) 32 | 33 | features = [x] 34 | for i, layer in enumerate(self.features_layers): 35 | features.append(layer(features[i])) 36 | 37 | return tuple(features[1:]) 38 | 39 | 40 | def _normalize_features(x: Tensor, ε: float = 1e-12) -> Tensor: 41 | """Normalize by norm and sqrt of spatial size.""" 42 | norm = torch.norm(x, dim=1, p=2, keepdim=True) 43 | return x / (norm[0].numel() ** 0.5 * norm.clamp(min=ε)) 44 | 45 | 46 | def _feature_difference(features_1: Tensor, features_2: Tensor, linear_mapping: Optional[nn.Module] = None) -> Tensor: 47 | features = [map(_normalize_features, feature) for feature in [features_1, features_2]] # Normalize features 48 | if linear_mapping is not None: # Perform linear scaling 49 | features = [[module(f) for module, f in zip(linear_mapping, feature)] for feature in features] 50 | features = [torch.cat([f.flatten(1) for f in feature], dim=1) for feature in features] # Concatenate 51 | return features[0] - features[1] 52 | 53 | 54 | class LPIPS(nn.Module): 55 | _models = {'alexnet': AlexNetFeatures} 56 | 57 | def __init__(self, 58 | model: Union[str, nn.Module] = 'alexnet', 59 | linear_mapping: Optional[str] = None, 60 | target: Optional[Tensor] = None, 61 | squared: bool = False) -> None: 62 | super(LPIPS, self).__init__() 63 | 64 | if isinstance(model, str): 65 | self.features = self._models[model]() 66 | else: 67 | self.features = model 68 | 69 | self.linear_mapping = None 70 | if linear_mapping is not None: 71 | convs = [] 72 | sd = torch.load(linear_mapping) 73 | for k, weight in sd.items(): 74 | out_channels, in_channels = weight.shape[:2] 75 | conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=False) 76 | conv.weight.data.copy_(weight) 77 | convs.append(conv) 78 | self.linear_mapping = nn.ModuleList(convs) 79 | 80 | self.target_features = None 81 | if target is not None: 82 | self.to(target.device) 83 | self.target_features = self.features(target) 84 | 85 | self.squared = squared 86 | 87 | def forward(self, input: Tensor, target: Optional[Tensor] = None) -> Tensor: 88 | input_features = self.features(input) 89 | if target is None and self.target_features is not None: 90 | target_features = self.target_features 91 | elif target is not None: 92 | target_features = self.features(target) 93 | else: 94 | raise ValueError('Must provide targets (either in init or in forward).') 95 | 96 | if self.squared: 97 | return _feature_difference(input_features, target_features).square().sum(dim=1) 98 | 99 | return torch.norm(_feature_difference(input_features, target_features), p=2, dim=1) 100 | -------------------------------------------------------------------------------- /adv_lib/distances/structural_similarity.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/pytorch/pytorch/pull/22289 2 | from functools import lru_cache 3 | from typing import Tuple 4 | 5 | import torch 6 | from torch import Tensor 7 | from torch.nn import _reduction as _Reduction 8 | from torch.nn.functional import avg_pool2d, conv2d 9 | 10 | 11 | @lru_cache() 12 | def _fspecial_gaussian(size: int, channel: int, sigma: float, device: torch.device, dtype: torch.dtype, 13 | max_size: Tuple[int, int]) -> Tensor: 14 | coords = -(torch.arange(size, device=device, dtype=dtype) - (size - 1) / 2) ** 2 / (2. * sigma ** 2) 15 | if max(max_size) <= size: 16 | coords_x, coords_y = torch.zeros(max_size[0], device=device, dtype=dtype), torch.zeros(max_size[1], 17 | device=device, 18 | dtype=dtype) 19 | elif max_size[0] <= size: 20 | coords_x, coords_y = torch.zeros(max_size[0], device=device, dtype=dtype), coords 21 | elif max_size[1] <= size: 22 | coords_x, coords_y = coords, torch.zeros(max_size[1], device=device, dtype=dtype) 23 | else: 24 | coords_x = coords_y = coords 25 | final_size = (min(max_size[0], size), min(max_size[1], size)) 26 | 27 | grid = coords_x.view(-1, 1) + coords_y.view(1, -1) 28 | kernel = grid.view(1, -1).softmax(-1).view(1, 1, *final_size).expand(channel, 1, -1, -1).contiguous() 29 | return kernel 30 | 31 | 32 | def _ssim(input: Tensor, target: Tensor, max_val: float, k1: float, k2: float, channel: int, 33 | kernel: Tensor) -> Tuple[Tensor, Tensor]: 34 | c1 = (k1 * max_val) ** 2 35 | c2 = (k2 * max_val) ** 2 36 | 37 | mu1 = conv2d(input, kernel, groups=channel) 38 | mu2 = conv2d(target, kernel, groups=channel) 39 | 40 | mu1_sq = mu1 ** 2 41 | mu2_sq = mu2 ** 2 42 | mu1_mu2 = mu1 * mu2 43 | 44 | sigma1_sq = conv2d(input * input, kernel, groups=channel) - mu1_sq 45 | sigma2_sq = conv2d(target * target, kernel, groups=channel) - mu2_sq 46 | sigma12 = conv2d(input * target, kernel, groups=channel) - mu1_mu2 47 | 48 | v1 = 2 * sigma12 + c2 49 | v2 = sigma1_sq + sigma2_sq + c2 50 | 51 | ssim = ((2 * mu1_mu2 + c1) * v1) / ((mu1_sq + mu2_sq + c1) * v2) 52 | return ssim, v1 / v2 53 | 54 | 55 | def ssim(input: Tensor, target: Tensor, max_val: float, filter_size: int = 11, k1: float = 0.01, k2: float = 0.03, 56 | sigma: float = 1.5, size_average=None, reduce=None, reduction: str = 'mean') -> Tensor: 57 | """Measures the structural similarity index (SSIM) error.""" 58 | dim = input.dim() 59 | if dim != 4: 60 | raise ValueError('Expected 4 dimensions (got {})'.format(dim)) 61 | 62 | if input.size() != target.size(): 63 | raise ValueError('Expected input size ({}) to match target size ({}).'.format(input.size(0), target.size(0))) 64 | 65 | if size_average is not None or reduce is not None: 66 | reduction = _Reduction.legacy_get_string(size_average, reduce) 67 | 68 | channel = input.size(1) 69 | kernel = _fspecial_gaussian(filter_size, channel, sigma, device=input.device, dtype=input.dtype, 70 | max_size=input.shape[-2:]) 71 | ret, _ = _ssim(input, target, max_val, k1, k2, channel, kernel) 72 | 73 | if reduction != 'none': 74 | ret = torch.mean(ret) if reduction == 'mean' else torch.sum(ret) 75 | return ret 76 | 77 | 78 | def compute_ssim(input: Tensor, target: Tensor, **kwargs) -> Tensor: 79 | c_ssim = ssim(input=input, target=target, max_val=1, reduction='none', **kwargs).mean([2, 3]) 80 | return c_ssim.mean(1) 81 | 82 | 83 | def ssim_loss(*args, **kwargs) -> Tensor: 84 | return 1 - compute_ssim(*args, **kwargs) 85 | 86 | 87 | @lru_cache() 88 | def ms_weights(device: torch.device): 89 | return torch.tensor([0.0448, 0.2856, 0.3001, 0.2363, 0.1333], device=device) 90 | 91 | 92 | def ms_ssim(input: Tensor, target: Tensor, max_val: float, filter_size: int = 11, k1: float = 0.01, k2: float = 0.03, 93 | sigma: float = 1.5, size_average=None, reduce=None, reduction: str = 'mean') -> Tensor: 94 | """Measures the multi-scale structural similarity index (MS-SSIM) error.""" 95 | dim = input.dim() 96 | if dim != 4: 97 | raise ValueError('Expected 4 dimensions (got {}) from input'.format(dim)) 98 | 99 | if input.size() != target.size(): 100 | raise ValueError('Expected input size ({}) to match target size ({}).'.format(input.size(0), target.size(0))) 101 | 102 | if size_average is not None or reduce is not None: 103 | reduction = _Reduction.legacy_get_string(size_average, reduce) 104 | 105 | channel = input.size(1) 106 | kernel = _fspecial_gaussian(filter_size, channel, sigma, device=input.device, dtype=input.dtype, 107 | max_size=input.shape[-2:]) 108 | 109 | weights = ms_weights(input.device).unsqueeze(-1).unsqueeze(-1) 110 | levels = weights.size(0) 111 | mssim = [] 112 | mcs = [] 113 | for i in range(levels): 114 | 115 | if i: 116 | input = avg_pool2d(input, kernel_size=2, ceil_mode=True) 117 | target = avg_pool2d(target, kernel_size=2, ceil_mode=True) 118 | 119 | if min(size := input.shape[-2:]) <= filter_size: 120 | kernel = _fspecial_gaussian(filter_size, channel, sigma, device=input.device, dtype=input.dtype, 121 | max_size=size) 122 | 123 | ssim, cs = _ssim(input, target, max_val, k1, k2, channel, kernel) 124 | ssim = ssim.mean((2, 3)) 125 | cs = cs.mean((2, 3)) 126 | mssim.append(ssim) 127 | mcs.append(cs) 128 | 129 | mssim = torch.stack(mssim) 130 | mcs = torch.stack(mcs) 131 | p1 = mcs ** weights 132 | p2 = mssim ** weights 133 | 134 | ret = torch.prod(p1[:-1], 0) * p2[-1] 135 | 136 | if reduction != 'none': 137 | ret = torch.mean(ret) if reduction == 'mean' else torch.sum(ret) 138 | return ret 139 | 140 | 141 | def compute_ms_ssim(input: Tensor, target: Tensor, **kwargs) -> Tensor: 142 | channel_ssim = ms_ssim(input=input, target=target, max_val=1, reduction='none', **kwargs) 143 | return channel_ssim.mean(1) 144 | 145 | 146 | def ms_ssim_loss(*args, **kwargs) -> Tensor: 147 | return 1 - compute_ms_ssim(*args, **kwargs) 148 | -------------------------------------------------------------------------------- /adv_lib/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import ForwardCounter, BackwardCounter, ImageNormalizer, normalize_model, predict_inputs, requires_grad_ -------------------------------------------------------------------------------- /adv_lib/utils/attack_utils.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from collections import OrderedDict 3 | from distutils.version import LooseVersion 4 | from functools import partial 5 | from inspect import isclass 6 | from typing import Callable, Optional, Dict, Union 7 | 8 | import numpy as np 9 | import torch 10 | import tqdm 11 | from torch import Tensor, nn 12 | from torch.nn import functional as F 13 | 14 | from adv_lib.distances.lp_norms import l0_distances, l1_distances, l2_distances, linf_distances 15 | from adv_lib.utils import ForwardCounter, BackwardCounter, predict_inputs 16 | 17 | 18 | def generate_random_targets(labels: Tensor, num_classes: int) -> Tensor: 19 | """ 20 | Generates one random target in (num_classes - 1) possibilities for each label that is different from the original 21 | label. 22 | 23 | Parameters 24 | ---------- 25 | labels: Tensor 26 | Original labels. Generated targets will be different from labels. 27 | num_classes: int 28 | Number of classes to generate the random targets from. 29 | 30 | Returns 31 | ------- 32 | targets: Tensor 33 | Random target for each label. Has the same shape as labels. 34 | 35 | """ 36 | random = torch.rand(len(labels), num_classes, device=labels.device, dtype=torch.float) 37 | random.scatter_(1, labels.unsqueeze(-1), 0) 38 | return random.argmax(1) 39 | 40 | 41 | def get_all_targets(labels: Tensor, num_classes: int) -> Tensor: 42 | """ 43 | Generates all possible targets that are different from the original labels. 44 | 45 | Parameters 46 | ---------- 47 | labels: Tensor 48 | Original labels. Generated targets will be different from labels. 49 | num_classes: int 50 | Number of classes to generate the random targets from. 51 | 52 | Returns 53 | ------- 54 | targets: Tensor 55 | Random targets for each label. shape: (len(labels), num_classes - 1). 56 | 57 | """ 58 | assert labels.ndim == 1 59 | all_possible_targets = torch.arange(num_classes - 1, dtype=torch.long, device=labels.device) 60 | all_possible_targets = all_possible_targets + (all_possible_targets >= labels.unsqueeze(1)) 61 | return all_possible_targets 62 | 63 | 64 | def run_attack(model: nn.Module, 65 | inputs: Tensor, 66 | labels: Tensor, 67 | attack: Callable, 68 | targets: Optional[Tensor] = None, 69 | batch_size: Optional[int] = None) -> dict: 70 | device = next(model.parameters()).device 71 | to_device = lambda tensor: tensor.to(device) 72 | targeted, adv_labels = False, labels 73 | if targets is not None: 74 | targeted, adv_labels = True, targets 75 | batch_size = batch_size or len(inputs) 76 | 77 | # run attack only on non already adversarial samples 78 | already_adv = [] 79 | chunks = [tensor.split(batch_size) for tensor in [inputs, adv_labels]] 80 | for (inputs_chunk, label_chunk) in zip(*chunks): 81 | batch_chunk_d, label_chunk_d = [to_device(tensor) for tensor in [inputs_chunk, label_chunk]] 82 | preds = model(batch_chunk_d).argmax(1) 83 | is_adv = (preds == label_chunk_d) if targeted else (preds != label_chunk_d) 84 | already_adv.append(is_adv.cpu()) 85 | not_adv = ~torch.cat(already_adv, 0) 86 | 87 | start, end = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True) 88 | forward_counter, backward_counter = ForwardCounter(), BackwardCounter() 89 | model.register_forward_pre_hook(forward_counter) 90 | if LooseVersion(torch.__version__) >= LooseVersion('1.8'): 91 | model.register_full_backward_hook(backward_counter) 92 | else: 93 | model.register_backward_hook(backward_counter) 94 | average_forwards, average_backwards = [], [] # number of forward and backward calls per sample 95 | advs_chunks = [] 96 | chunks = [tensor.split(batch_size) for tensor in [inputs[not_adv], adv_labels[not_adv]]] 97 | total_time = 0 98 | for (inputs_chunk, label_chunk) in tqdm.tqdm(zip(*chunks), ncols=80, total=len(chunks[0])): 99 | batch_chunk_d, label_chunk_d = [to_device(tensor.clone()) for tensor in [inputs_chunk, label_chunk]] 100 | 101 | start.record() 102 | advs_chunk_d = attack(model, batch_chunk_d, label_chunk_d, targeted=targeted) 103 | 104 | # performance monitoring 105 | end.record() 106 | torch.cuda.synchronize() 107 | total_time += (start.elapsed_time(end)) / 1000 # times for cuda Events are in milliseconds 108 | average_forwards.append(forward_counter.num_samples_called / len(batch_chunk_d)) 109 | average_backwards.append(backward_counter.num_samples_called / len(batch_chunk_d)) 110 | forward_counter.reset(), backward_counter.reset() 111 | 112 | advs_chunks.append(advs_chunk_d.cpu()) 113 | if isinstance(attack, partial) and (callback := attack.keywords.get('callback')) is not None: 114 | callback.reset_windows() 115 | 116 | adv_inputs = inputs.clone() 117 | adv_inputs[not_adv] = torch.cat(advs_chunks, 0) 118 | 119 | data = { 120 | 'inputs': inputs, 121 | 'labels': labels, 122 | 'targets': adv_labels if targeted else None, 123 | 'adv_inputs': adv_inputs, 124 | 'time': total_time, 125 | 'num_forwards': sum(average_forwards) / len(chunks[0]), 126 | 'num_backwards': sum(average_backwards) / len(chunks[0]), 127 | } 128 | 129 | return data 130 | 131 | 132 | _default_metrics = OrderedDict([ 133 | ('linf', linf_distances), 134 | ('l0', l0_distances), 135 | ('l1', l1_distances), 136 | ('l2', l2_distances), 137 | ]) 138 | 139 | 140 | def compute_attack_metrics(model: nn.Module, 141 | attack_data: Dict[str, Union[Tensor, float]], 142 | batch_size: Optional[int] = None, 143 | metrics: Dict[str, Callable] = _default_metrics) -> Dict[str, Union[Tensor, float]]: 144 | inputs, labels, targets, adv_inputs = map(attack_data.get, ['inputs', 'labels', 'targets', 'adv_inputs']) 145 | if adv_inputs.min() < 0 or adv_inputs.max() > 1: 146 | warnings.warn('Values of produced adversarials are not in the [0, 1] range -> Clipping to [0, 1].') 147 | adv_inputs.clamp_(min=0, max=1) 148 | device = next(model.parameters()).device 149 | to_device = lambda tensor: tensor.to(device) 150 | 151 | batch_size = batch_size or len(inputs) 152 | chunks = [tensor.split(batch_size) for tensor in [inputs, labels, adv_inputs]] 153 | all_predictions = [[] for _ in range(6)] 154 | distances = {k: [] for k in metrics.keys()} 155 | metrics = {k: v().to(device) if (isclass(v.func) if isinstance(v, partial) else False) else v for k, v in 156 | metrics.items()} 157 | 158 | append = lambda list, data: list.append(data.cpu()) 159 | for inputs_chunk, labels_chunk, adv_chunk in zip(*chunks): 160 | inputs_chunk, adv_chunk = map(to_device, [inputs_chunk, adv_chunk]) 161 | clean_preds, adv_preds = [predict_inputs(model, chunk.to(device)) for chunk in [inputs_chunk, adv_chunk]] 162 | list(map(append, all_predictions, [*clean_preds, *adv_preds])) 163 | for metric, metric_func in metrics.items(): 164 | distances[metric].append(metric_func(adv_chunk, inputs_chunk).detach().cpu()) 165 | 166 | logits, probs, preds, logits_adv, probs_adv, preds_adv = [torch.cat(l) for l in all_predictions] 167 | for metric in metrics.keys(): 168 | distances[metric] = torch.cat(distances[metric], 0) 169 | 170 | accuracy_orig = (preds == labels).float().mean().item() 171 | if targets is not None: 172 | success = (preds_adv == targets) 173 | labels = targets 174 | else: 175 | success = (preds_adv != labels) 176 | 177 | prob_orig = probs.gather(1, labels.unsqueeze(1)).squeeze(1) 178 | prob_adv = probs_adv.gather(1, labels.unsqueeze(1)).squeeze(1) 179 | labels_infhot = torch.zeros_like(logits_adv).scatter_(1, labels.unsqueeze(1), float('inf')) 180 | real = logits_adv.gather(1, labels.unsqueeze(1)).squeeze(1) 181 | other = (logits_adv - labels_infhot).amax(dim=1) 182 | diff_vs_max_adv = (real - other) 183 | nll = F.cross_entropy(logits, labels, reduction='none') 184 | nll_adv = F.cross_entropy(logits_adv, labels, reduction='none') 185 | 186 | data = { 187 | 'time': attack_data['time'], 188 | 'num_forwards': attack_data['num_forwards'], 189 | 'num_backwards': attack_data['num_backwards'], 190 | 'targeted': targets is not None, 191 | 'preds': preds, 192 | 'adv_preds': preds_adv, 193 | 'accuracy_orig': accuracy_orig, 194 | 'success': success, 195 | 'probs_orig': prob_orig, 196 | 'probs_adv': prob_adv, 197 | 'logit_diff_adv': diff_vs_max_adv, 198 | 'nll': nll, 199 | 'nll_adv': nll_adv, 200 | 'distances': distances, 201 | } 202 | 203 | return data 204 | 205 | 206 | def print_metrics(metrics: dict) -> None: 207 | np.set_printoptions(formatter={'float': '{:0.3f}'.format}, threshold=16, edgeitems=3, 208 | linewidth=120) # To print arrays with less precision 209 | print('Original accuracy: {:.2%}'.format(metrics['accuracy_orig'])) 210 | print('Attack done in: {:.2f}s with {:.4g} forwards and {:.4g} backwards.'.format( 211 | metrics['time'], metrics['num_forwards'], metrics['num_backwards'])) 212 | success = metrics['success'].numpy() 213 | fail = bool(success.mean() != 1) 214 | print('Attack success: {:.2%}'.format(success.mean()) + fail * ' - {}'.format(success)) 215 | for distance, values in metrics['distances'].items(): 216 | data = values.numpy() 217 | print('{}: {} - Average: {:.3f} - Median: {:.3f}'.format(distance, data, data.mean(), np.median(data)) + 218 | fail * ' | Avg over success: {:.3f}'.format(data[success].mean())) 219 | attack_type = 'targets' if metrics['targeted'] else 'correct' 220 | print('Logit({} class) - max_Logit(other classes): {} - Average: {:.2f}'.format( 221 | attack_type, metrics['logit_diff_adv'].numpy(), metrics['logit_diff_adv'].numpy().mean())) 222 | print('NLL of target/pred class: {:.3f}'.format(metrics['nll_adv'].numpy().mean())) 223 | -------------------------------------------------------------------------------- /adv_lib/utils/color_conversions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | 4 | _ycbcr_conversions = { 5 | 'rec_601': (0.299, 0.587, 0.114), 6 | 'rec_709': (0.2126, 0.7152, 0.0722), 7 | 'rec_2020': (0.2627, 0.678, 0.0593), 8 | 'smpte_240m': (0.212, 0.701, 0.087), 9 | } 10 | 11 | 12 | def rgb_to_ycbcr(input: Tensor, standard: str = 'rec_2020'): 13 | kr, kg, kb = _ycbcr_conversions[standard] 14 | conversion_matrix = torch.tensor([[kr, kg, kb], 15 | [-0.5 * kr / (1 - kb), -0.5 * kg / (1 - kb), 0.5], 16 | [0.5, -0.5 * kg / (1 - kr), -0.5 * kb / (1 - kr)]], device=input.device) 17 | return torch.einsum('mc,nchw->nmhw', conversion_matrix, input) 18 | 19 | 20 | def ycbcr_to_rgb(input: Tensor, standard: str = 'rec_2020'): 21 | kr, kg, kb = _ycbcr_conversions[standard] 22 | conversion_matrix = torch.tensor([[1, 0, 2 - 2 * kr], 23 | [1, -kb / kg * (2 - 2 * kb), -kr / kg * (2 - 2 * kr)], 24 | [1, 2 - 2 * kb, 0]], device=input.device) 25 | return torch.einsum('mc,nchw->nmhw', conversion_matrix, input) 26 | 27 | 28 | _xyz_conversions = { 29 | 'CIE_RGB': ((0.4887180, 0.3106803, 0.2006017), 30 | (0.1762044, 0.8129847, 0.0108109), 31 | (0.0000000, 0.0102048, 0.9897952)), 32 | 'sRGB': ((0.4124564, 0.3575761, 0.1804375), 33 | (0.2126729, 0.7151522, 0.0721750), 34 | (0.0193339, 0.1191920, 0.9503041)) 35 | } 36 | 37 | 38 | def rgb_to_xyz(input: Tensor, rgb_space: str = 'sRGB'): 39 | conversion_matrix = torch.tensor(_xyz_conversions[rgb_space], device=input.device) 40 | # Inverse sRGB companding 41 | v = torch.where(input <= 0.04045, input / 12.92, ((input + 0.055) / 1.055) ** 2.4) 42 | return torch.einsum('mc,nchw->nmhw', conversion_matrix, v) 43 | 44 | 45 | _delta = 6 / 29 46 | 47 | 48 | def cielab_func(input: Tensor) -> Tensor: 49 | # torch.where produces NaNs in backward if one of the choice produces NaNs or infs in backward (here .pow(1/3)) 50 | return torch.where(input > _delta ** 3, input.clamp(min=_delta ** 3).pow(1 / 3), input / (3 * _delta ** 2) + 4 / 29) 51 | 52 | 53 | def cielab_inverse_func(input: Tensor) -> Tensor: 54 | return torch.where(input > _delta, input.pow(3), 3 * _delta ** 2 * (input - 4 / 29)) 55 | 56 | 57 | _cielab_conversions = { 58 | 'illuminant_d50': (96.4212, 100, 82.5188), 59 | 'illuminant_d65': (95.0489, 100, 108.884), 60 | } 61 | 62 | 63 | def rgb_to_cielab(input: Tensor, standard: str = 'illuminant_d65') -> Tensor: 64 | # Convert to XYZ 65 | XYZ_input = rgb_to_xyz(input=input) 66 | 67 | Xn, Yn, Zn = _cielab_conversions[standard] 68 | L_star = 116 * cielab_func(XYZ_input.narrow(1, 1, 1) / Yn) - 16 69 | a_star = 500 * (cielab_func(XYZ_input.narrow(1, 0, 1) / Xn) - cielab_func(XYZ_input.narrow(1, 1, 1) / Yn)) 70 | b_star = 200 * (cielab_func(XYZ_input.narrow(1, 1, 1) / Yn) - cielab_func(XYZ_input.narrow(1, 2, 1) / Zn)) 71 | return torch.cat((L_star, a_star, b_star), 1) 72 | -------------------------------------------------------------------------------- /adv_lib/utils/image_selection.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch import Tensor 6 | from torch.utils.data import Dataset, DataLoader 7 | 8 | 9 | def select_images(model: nn.Module, dataset: Dataset, num_images: int, correct_only: bool = False, 10 | random: bool = False) -> Tuple[Tensor, Tensor]: 11 | device = next(model.parameters()).device 12 | loader = DataLoader(dataset=dataset, batch_size=1, shuffle=random) 13 | selected_images, selected_labels = [], [] 14 | for (image, label) in loader: 15 | if correct_only: 16 | correct = model(image.to(device)).argmax(1) == label.to(device) 17 | if correct.all(): 18 | selected_images.append(image), selected_labels.append(label) 19 | elif not correct_only: 20 | selected_images.append(image), selected_labels.append(label) 21 | 22 | if len(selected_images) == num_images: 23 | break 24 | else: 25 | print('Could only find {} correctly classified images.'.format(len(selected_images))) 26 | 27 | return torch.cat(selected_images, 0), torch.cat(selected_labels, 0) 28 | -------------------------------------------------------------------------------- /adv_lib/utils/lagrangian_penalties/__init__.py: -------------------------------------------------------------------------------- 1 | from .all_penalties import all_penalties -------------------------------------------------------------------------------- /adv_lib/utils/lagrangian_penalties/all_penalties.py: -------------------------------------------------------------------------------- 1 | from .penalty_functions import * 2 | from .univariate_functions import * 3 | 4 | univariates_P4 = { 5 | 'Quad': Quadratic, 6 | 'FourThirds': FourThirds, 7 | 'Cosh': Cosh, 8 | } 9 | 10 | univariates_P5_P6_P7 = { 11 | 'Exp': Exp, 12 | 'LogExp': LogExp, 13 | 'LogQuad_1': LogQuad_1, 14 | 'LogQuad_2': LogQuad_2, 15 | 'HyperExp': HyperExp, 16 | 'HyperQuad': HyperQuad, 17 | 'DualLogQuad': DualLogQuad, 18 | 'CubicQuad': CubicQuad, 19 | 'ExpQuad': ExpQuad, 20 | 'LogBarrierQuad': LogBarrierQuad, 21 | 'HyperBarrierQuad': HyperBarrierQuad, 22 | 'HyperLogQuad': HyperLogQuad, 23 | 'SmoothPlus': SmoothPlus, 24 | 'NNSmoothPlus': NNSmoothPlus, 25 | 'ExpSmoothPlus': ExpSmoothPlus, 26 | } 27 | 28 | univariates_P8 = { 29 | 'LogExp': LogExp, 30 | 'LogQuad_1': LogQuad_1, 31 | 'LogQuad_2': LogQuad_2, 32 | 'HyperExp': HyperExp, 33 | 'HyperQuad': HyperQuad, 34 | 'DualLogQuad': DualLogQuad, 35 | 'CubicQuad': CubicQuad, 36 | 'ExpQuad': ExpQuad, 37 | 'LogBarrierQuad': LogBarrierQuad, 38 | 'HyperBarrierQuad': HyperBarrierQuad, 39 | 'HyperLogQuad': HyperLogQuad, 40 | } 41 | 42 | univariates_P9 = { 43 | 'SmoothPlus': SmoothPlus, 44 | 'NNSmoothPlus': NNSmoothPlus, 45 | 'ExpSmoothPlus': ExpSmoothPlus, 46 | } 47 | 48 | combinations = { 49 | 'PHRQuad': PHRQuad, 50 | 'P1': P1, 51 | 'P2': P2, 52 | 'P3': P3, 53 | 'P4': (P4, univariates_P4), 54 | 'P5': (P5, univariates_P5_P6_P7), 55 | 'P6': (P6, univariates_P5_P6_P7), 56 | 'P7': (P7, univariates_P5_P6_P7), 57 | 'P8': (P8, univariates_P8), 58 | 'P9': (P9, univariates_P9), 59 | } 60 | 61 | all_penalties = {} 62 | for p_name, penalty in combinations.items(): 63 | if isinstance(penalty, tuple): 64 | for θ_name, θ in penalty[1].items(): 65 | all_penalties['_'.join([p_name, θ_name])] = penalty[0](θ()) 66 | else: 67 | all_penalties[p_name] = penalty 68 | -------------------------------------------------------------------------------- /adv_lib/utils/lagrangian_penalties/penalty_functions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | 4 | __all__ = [ 5 | 'PHRQuad', 6 | 'P1', 7 | 'P2', 8 | 'P3', 9 | 'P4', 10 | 'P5', 11 | 'P6', 12 | 'P7', 13 | 'P8', 14 | 'P9' 15 | ] 16 | 17 | 18 | def PHRQuad(y: Tensor, ρ: Tensor, μ: Tensor) -> Tensor: 19 | return ((μ + ρ * y).relu().square() - μ ** 2).div(2 * ρ) 20 | 21 | 22 | def P1(y: Tensor, ρ: Tensor, μ: Tensor) -> Tensor: 23 | y_sup = μ * y + 0.5 * ρ * y ** 2 + ρ ** 2 * y ** 3 24 | y_mid = μ * y + 0.5 * ρ * y ** 2 25 | y_inf = - μ ** 2 / (2 * ρ) 26 | return torch.where(y >= 0, y_sup, torch.where(y <= -μ / ρ, y_inf, y_mid)) 27 | 28 | 29 | def P2(y: Tensor, ρ: Tensor, μ: Tensor) -> Tensor: 30 | y_sup = μ * y + μ * ρ * y ** 2 + 1 / 6 * ρ ** 2 * y ** 3 31 | y_inf = μ * y / (1 - ρ * y.clamp(max=0)) 32 | return torch.where(y >= 0, y_sup, y_inf) 33 | 34 | 35 | def P3(y: Tensor, ρ: Tensor, μ: Tensor) -> Tensor: 36 | y_sup = μ * y + μ * ρ * y ** 2 37 | y_inf = μ * y / (1 - ρ * y.clamp(max=0)) 38 | return torch.where(y >= 0, y_sup, y_inf) 39 | 40 | 41 | class GenericPenaltyLagrangian: 42 | def __init__(self, θ): 43 | self.θ = θ # univariate function 44 | 45 | 46 | class P4(GenericPenaltyLagrangian): 47 | def __call__(self, y: Tensor, ρ: Tensor, μ: Tensor) -> Tensor: 48 | y_sup = μ * y + self.θ(ρ * y) / ρ 49 | y_inf = self.θ.min(ρ, μ) 50 | return torch.where(self.θ.sup(y, ρ, μ), y_sup, y_inf) 51 | 52 | 53 | class P5(GenericPenaltyLagrangian): 54 | def __call__(self, y: Tensor, ρ: Tensor, μ: Tensor) -> Tensor: 55 | return self.θ.mul * self.θ(ρ * y) * μ / ρ 56 | 57 | 58 | class P6(GenericPenaltyLagrangian): 59 | def __call__(self, y: Tensor, ρ: Tensor, μ: Tensor) -> Tensor: 60 | return self.θ.mul * self.θ(ρ * μ * y) / ρ 61 | 62 | 63 | class P7(GenericPenaltyLagrangian): 64 | def __call__(self, y: Tensor, ρ: Tensor, μ: Tensor) -> Tensor: 65 | return self.θ.mul * self.θ(ρ * y / μ) * (μ ** 2) / ρ 66 | 67 | 68 | class P8(GenericPenaltyLagrangian): 69 | def __call__(self, y: Tensor, ρ: Tensor, μ: Tensor) -> Tensor: 70 | tilde_x = self.θ.tilde(μ) 71 | return (self.θ(ρ * y + tilde_x) - self.θ(tilde_x)) / ρ 72 | 73 | 74 | class P9(GenericPenaltyLagrangian): 75 | # Penalty-Lagrangian functions associated with P9 are not well defined for ρ <= μ, so we set ρ = max({ρ, 2μ}) 76 | def __call__(self, y: Tensor, ρ: Tensor, μ: Tensor) -> Tensor: 77 | ρ_adjusted = torch.max(ρ, 2 * μ) 78 | tilde_x = self.θ.tilde(μ, ρ_adjusted) 79 | return self.θ(ρ_adjusted * y + tilde_x) - self.θ(tilde_x) 80 | -------------------------------------------------------------------------------- /adv_lib/utils/lagrangian_penalties/scripts/plot_penalties.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import torch 3 | from cycler import cycler 4 | from torch.autograd import grad 5 | 6 | from adv_lib.utils.lagrangian_penalties.all_penalties import all_penalties 7 | 8 | fig, ax = plt.subplots(figsize=(12, 12)) 9 | x = torch.linspace(-10, 5, 15001, requires_grad=True) 10 | x[10000] = 0 11 | 12 | styles = (cycler(linestyle=['-', '--', '-.']) * cycler(color=plt.rcParams['axes.prop_cycle'])) 13 | ax.set_prop_cycle(styles) 14 | 15 | unique_penalties_idxs = list(range(len(all_penalties))) 16 | 17 | for i in unique_penalties_idxs: 18 | penalty = list(all_penalties.keys())[i] 19 | ρ = torch.tensor(1.) 20 | μ = torch.tensor(1.) 21 | y = all_penalties[penalty](x, ρ, μ) 22 | if y.isnan().any(): 23 | print('nan in y for {}'.format(penalty)) 24 | grads = grad(y.sum(), x, only_inputs=True)[0] 25 | if grads.isnan().any(): 26 | print('nan in grad for {}'.format(penalty)) 27 | if not torch.allclose(grads[10000], μ): 28 | print("P'(0, ρ, μ) = {:3g} = μ".format(grads[10000].item())) 29 | ax.plot(x.detach().numpy(), y.detach().numpy(), 30 | label=r'{}: $\nabla P(0)$:{:.3g}'.format(penalty, grads[10000].item())) 31 | 32 | ax.set_xlim(-10, 5) 33 | ax.set_ylim(-5, 10) 34 | ax.legend(loc=2, prop={'size': 6}) 35 | ax.set_aspect('equal') 36 | ax.set_xlabel(r'$x$') 37 | ax.set_ylabel(r'$Penalty(x)$') 38 | ax.grid(True, linestyle='--') 39 | 40 | plt.tight_layout() 41 | plt.show() 42 | -------------------------------------------------------------------------------- /adv_lib/utils/lagrangian_penalties/scripts/plot_univariates.py: -------------------------------------------------------------------------------- 1 | from inspect import isfunction 2 | 3 | import matplotlib.pyplot as plt 4 | import torch 5 | from cycler import cycler 6 | 7 | from adv_lib.utils.lagrangian_penalties import univariate_functions 8 | 9 | fig, ax = plt.subplots(figsize=(8, 8)) 10 | x = torch.linspace(-10, 5, 15001) 11 | 12 | styles = (cycler(linestyle=['-', '--']) * cycler(color=plt.rcParams['axes.prop_cycle'])) 13 | ax.set_prop_cycle(styles) 14 | 15 | for univariate in univariate_functions.__all__: 16 | if isfunction(univariate_functions.__dict__[univariate]): 17 | univariate_function = univariate_functions.__dict__[univariate] 18 | else: 19 | univariate_function = univariate_functions.__dict__[univariate]() 20 | y = univariate_function(x) 21 | ax.plot(x, y, label=univariate) 22 | 23 | ax.set_xlim(-10, 5) 24 | ax.set_ylim(-5, 10) 25 | ax.legend(loc=2) 26 | ax.set_aspect('equal') 27 | ax.set_xlabel(r'$x$') 28 | ax.set_ylabel(r'$Univariate(x)$') 29 | ax.grid(True, linestyle='--') 30 | 31 | plt.tight_layout() 32 | plt.show() 33 | -------------------------------------------------------------------------------- /adv_lib/utils/lagrangian_penalties/univariate_functions.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | from torch import Tensor 5 | from torch.nn import functional as F 6 | 7 | __all__ = [ 8 | 'Quadratic', 9 | 'FourThirds', 10 | 'Cosh', 11 | 'Exp', 12 | 'LogExp', 13 | 'LogQuad_1', 14 | 'LogQuad_2', 15 | 'HyperExp', 16 | 'HyperQuad', 17 | 'DualLogQuad', 18 | 'CubicQuad', 19 | 'ExpQuad', 20 | 'LogBarrierQuad', 21 | 'HyperBarrierQuad', 22 | 'HyperLogQuad', 23 | 'SmoothPlus', 24 | 'NNSmoothPlus', 25 | 'ExpSmoothPlus', 26 | ] 27 | 28 | 29 | def safe_exp(x: Tensor) -> Tensor: 30 | return torch.exp(x.clamp(max=87.5)) 31 | 32 | 33 | class Quadratic: 34 | @staticmethod 35 | def __call__(t: Tensor) -> Tensor: 36 | return 0.5 * t ** 2 37 | 38 | @staticmethod 39 | def min(ρ: Tensor, μ: Tensor) -> Tensor: 40 | return - μ ** 2 / (2 * ρ) 41 | 42 | @staticmethod 43 | def sup(t: Tensor, ρ: Tensor, μ: Tensor) -> Tensor: 44 | return μ + ρ * t >= 0 45 | 46 | 47 | class FourThirds: 48 | @staticmethod 49 | def __call__(t: Tensor) -> Tensor: 50 | return 0.75 * t.abs().pow(4 / 3) 51 | 52 | @staticmethod 53 | def min(ρ: Tensor, μ: Tensor) -> Tensor: 54 | return - μ ** 4 / (4 * ρ) 55 | 56 | @staticmethod 57 | def sup(t: Tensor, ρ: Tensor, μ: Tensor) -> Tensor: 58 | return ρ.pow(1 / 3) * t.sign() * t.abs().pow(1 / 3) + μ >= 0 59 | 60 | 61 | class Cosh: 62 | @staticmethod 63 | def __call__(t: Tensor) -> Tensor: 64 | return t.cosh() - 1 65 | 66 | @staticmethod 67 | def min(ρ: Tensor, μ: Tensor) -> Tensor: 68 | return μ / ρ * torch.asinh(-μ) + 1 / ρ * (torch.cosh(torch.asinh(-μ)) - 1) 69 | 70 | @staticmethod 71 | def sup(t: Tensor, ρ: Tensor, μ: Tensor) -> Tensor: 72 | return torch.sinh(μ + (ρ * t)) >= 0 73 | 74 | 75 | class Exp: 76 | mul = 1 77 | 78 | @staticmethod 79 | def __call__(t: Tensor) -> Tensor: 80 | return safe_exp(t) - 1 81 | 82 | @staticmethod 83 | def tilde(μ: Tensor) -> Tensor: 84 | return torch.log(μ) 85 | 86 | 87 | class LogExp: 88 | mul = 1 89 | 90 | @staticmethod 91 | def __call__(t: Tensor) -> Tensor: 92 | y_sup = safe_exp(2 * t - 1) + math.log(2) - 1 93 | y_inf = -torch.log(1 - t.clamp(max=0.5)) 94 | return torch.where(t >= 0.5, y_sup, y_inf) 95 | 96 | @staticmethod 97 | def tilde(μ: Tensor) -> Tensor: 98 | y_sup = 0.5 * (torch.log(μ / 2) + 1) 99 | y_inf = (μ - 1) / μ 100 | return torch.where(μ >= 2, y_sup, y_inf) 101 | 102 | 103 | class LogQuad_1: 104 | mul = 1 105 | 106 | @staticmethod 107 | def __call__(t: Tensor) -> Tensor: 108 | y_sup = 2 * t ** 2 + math.log(2) - 0.5 109 | y_inf = -torch.log(1 - t.clamp(max=0.5)) 110 | return torch.where(t >= 0.5, y_sup, y_inf) 111 | 112 | @staticmethod 113 | def tilde(μ: Tensor) -> Tensor: 114 | y_sup = μ / 4 115 | y_inf = (μ - 1) / μ 116 | return torch.where(μ >= 2, y_sup, y_inf) 117 | 118 | 119 | class LogQuad_2: 120 | mul = 1 121 | 122 | @staticmethod 123 | def __call__(t: Tensor) -> Tensor: 124 | y_sup = 0.5 * t ** 2 + t 125 | y_inf = -0.25 * torch.log(-2 * t.clamp(max=-0.5)) - 0.375 126 | return torch.where(t >= -0.5, y_sup, y_inf) 127 | 128 | @staticmethod 129 | def tilde(μ: Tensor) -> Tensor: 130 | y_sup = μ - 1 131 | y_inf = -1 / (4 * μ) 132 | return torch.where(μ >= 0.5, y_sup, y_inf) 133 | 134 | 135 | class HyperExp: 136 | mul = 1 137 | 138 | @staticmethod 139 | def __call__(t: Tensor) -> Tensor: 140 | y_sup = safe_exp(4 * t - 2) 141 | y_inf = t / (1 - t.clamp(max=0.5)) 142 | return torch.where(t >= 0.5, y_sup, y_inf) 143 | 144 | @staticmethod 145 | def tilde(μ: Tensor) -> Tensor: 146 | y_sup = torch.log(μ / 4) / 4 + 0.5 147 | y_inf = 1 - 1 / torch.sqrt(μ) 148 | return torch.where(μ >= 4, y_sup, y_inf) 149 | 150 | 151 | class HyperQuad: 152 | mul = 1 153 | 154 | @staticmethod 155 | def __call__(t: Tensor) -> Tensor: 156 | y_sup = 8 * t ** 2 - 4 * t + 1 157 | y_inf = t / (1 - t.clamp(max=0.5)) 158 | return torch.where(t >= 0.5, y_sup, y_inf) 159 | 160 | @staticmethod 161 | def tilde(μ: Tensor) -> Tensor: 162 | y_sup = (μ + 4) / 16 163 | y_inf = 1 - 1 / torch.sqrt(μ) 164 | return torch.where(μ >= 4, y_sup, y_inf) 165 | 166 | 167 | class DualLogQuad: 168 | mul = 1 169 | 170 | @staticmethod 171 | def __call__(t: Tensor) -> Tensor: 172 | return (1 + t + torch.sqrt((1 + t) ** 2 + 8)) ** 2 / 16 + torch.log( 173 | 0.25 * (1 + t + torch.sqrt((1 + t) ** 2 + 8))) - 1 174 | 175 | @staticmethod 176 | def tilde(μ: Tensor) -> Tensor: 177 | return 2 * μ - 1 / μ - 1 178 | 179 | 180 | class CubicQuad: 181 | mul = 8 182 | 183 | @staticmethod 184 | def __call__(t: Tensor) -> Tensor: 185 | y_sup = 0.5 * t ** 2 186 | y_inf = 1 / 6 * (t + 0.5).clamp(min=0) ** 3 - 1 / 24 187 | return torch.where(t >= 0.5, y_sup, y_inf) 188 | 189 | @staticmethod 190 | def tilde(μ: Tensor) -> Tensor: 191 | y_sup = μ 192 | y_inf = torch.sqrt(2 * μ) - 0.5 193 | return torch.where(μ >= 0.5, y_sup, y_inf) 194 | 195 | 196 | class ExpQuad: 197 | mul = 1 198 | 199 | @staticmethod 200 | def __call__(t: Tensor) -> Tensor: 201 | y_sup = math.exp(0.5) * (0.5 * t ** 2 + 0.5 * t + 0.625) 202 | y_inf = safe_exp(t) 203 | return torch.where(t >= 0.5, y_sup, y_inf) 204 | 205 | @staticmethod 206 | def tilde(μ: Tensor) -> Tensor: 207 | y_sup = μ * math.exp(-0.5) - 0.5 208 | y_inf = torch.log(μ) 209 | return torch.where(μ >= math.exp(0.5), y_sup, y_inf) 210 | 211 | 212 | class LogBarrierQuad: 213 | mul = 0.25 214 | 215 | @staticmethod 216 | def __call__(t: Tensor) -> Tensor: 217 | y_sup = 2 * t ** 2 + 4 * t + 0.5 + math.log(2) 218 | y_inf = -torch.log(-t.clamp(max=-0.5)) - 1 219 | return torch.where(t >= -0.5, y_sup, y_inf) 220 | 221 | @staticmethod 222 | def tilde(μ: Tensor) -> Tensor: 223 | y_sup = (μ - 4) / 4 224 | y_inf = -1 / μ 225 | return torch.where(μ >= 2, y_sup, y_inf) 226 | 227 | 228 | class HyperBarrierQuad: 229 | mul = 1 / 12 230 | 231 | @staticmethod 232 | def __call__(t: Tensor) -> Tensor: 233 | y_sup = 8 * t ** 2 + 12 * t + 6 234 | y_inf = -1 / t.clamp(max=-0.5) 235 | return torch.where(t >= -0.5, y_sup, y_inf) 236 | 237 | @staticmethod 238 | def tilde(μ: Tensor) -> Tensor: 239 | y_sup = (μ - 12) / 16 240 | y_inf = -1 / torch.sqrt(μ) 241 | return torch.where(μ >= 4, y_sup, y_inf) 242 | 243 | 244 | class HyperLogQuad: 245 | mul = 0.125 246 | 247 | @staticmethod 248 | def __call__(t: Tensor) -> Tensor: 249 | y_sup = 8 * t ** 2 + 8 * t + 1.5 + 2 * math.log(2) 250 | y_mid = -torch.log(-t.clamp(-1, -0.25)) 251 | y_inf = 4 / (1 - t.clamp(max=-1)) - 2 252 | return torch.where(t >= -0.25, y_sup, torch.where(t <= -1, y_inf, y_mid)) 253 | 254 | @staticmethod 255 | def tilde(μ: Tensor) -> Tensor: 256 | y_sup = (μ - 8) / 16 257 | y_mid = -1 / μ 258 | y_inf = 1 - 2 / torch.sqrt(μ) 259 | return torch.where(μ >= 4, y_sup, torch.where(μ <= 1, y_inf, y_mid)) 260 | 261 | 262 | class SmoothPlus: 263 | mul = 2 264 | 265 | @staticmethod 266 | def __call__(t: Tensor) -> Tensor: 267 | return 0.5 * (t + torch.sqrt(t ** 2 + 4)) 268 | 269 | @staticmethod 270 | def tilde(μ: Tensor, ρ: Tensor) -> Tensor: 271 | return (2 * μ - ρ) / torch.sqrt(μ * ρ - μ ** 2) 272 | 273 | 274 | class NNSmoothPlus: 275 | mul = 2 276 | 277 | @staticmethod 278 | def __call__(t: Tensor) -> Tensor: 279 | return F.softplus(t) 280 | 281 | @staticmethod 282 | def tilde(μ: Tensor, ρ: Tensor) -> Tensor: 283 | return torch.log(μ / (ρ - μ)) 284 | 285 | 286 | class ExpSmoothPlus: 287 | mul = 2 288 | 289 | @staticmethod 290 | def __call__(t: Tensor) -> Tensor: 291 | y_sup = t + 0.5 * safe_exp(-t) 292 | y_inf = 0.5 * safe_exp(t) 293 | return torch.where(t >= 0, y_sup, y_inf) 294 | 295 | @staticmethod 296 | def tilde(μ: Tensor, ρ: Tensor) -> Tensor: 297 | y_sup = torch.log(ρ / (2 * (ρ - μ))) 298 | y_inf = torch.log(2 * μ / ρ) 299 | return torch.where(μ >= 0.5 * ρ, y_sup, y_inf) 300 | -------------------------------------------------------------------------------- /adv_lib/utils/losses.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | from torch import Tensor 5 | 6 | 7 | def difference_of_logits(logits: Tensor, labels: Tensor, labels_infhot: Optional[Tensor] = None) -> Tensor: 8 | if labels_infhot is None: 9 | labels_infhot = torch.zeros_like(logits).scatter_(1, labels.unsqueeze(1), float('inf')) 10 | 11 | class_logits = logits.gather(1, labels.unsqueeze(1)).squeeze(1) 12 | other_logits = (logits - labels_infhot).amax(dim=1) 13 | return class_logits - other_logits 14 | 15 | 16 | def difference_of_logits_ratio(logits: Tensor, labels: Tensor, labels_infhot: Optional[Tensor] = None, 17 | targeted: bool = False, ε: float = 0) -> Tensor: 18 | """Difference of Logits Ratio from https://arxiv.org/abs/2003.01690. This version is modified such that the DLR is 19 | always positive if argmax(logits) == labels""" 20 | logit_dists = difference_of_logits(logits=logits, labels=labels, labels_infhot=labels_infhot) 21 | 22 | if targeted: 23 | top4_logits = torch.topk(logits, k=4, dim=1).values 24 | logit_normalization = top4_logits[:, 0] - (top4_logits[:, -2] + top4_logits[:, -1]) / 2 25 | else: 26 | top3_logits = torch.topk(logits, k=3, dim=1).values 27 | logit_normalization = top3_logits[:, 0] - top3_logits[:, -1] 28 | 29 | return (logit_dists + ε) / (logit_normalization + 1e-8) 30 | -------------------------------------------------------------------------------- /adv_lib/utils/projections.py: -------------------------------------------------------------------------------- 1 | from distutils.version import LooseVersion 2 | from typing import Union 3 | 4 | import torch 5 | from torch import Tensor 6 | 7 | use_tensors_in_clamp = False 8 | if LooseVersion(torch.__version__) >= LooseVersion('1.9'): 9 | use_tensors_in_clamp = True 10 | 11 | 12 | @torch.no_grad() 13 | def clamp(x: Tensor, lower: Tensor, upper: Tensor, inplace: bool = False) -> Tensor: 14 | """Clamp based on lower and upper Tensor bounds. Clamping method depends on torch version: clamping with tensors was 15 | introduced in torch 1.9.""" 16 | δ_clamped = x if inplace else None 17 | if use_tensors_in_clamp: 18 | δ_clamped = torch.clamp(x, min=lower, max=upper, out=δ_clamped) 19 | else: 20 | δ_clamped = torch.maximum(x, lower, out=δ_clamped) 21 | δ_clamped = torch.minimum(δ_clamped, upper, out=δ_clamped) 22 | return δ_clamped 23 | 24 | 25 | def clamp_(x: Tensor, lower: Tensor, upper: Tensor) -> Tensor: 26 | """In-place alias for clamp.""" 27 | return clamp(x=x, lower=lower, upper=upper, inplace=True) 28 | 29 | 30 | def simplex_projection(x: Tensor, ε: Union[float, Tensor] = 1, inplace: bool = False) -> Tensor: 31 | """ 32 | Simplex projection based on sorting. 33 | 34 | Parameters 35 | ---------- 36 | x : Tensor 37 | Batch of vectors to project on the simplex. 38 | ε : float or Tensor 39 | Size of the simplex, default to 1 for the probability simplex. 40 | inplace : bool 41 | Can optionally do the operation in-place. 42 | 43 | Returns 44 | ------- 45 | projected_x : Tensor 46 | Batch of projected vectors on the simplex. 47 | """ 48 | u = x.sort(dim=1, descending=True)[0] 49 | ε = ε.unsqueeze(1) if isinstance(ε, Tensor) else x.new_full((), ε) 50 | indices = torch.arange(x.size(1), device=x.device) 51 | cumsum = torch.cumsum(u, dim=1).sub_(ε).div_(indices + 1) 52 | K = (cumsum < u).long().mul_(indices).amax(dim=1, keepdim=True) 53 | τ = cumsum.gather(1, K) 54 | x = x.sub_(τ) if inplace else x - τ 55 | return x.clamp_(min=0) 56 | 57 | 58 | def l1_ball_euclidean_projection(x: Tensor, ε: Union[float, Tensor], inplace: bool = False) -> Tensor: 59 | """ 60 | Compute Euclidean projection onto the L1 ball for a batch. 61 | 62 | min ||x - u||_2 s.t. ||u||_1 <= eps 63 | 64 | Inspired by the corresponding numpy version by Adrien Gaidon. 65 | Adapted from Tony Duan's implementation https://gist.github.com/tonyduan/1329998205d88c566588e57e3e2c0c55 66 | 67 | Parameters 68 | ---------- 69 | x: Tensor 70 | Batch of tensors to project. 71 | ε: float or Tensor 72 | Radius of L1-ball to project onto. Can be a single value for all tensors in the batch or a batch of values. 73 | inplace : bool 74 | Can optionally do the operation in-place. 75 | 76 | Returns 77 | ------- 78 | projected_x: Tensor 79 | Batch of projected tensors with the same shape as x. 80 | 81 | Notes 82 | ----- 83 | The complexity of this algorithm is in O(dlogd) as it involves sorting x. 84 | 85 | References 86 | ---------- 87 | [1] Efficient Projections onto the l1-Ball for Learning in High Dimensions 88 | John Duchi, Shai Shalev-Shwartz, Yoram Singer, and Tushar Chandra. 89 | International Conference on Machine Learning (ICML 2008) 90 | """ 91 | if (to_project := x.norm(p=1, dim=1) > ε).any(): 92 | x_to_project = x[to_project] 93 | ε_ = ε[to_project] if isinstance(ε, Tensor) else x_to_project.new_full((1,), ε) 94 | if not inplace: 95 | x = x.clone() 96 | simplex_proj = simplex_projection(x_to_project.abs(), ε=ε_, inplace=True) 97 | x[to_project] = simplex_proj.copysign_(x_to_project) 98 | return x 99 | else: 100 | return x 101 | -------------------------------------------------------------------------------- /adv_lib/utils/utils.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from typing import Tuple 3 | 4 | import torch 5 | from torch import nn, Tensor 6 | 7 | 8 | class ForwardCounter: 9 | def __init__(self): 10 | self.reset() 11 | 12 | def reset(self): 13 | self.num_samples_called = 0 14 | 15 | def __call__(self, module, input) -> None: 16 | self.num_samples_called += len(input[0]) 17 | 18 | 19 | class BackwardCounter: 20 | def __init__(self): 21 | self.reset() 22 | 23 | def reset(self): 24 | self.num_samples_called = 0 25 | 26 | def __call__(self, module, grad_input, grad_output) -> None: 27 | self.num_samples_called += len(grad_output[0]) 28 | 29 | 30 | class ImageNormalizer(nn.Module): 31 | def __init__(self, mean: Tuple[float, float, float], std: Tuple[float, float, float]) -> None: 32 | super(ImageNormalizer, self).__init__() 33 | 34 | self.register_buffer('mean', torch.as_tensor(mean).view(1, 3, 1, 1)) 35 | self.register_buffer('std', torch.as_tensor(std).view(1, 3, 1, 1)) 36 | 37 | def forward(self, input: Tensor) -> Tensor: 38 | return (input - self.mean) / self.std 39 | 40 | 41 | def normalize_model(model: nn.Module, mean: Tuple[float, float, float], std: Tuple[float, float, float]) -> nn.Module: 42 | layers = OrderedDict([ 43 | ('normalize', ImageNormalizer(mean, std)), 44 | ('model', model) 45 | ]) 46 | return nn.Sequential(layers) 47 | 48 | 49 | def requires_grad_(model: nn.Module, requires_grad: bool) -> None: 50 | for param in model.parameters(): 51 | param.requires_grad_(requires_grad) 52 | 53 | 54 | def predict_inputs(model: nn.Module, inputs: Tensor) -> Tuple[Tensor, Tensor, Tensor]: 55 | logits = model(inputs) 56 | probabilities = torch.softmax(logits, 1) 57 | predictions = logits.argmax(1) 58 | return logits, probabilities, predictions 59 | -------------------------------------------------------------------------------- /adv_lib/utils/visdom_logger.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from enum import Enum 3 | from typing import List, Optional, Tuple, Union 4 | 5 | import torch 6 | import visdom 7 | from torch import Tensor 8 | 9 | 10 | class ChartTypes(Enum): 11 | line = 1, 12 | image = 2 13 | 14 | 15 | class ChartData: 16 | def __init__(self): 17 | self.window = None 18 | self.type = None 19 | self.x_list = [] 20 | self.y_list = [] 21 | self.other_data = None 22 | self.to_plot = {} 23 | 24 | 25 | class VisdomLogger: 26 | def __init__(self, port: int): 27 | self.vis = visdom.Visdom(port=port) 28 | self.windows = defaultdict(lambda: ChartData()) 29 | 30 | @staticmethod 31 | def as_unsqueezed_tensor(data: Union[float, List[float], Tensor]) -> Tensor: 32 | data = torch.as_tensor(data).detach() 33 | return data.unsqueeze(0) if data.ndim == 0 else data 34 | 35 | def accumulate_line(self, names: Union[str, List[str]], x: Union[float, Tensor], 36 | y: Union[float, Tensor, List[Tensor]], title: str = '', **kwargs) -> None: 37 | if isinstance(names, str): 38 | names = [names] 39 | data = self.windows['$'.join(names)] 40 | update = None if data.window is None else 'append' 41 | 42 | if isinstance(y, (int, float)): 43 | Y = torch.tensor([y]) 44 | elif isinstance(y, list): 45 | Y = torch.stack(list(map(self.as_unsqueezed_tensor, y)), 1) 46 | elif isinstance(y, Tensor): 47 | Y = self.as_unsqueezed_tensor(y) 48 | 49 | if isinstance(x, (int, float)): 50 | X = torch.tensor([x]) 51 | elif isinstance(x, Tensor): 52 | X = self.as_unsqueezed_tensor(x) 53 | 54 | if Y.ndim == 2 and X.ndim == 1: 55 | X.expand(len(X), Y.shape[1]) 56 | 57 | if len(data.to_plot) == 0: 58 | data.to_plot = {'X': X, 'Y': Y, 'win': data.window, 'update': update, 59 | 'opts': {'legend': names, 'title': title, **kwargs}} 60 | else: 61 | data.to_plot['X'] = torch.cat((data.to_plot['X'], X), 0) 62 | data.to_plot['Y'] = torch.cat((data.to_plot['Y'], Y), 0) 63 | 64 | def update_lines(self) -> None: 65 | for window, data in self.windows.items(): 66 | if len(data.to_plot) != 0: 67 | win = self.vis.line(**data.to_plot) 68 | 69 | data.x_list.append(data.to_plot['X']) 70 | data.y_list.append(data.to_plot['Y']) 71 | 72 | # Update the window 73 | data.window = win 74 | data.type = ChartTypes.line 75 | 76 | data.to_plot = {} 77 | 78 | def line(self, names: Union[str, List[str]], x: Union[float, Tensor], y: Union[float, Tensor, List[Tensor]], 79 | title: str = '', **kwargs) -> None: 80 | self.accumulate_line(names=names, x=x, y=y, title=title, **kwargs) 81 | self.update_lines() 82 | 83 | def images(self, name: str, images: Tensor, mean_std: Optional[Tuple[List[float], List[float]]] = None, 84 | title: str = '') -> None: 85 | data = self.windows[name] 86 | 87 | if mean_std is not None: 88 | images = images * torch.as_tensor(mean_std[0]) + torch.as_tensor(mean_std[1]) 89 | 90 | win = self.vis.images(images, win=data.window, opts={'legend': [name], 'title': title}) 91 | 92 | # Update the window 93 | data.window = win 94 | data.other_data = images 95 | data.type = ChartTypes.image 96 | 97 | def reset_windows(self): 98 | self.windows.clear() 99 | 100 | def save(self, filename): 101 | to_save = {} 102 | for (name, data) in self.windows.items(): 103 | type = data.type 104 | if type == ChartTypes.line: 105 | to_save[name] = (type, torch.cat(data.x_list, dim=0).cpu(), torch.cat(data.y_list, dim=0).cpu()) 106 | elif type == ChartTypes.image: 107 | to_save[name] = (type, data.other_data.cpu()) 108 | 109 | torch.save(to_save, filename) 110 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "adv-lib" 7 | dynamic = ["version"] 8 | authors = [ 9 | { name = "Jerome Rony", email = "jerome.rony@gmail.com" }, 10 | ] 11 | description = "Library of various adversarial attacks resources in PyTorch" 12 | readme = "README.md" 13 | requires-python = ">=3.8" 14 | license = { file = "LICENSE" } 15 | classifiers = [ 16 | "Programming Language :: Python :: 3", 17 | "Development Status :: 3 - Alpha", 18 | "Intended Audience :: Developers", 19 | "Intended Audience :: Science/Research", 20 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 21 | ] 22 | dependencies = [ 23 | "torch>=1.8.0", 24 | "torchvision>=0.9.0", 25 | "tqdm>=4.48.0", 26 | "visdom>=0.1.8", 27 | ] 28 | 29 | [tool.setuptools.packages.find] 30 | include = ["adv_lib*"] 31 | namespaces = false 32 | 33 | [tool.setuptools.dynamic] 34 | version = { attr = "adv_lib.__version__" } 35 | 36 | [project.optional-dependencies] 37 | test = ["scikit-image", "pytest"] 38 | 39 | [project.urls] 40 | Repository = "https://github.com/jeromerony/adversarial-library.git" -------------------------------------------------------------------------------- /tests/distances/test_color_difference.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | from torch.autograd import grad 4 | 5 | from adv_lib.distances.color_difference import ciede2000_color_difference 6 | 7 | # test data from http://www2.ece.rochester.edu/~gsharma/ciede2000/ciede2000noteCRNA.pdf 8 | TEST_DATA = [ 9 | # L*1 a*1 b*1 L*2 a*2 b*2 ΔE_00 10 | [50.0000, 2.677200, -79.7751, 50.0000, 0.000000, -82.7485, 2.0425], 11 | [50.0000, 3.157100, -77.2803, 50.0000, 0.000000, -82.7485, 2.8615], 12 | [50.0000, 2.836100, -74.0200, 50.0000, 0.000000, -82.7485, 3.4412], 13 | [50.0000, -1.38020, -84.2814, 50.0000, 0.000000, -82.7485, 1.0000], 14 | [50.0000, -1.18480, -84.8006, 50.0000, 0.000000, -82.7485, 1.0000], 15 | [50.0000, -0.90090, -85.5211, 50.0000, 0.000000, -82.7485, 1.0000], 16 | [50.0000, 0.000000, 0.000000, 50.0000, -1.00000, 2.000000, 2.3669], 17 | [50.0000, -1.00000, 2.000000, 50.0000, 0.000000, 0.000000, 2.3669], 18 | [50.0000, 2.490000, -0.00100, 50.0000, -2.49000, 0.000900, 7.1792], 19 | [50.0000, 2.490000, -0.00100, 50.0000, -2.49000, 0.001000, 7.1792], 20 | [50.0000, 2.490000, -0.00100, 50.0000, -2.49000, 0.001100, 7.2195], 21 | [50.0000, 2.490000, -0.00100, 50.0000, -2.49000, 0.001200, 7.2195], 22 | [50.0000, -0.00100, 2.490000, 50.0000, 0.000900, -2.49000, 4.8045], 23 | [50.0000, -0.00100, 2.490000, 50.0000, 0.001000, -2.49000, 4.8045], 24 | [50.0000, -0.00100, 2.490000, 50.0000, 0.001100, -2.49000, 4.7461], 25 | [50.0000, 2.500000, 0.000000, 50.0000, 0.000000, -2.50000, 4.3065], 26 | [50.0000, 2.500000, 0.000000, 73.0000, 25.00000, -18.0000, 27.1492], 27 | [50.0000, 2.500000, 0.000000, 61.0000, -5.00000, 29.00000, 22.8977], 28 | [50.0000, 2.500000, 0.000000, 56.0000, -27.0000, -3.00000, 31.9030], 29 | [50.0000, 2.500000, 0.000000, 58.0000, 24.00000, 15.00000, 19.4535], 30 | [50.0000, 2.500000, 0.000000, 50.0000, 3.173600, 0.585400, 1.0000], 31 | [50.0000, 2.500000, 0.000000, 50.0000, 3.297200, 0.000000, 1.0000], 32 | [50.0000, 2.500000, 0.000000, 50.0000, 1.863400, 0.575700, 1.0000], 33 | [50.0000, 2.500000, 0.000000, 50.0000, 3.259200, 0.335000, 1.0000], 34 | [60.2574, -34.0099, 36.26770, 60.4626, -34.1751, 39.43870, 1.2644], 35 | [63.0109, -31.0961, -5.86630, 62.8187, -29.7946, -4.08640, 1.2630], 36 | [61.2901, 3.719600, -5.39010, 61.4292, 2.248000, -4.96200, 1.8731], 37 | [35.0831, -44.1164, 3.793300, 35.0232, -40.0716, 1.590100, 1.8645], 38 | [22.7233, 20.09040, -46.6940, 23.0331, 14.97300, -42.5619, 2.0373], 39 | [36.4612, 47.85800, 18.38520, 36.2715, 50.50650, 21.22310, 1.4146], 40 | [90.8027, -2.08310, 1.441000, 91.1528, -1.64350, 0.044700, 1.4441], 41 | [90.9257, -0.54060, -0.92080, 88.6381, -0.89850, -0.72390, 1.5381], 42 | [6.77470, -0.29080, -2.42470, 5.87140, -0.09850, -2.22860, 0.6377], 43 | [2.07760, 0.079500, -1.13500, 0.90330, -0.06360, -0.55140, 0.9082], 44 | ] 45 | 46 | 47 | @pytest.mark.parametrize('dtype', [torch.float32, torch.float64]) 48 | def test_ciede2000_color_difference_value(dtype: torch.dtype) -> None: 49 | test_data = torch.tensor(TEST_DATA, dtype=dtype) 50 | Lab1 = test_data.narrow(1, 0, 3).clone() 51 | Lab2 = test_data.narrow(1, 3, 3).clone() 52 | test_ΔE_00 = test_data[:, -1].clone() 53 | 54 | ΔE_00_1 = ciede2000_color_difference(Lab1, Lab2) 55 | ΔE_00_2 = ciede2000_color_difference(Lab2, Lab1) 56 | ΔE_00_0 = ciede2000_color_difference(Lab1, Lab1) 57 | 58 | assert torch.equal(ΔE_00_0, torch.zeros_like(ΔE_00_0)) # check identical inputs 59 | assert torch.equal(ΔE_00_1, ΔE_00_2) # check symmetry 60 | assert torch.allclose(ΔE_00_1, test_ΔE_00, rtol=1e-4, atol=1e-5) # check correctness 61 | 62 | 63 | @pytest.mark.parametrize('dtype', [torch.float32, torch.float64]) 64 | def test_ciede2000_color_difference_grad(dtype: torch.dtype) -> None: 65 | test_data = torch.tensor(TEST_DATA, dtype=dtype) 66 | Lab1 = test_data.narrow(1, 0, 3).clone() 67 | Lab2 = test_data.narrow(1, 3, 3).clone() 68 | 69 | # check that gradients are not NaN 70 | Lab1.requires_grad_(True) 71 | 72 | # check for identical inputs 73 | ΔE_00 = ciede2000_color_difference(Lab1, Lab1, ε=1e-12) 74 | ΔE_00_grad = grad(ΔE_00.sum(), Lab1, only_inputs=True)[0] 75 | assert not torch.isnan(ΔE_00_grad).any() 76 | assert torch.equal(ΔE_00_grad, torch.zeros_like(ΔE_00_grad)) 77 | 78 | # check for different inputs 79 | ΔE_00 = ciede2000_color_difference(Lab1, Lab2, ε=1e-12) 80 | ΔE_00_grad = grad(ΔE_00.sum(), Lab1, only_inputs=True)[0] 81 | assert not torch.isnan(ΔE_00_grad).any() 82 | -------------------------------------------------------------------------------- /tests/distances/test_structural_similarity.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | import torch 4 | from skimage import data 5 | from skimage.metrics import structural_similarity 6 | 7 | from adv_lib.distances.structural_similarity import compute_ssim 8 | 9 | 10 | @pytest.mark.parametrize('dtype', [np.float32, np.float64]) 11 | def test_compute_ssim_gray(dtype: np.dtype) -> None: 12 | # test for gray level images 13 | np_gray_img = data.camera().astype(dtype) / 255 14 | pt_gray_img = torch.as_tensor(np_gray_img) 15 | 16 | for sigma in [0, 0.01, 0.03, 0.1, 0.3]: 17 | noise = torch.randn_like(pt_gray_img) * sigma 18 | 19 | noisy_pt_gray_img = (pt_gray_img + noise).clamp(0, 1) 20 | noisy_np_gray_img = noisy_pt_gray_img.numpy() 21 | 22 | skimage_ssim = structural_similarity(noisy_np_gray_img, np_gray_img, win_size=11, sigma=1.5, 23 | use_sample_covariance=False, gaussian_weights=True, data_range=1) 24 | adv_lib_ssim = compute_ssim(noisy_pt_gray_img.unsqueeze(0).unsqueeze(1), 25 | pt_gray_img.unsqueeze(0).unsqueeze(1)) 26 | abs_diff = abs(skimage_ssim - adv_lib_ssim.item()) 27 | assert abs_diff < 2e-5 28 | 29 | 30 | @pytest.mark.parametrize('dtype', [np.float32, np.float64]) 31 | def test_compute_ssim_color(dtype: np.dtype) -> None: 32 | # test for color images 33 | np_color_img = data.astronaut().astype(dtype) / 255 34 | pt_color_img = torch.as_tensor(np_color_img) 35 | 36 | for sigma in [0, 0.01, 0.03, 0.1, 0.3]: 37 | noise = torch.randn_like(pt_color_img) * sigma 38 | 39 | noisy_pt_color_img = (pt_color_img + noise).clamp(0, 1) 40 | noisy_np_color_img = noisy_pt_color_img.numpy() 41 | 42 | skimage_ssim = structural_similarity(noisy_np_color_img, np_color_img, win_size=11, sigma=1.5, 43 | multichannel=True, use_sample_covariance=False, gaussian_weights=True, 44 | data_range=1) 45 | adv_lib_ssim = compute_ssim(noisy_pt_color_img.permute(2, 0, 1).unsqueeze(0), 46 | pt_color_img.permute(2, 0, 1).unsqueeze(0)) 47 | 48 | abs_diff = abs(skimage_ssim - adv_lib_ssim.item()) 49 | assert abs_diff < 1e-5 50 | -------------------------------------------------------------------------------- /tests/utils/lagrangian_penalties/test_penalty_functions.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | from torch.autograd import grad, gradcheck 4 | 5 | from adv_lib.utils.lagrangian_penalties import all_penalties 6 | 7 | 8 | @pytest.mark.parametrize('penalty', list(all_penalties.values())) 9 | def test_grad(penalty) -> None: 10 | y = torch.randn(512, dtype=torch.double, requires_grad=True) 11 | ρ = torch.randn(512, dtype=torch.double).abs_().clamp_min_(1e-3) 12 | μ = torch.randn(512, dtype=torch.double).abs_().clamp_min_(1e-6) 13 | ρ.requires_grad_(True) 14 | μ.requires_grad_(True) 15 | 16 | # check if gradients are correct compared to numerical approximations using finite differences 17 | assert gradcheck(penalty, inputs=(y, ρ, μ)) 18 | 19 | 20 | @pytest.mark.parametrize('penalty,value', [(all_penalties['P2'], 1), (all_penalties['P3'], 1)]) 21 | @pytest.mark.parametrize('dtype', [torch.float32, torch.float64]) 22 | def test_nan_grad(penalty, value, dtype) -> None: 23 | y = torch.full((1,), value, dtype=dtype, requires_grad=True) 24 | ρ = torch.full((1,), value, dtype=dtype) 25 | μ = torch.full((1,), value, dtype=dtype) 26 | 27 | out = penalty(y, ρ, μ) 28 | g = grad(out, y, only_inputs=True)[0] 29 | 30 | assert torch.isnan(g).any() == False # check nan in gradients of penalty 31 | -------------------------------------------------------------------------------- /tests/utils/lagrangian_penalties/test_univariate_functions.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | from torch.autograd import grad, gradcheck 4 | 5 | from adv_lib.utils.lagrangian_penalties import univariate_functions 6 | 7 | 8 | @pytest.mark.parametrize('univariate', univariate_functions.__all__) 9 | def test_grad(univariate) -> None: 10 | t = torch.randn(512, dtype=torch.double, requires_grad=True) 11 | # check if gradients are correct compared to numerical approximations using finite differences 12 | assert gradcheck(univariate_functions.__dict__[univariate](), inputs=t) 13 | 14 | 15 | @pytest.mark.parametrize('univariate,value', [('LogExp', 1), ('LogQuad_1', 1), ('HyperExp', 1), ('HyperQuad', 1), 16 | ('LogBarrierQuad', 0), ('HyperBarrierQuad', 0), ('HyperLogQuad', 0), 17 | ('HyperLogQuad', 1)]) 18 | @pytest.mark.parametrize('dtype', [torch.float32, torch.float64]) 19 | def test_nan_grad(univariate, value, dtype) -> None: 20 | t = torch.full((1,), value, dtype=dtype, requires_grad=True) 21 | 22 | univariate_func = univariate_functions.__dict__[univariate]() 23 | out = univariate_func(t) 24 | g = grad(out, t, only_inputs=True)[0] 25 | 26 | assert torch.isnan(g).any() == False # check nan in gradients of penalty 27 | --------------------------------------------------------------------------------