├── .env.example ├── .github ├── PULL_REQUEST_TEMPLATE.md ├── dependabot.yml └── workflows │ ├── code-quality-main.yaml │ ├── code-quality-pr.yaml │ └── test.yml ├── .gitignore ├── .pre-commit-config.yaml ├── CODE_OF_CONDUCT.md ├── LICENSE ├── LICENSE-3RD-PARTY ├── Makefile ├── README.md ├── data ├── .gitkeep ├── arch.png └── loop.png ├── docs └── attack_in_armory.md ├── examples ├── README.md ├── anomalib_adversary │ ├── .gitignore │ ├── README.md │ ├── anomalib_adversary │ │ └── callbacks.py │ ├── configs │ │ └── .gitkeep │ └── pyproject.toml ├── art_wrapper │ └── adversary_in_art.py ├── carla_overhead_object_detection │ ├── Makefile │ ├── README.md │ ├── configs │ │ ├── datamodule │ │ │ ├── armory_carla_over_objdet.yaml │ │ │ └── armory_carla_over_objdet_perturbable_mask.yaml │ │ └── experiment │ │ │ └── ArmoryCarlaOverObjDet_TorchvisionFasterRCNN.yaml │ ├── requirements.txt │ └── tests │ │ ├── conftest.py │ │ ├── helpers │ │ ├── test_configs.py │ │ └── test_experiments.py ├── fiftyone │ └── README.md └── robust_bench │ ├── README.md │ ├── configs │ ├── attack │ │ └── classification_autoattack.yaml │ ├── experiment │ │ └── CIFAR10_RobustBench.yaml │ └── model │ │ └── classifier_robustbench.yaml │ ├── requirements.txt │ └── tests │ ├── helpers │ └── test_experiments.py ├── hydra_plugins └── hydra_mart │ ├── __init__.py │ └── mart.py ├── logs └── .gitkeep ├── mart ├── __init__.py ├── __main__.py ├── attack │ ├── __init__.py │ ├── adversary.py │ ├── adversary_wrapper.py │ ├── composer │ │ ├── __init__.py │ │ ├── modular.py │ │ ├── patch.py │ │ └── visualizer.py │ ├── enforcer.py │ ├── gain.py │ ├── gradient_modifier.py │ ├── initializer │ │ ├── __init__.py │ │ ├── base.py │ │ └── vision.py │ ├── objective │ │ ├── __init__.py │ │ ├── base.py │ │ ├── classification.py │ │ └── object_detection.py │ ├── perturber.py │ └── projector.py ├── callbacks │ ├── __init__.py │ ├── adversary_connector.py │ ├── eval_mode.py │ ├── gradients.py │ ├── logging.py │ ├── metrics.py │ ├── no_grad_mode.py │ ├── progress_bar.py │ └── visualizer.py ├── configs │ ├── __init__.py │ ├── attack │ │ ├── adaptive_gradient_ascent.yaml │ │ ├── adversary.yaml │ │ ├── classification_fgsm_linf.yaml │ │ ├── classification_pgd_linf.yaml │ │ ├── composer │ │ │ ├── additive.yaml │ │ │ ├── default.yaml │ │ │ ├── mask_additive.yaml │ │ │ ├── modules │ │ │ │ ├── additive.yaml │ │ │ │ ├── fake_renderer.yaml │ │ │ │ ├── mask.yaml │ │ │ │ ├── overlay.yaml │ │ │ │ ├── pert_extract_rect.yaml │ │ │ │ ├── pert_image_base.yaml │ │ │ │ ├── pert_rect_perspective.yaml │ │ │ │ └── pert_rect_size.yaml │ │ │ ├── overlay.yaml │ │ │ └── perturber │ │ │ │ ├── default.yaml │ │ │ │ ├── initializer │ │ │ │ ├── constant.yaml │ │ │ │ ├── image.yaml │ │ │ │ ├── uniform.yaml │ │ │ │ └── uniform_lp.yaml │ │ │ │ ├── projector │ │ │ │ ├── linf.yaml │ │ │ │ ├── linf_additive_range.yaml │ │ │ │ ├── lp_additive_range.yaml │ │ │ │ ├── mask_range.yaml │ │ │ │ └── range.yaml │ │ │ │ └── universal.yaml │ │ ├── enforcer │ │ │ ├── constraints │ │ │ │ ├── integer.yaml │ │ │ │ ├── lp.yaml │ │ │ │ ├── mask.yaml │ │ │ │ └── pixel_range.yaml │ │ │ └── default.yaml │ │ ├── fgm.yaml │ │ ├── gain │ │ │ ├── cross_entropy.yaml │ │ │ ├── dlr.yaml │ │ │ ├── modular.yaml │ │ │ ├── rcnn_class_background.yaml │ │ │ └── rcnn_training_loss.yaml │ │ ├── gradient_ascent.yaml │ │ ├── gradient_modifier │ │ │ ├── lp_normalizer.yaml │ │ │ └── sign.yaml │ │ ├── linf.yaml │ │ ├── mask.yaml │ │ ├── object_detection_lp_patch_adversary.yaml │ │ ├── object_detection_lp_patch_adversary_simulation.yaml │ │ ├── object_detection_mask_adversary.yaml │ │ ├── object_detection_mask_adversary_missed.yaml │ │ ├── object_detection_patch_adversary.yaml │ │ ├── object_detection_patch_adversary_simulation.yaml │ │ ├── objective │ │ │ ├── misclassification.yaml │ │ │ ├── object_detection_missed.yaml │ │ │ └── zero_ap.yaml │ │ └── pgd.yaml │ ├── batch_c15n │ │ ├── dict.yaml │ │ ├── dict_imagenet_normalized.yaml │ │ ├── input_only.yaml │ │ ├── list.yaml │ │ ├── list_image_01.yaml │ │ ├── transform │ │ │ ├── 255_to_imagenet.yaml │ │ │ ├── divided_by_255.yaml │ │ │ ├── imagenet_to_255.yaml │ │ │ └── times_255_and_round.yaml │ │ └── tuple.yaml │ ├── callbacks │ │ ├── adversary_connector.yaml │ │ ├── attack_in_eval_mode.yaml │ │ ├── default.yaml │ │ ├── early_stopping.yaml │ │ ├── gradient_monitor.yaml │ │ ├── image_visualizer.yaml │ │ ├── lr_monitor.yaml │ │ ├── model_checkpoint.yaml │ │ ├── model_summary.yaml │ │ ├── no_grad_mode.yaml │ │ ├── none.yaml │ │ ├── progress_bar.yaml │ │ └── rich_progress_bar.yaml │ ├── datamodule │ │ ├── carla_patch.yaml │ │ ├── carla_patch_rendering.yaml │ │ ├── cifar10.yaml │ │ ├── coco.yaml │ │ ├── coco_perturbable_mask.yaml │ │ ├── default.yaml │ │ ├── dummy_classification.yaml │ │ ├── fiftyone.yaml │ │ ├── fiftyone_perturbable_mask.yaml │ │ └── imagenet.yaml │ ├── debug │ │ ├── default.yaml │ │ ├── fdr.yaml │ │ ├── limit.yaml │ │ ├── overfit.yaml │ │ └── profiler.yaml │ ├── experiment │ │ ├── CIFAR10_CNN.yaml │ │ ├── CIFAR10_CNN_Adv.yaml │ │ ├── COCO_TorchvisionFasterRCNN.yaml │ │ ├── COCO_TorchvisionFasterRCNN_Adv.yaml │ │ ├── COCO_TorchvisionRetinaNet.yaml │ │ └── ImageNet_Timm.yaml │ ├── extras │ │ └── default.yaml │ ├── hparams_search │ │ └── mnist_optuna.yaml │ ├── hydra │ │ └── default.yaml │ ├── lightning.yaml │ ├── logger │ │ ├── comet.yaml │ │ ├── csv.yaml │ │ ├── many_loggers.yaml │ │ ├── mlflow.yaml │ │ ├── neptune.yaml │ │ ├── tensorboard.yaml │ │ └── wandb.yaml │ ├── metric │ │ ├── accuracy.yaml │ │ └── average_precision.yaml │ ├── model │ │ ├── classifier.yaml │ │ ├── classifier_cifar10_cnn.yaml │ │ ├── classifier_timm.yaml │ │ ├── cnn_7layer_bn2.yaml │ │ ├── modular.yaml │ │ ├── torchvision_faster_rcnn.yaml │ │ ├── torchvision_object_detection.yaml │ │ └── torchvision_retinanet.yaml │ ├── optimization │ │ ├── adaptive_sgd.yaml │ │ ├── lr_scheduler │ │ │ └── scheduler │ │ │ │ └── reduced_lr_on_plateau.yaml │ │ └── super_convergence.yaml │ ├── optimizer │ │ ├── adam.yaml │ │ └── sgd.yaml │ ├── paths │ │ └── default.yaml │ └── trainer │ │ ├── ddp.yaml │ │ ├── ddp_sim.yaml │ │ ├── default.yaml │ │ ├── gpu.yaml │ │ └── mps.yaml ├── datamodules │ ├── __init__.py │ ├── modular.py │ └── vision │ │ ├── __init__.py │ │ ├── coco.py │ │ └── fiftyone.py ├── generate_config.py ├── models │ ├── __init__.py │ ├── modular.py │ └── vision │ │ ├── __init__.py │ │ └── dual_mode.py ├── nn │ ├── __init__.py │ └── nn.py ├── optim │ ├── __init__.py │ └── optimizer.py ├── tasks │ ├── __init__.py │ └── lightning.py ├── transforms │ ├── __init__.py │ ├── batch_c15n.py │ ├── transforms.py │ └── vision │ │ ├── __init__.py │ │ ├── objdet │ │ ├── __init__.py │ │ ├── extended.py │ │ └── torchvision_ref.py │ │ └── transforms.py └── utils │ ├── __init__.py │ ├── adapters.py │ ├── config.py │ ├── export.py │ ├── imports.py │ ├── lightning.py │ ├── monkey_patch.py │ ├── optimization.py │ ├── pylogger.py │ ├── rich_utils.py │ ├── silent.py │ └── utils.py ├── notebooks └── .gitkeep ├── pyproject.toml ├── scripts └── schedule.sh └── tests ├── __init__.py ├── conftest.py ├── helpers ├── __init__.py ├── dataset_generator.py ├── dataset_utils.py ├── package_available.py ├── run_if.py └── run_sh_command.py ├── test_adversary.py ├── test_adversary_connector.py ├── test_composer.py ├── test_configs.py ├── test_dependency.py ├── test_enforcer.py ├── test_experiments.py ├── test_gradient.py ├── test_initializer.py ├── test_perturber.py ├── test_projector.py ├── test_utils.py └── test_visualizer.py /.env.example: -------------------------------------------------------------------------------- 1 | # example of file for storing private and user specific environment variables, like keys or system paths 2 | # rename it to ".env" (excluded from version control by default) 3 | # .env is loaded by train.py automatically 4 | # hydra allows you to reference variables in .yaml configs with special syntax: ${oc.env:MY_VAR} 5 | 6 | MY_VAR="/home/user/my/system/path" 7 | -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | # What does this PR do? 2 | 3 | 9 | 10 | Fixes #\ 11 | 12 | ## Type of change 13 | 14 | Please check all relevant options. 15 | 16 | - [ ] Improvement (non-breaking) 17 | - [ ] Bug fix (non-breaking) 18 | - [ ] New feature (non-breaking) 19 | - [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected) 20 | - [ ] This change requires a documentation update 21 | 22 | ## Testing 23 | 24 | Please describe the tests that you ran to verify your changes. Consider listing any relevant details of your test configuration. 25 | 26 | - [ ] `pytest` 27 | - [ ] `CUDA_VISIBLE_DEVICES=0 python -m mart experiment=CIFAR10_CNN_Adv trainer=gpu trainer.precision=16` reports 70% (21 sec/epoch). 28 | - [ ] `CUDA_VISIBLE_DEVICES=0,1 python -m mart experiment=CIFAR10_CNN_Adv trainer=ddp trainer.precision=16 trainer.devices=2 model.optimizer.lr=0.2 trainer.max_steps=2925 datamodule.ims_per_batch=256 datamodule.world_size=2` reports 70% (14 sec/epoch). 29 | 30 | ## Before submitting 31 | 32 | - [ ] The title is **self-explanatory** and the description **concisely** explains the PR 33 | - [ ] My **PR does only one thing**, instead of bundling different changes together 34 | - [ ] I list all the **breaking changes** introduced by this pull request 35 | - [ ] I have commented my code 36 | - [ ] I have added tests that prove my fix is effective or that my feature works 37 | - [ ] New and existing unit tests pass locally with my changes 38 | - [ ] I have run pre-commit hooks with `pre-commit run -a` command without errors 39 | 40 | ## Did you have fun? 41 | 42 | Make sure you had fun coding 🙃 43 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | # To get started with Dependabot version updates, you'll need to specify which 2 | # package ecosystems to update and where the package manifests are located. 3 | # Please see the documentation for all configuration options: 4 | # https://docs.github.com/github/administering-a-repository/configuration-options-for-dependency-updates 5 | 6 | version: 2 7 | updates: 8 | - package-ecosystem: "pip" # See documentation for possible values 9 | directory: "/" # Location of package manifests 10 | schedule: 11 | interval: "daily" 12 | ignore: 13 | - dependency-name: "pytorch-lightning" 14 | update-types: ["version-update:semver-patch"] 15 | - dependency-name: "torchmetrics" 16 | update-types: ["version-update:semver-patch"] 17 | -------------------------------------------------------------------------------- /.github/workflows/code-quality-main.yaml: -------------------------------------------------------------------------------- 1 | # Same as `code-quality-pr.yaml` but triggered on commit to main branch 2 | # and runs on all files (instead of only the changed ones) 3 | 4 | name: Code Quality Main 5 | 6 | on: 7 | push: 8 | branches: [main] 9 | 10 | permissions: read-all 11 | 12 | jobs: 13 | code-quality: 14 | runs-on: ubuntu-latest 15 | 16 | steps: 17 | - name: Checkout 18 | uses: actions/checkout@ee0669bd1cc54295c223e0bb666b733df41de1c5 # v2.7.0 19 | 20 | - name: Set up Python 21 | uses: actions/setup-python@e9aba2c848f5ebd159c070c61ea2c4e2b122355e # v2.3.4 22 | 23 | - name: Run pre-commits 24 | uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd # v3.0.1 25 | -------------------------------------------------------------------------------- /.github/workflows/code-quality-pr.yaml: -------------------------------------------------------------------------------- 1 | # This workflow finds which files were changed, prints them, 2 | # and runs `pre-commit` on those files. 3 | 4 | # Inspired by the sktime library: 5 | # https://github.com/alan-turing-institute/sktime/blob/main/.github/workflows/test.yml 6 | 7 | name: Code Quality PR 8 | 9 | on: 10 | pull_request: 11 | branches: [main, "release/*"] 12 | 13 | permissions: read-all 14 | 15 | jobs: 16 | code-quality: 17 | runs-on: ubuntu-latest 18 | 19 | steps: 20 | - name: Checkout 21 | uses: actions/checkout@ee0669bd1cc54295c223e0bb666b733df41de1c5 # v2.7.0 22 | 23 | - name: Set up Python 24 | uses: actions/setup-python@e9aba2c848f5ebd159c070c61ea2c4e2b122355e # v2.3.4 25 | 26 | - name: Find modified files 27 | id: file_changes 28 | uses: trilom/file-changes-action@a6ca26c14274c33b15e6499323aac178af06ad4b # v1.2.4 29 | with: 30 | output: " " 31 | 32 | - name: List modified files 33 | run: echo '${{ steps.file_changes.outputs.files}}' 34 | 35 | - name: Run pre-commits 36 | uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd # v3.0.1 37 | with: 38 | extra_args: --files ${{ steps.file_changes.outputs.files}} 39 | -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | name: Tests 2 | 3 | on: 4 | push: 5 | branches: [main] 6 | pull_request: 7 | branches: [main, "release/*"] 8 | 9 | permissions: read-all 10 | 11 | jobs: 12 | run_tests: 13 | runs-on: ${{ matrix.os }} 14 | 15 | strategy: 16 | fail-fast: false 17 | matrix: 18 | os: ["ubuntu-latest"] 19 | python-version: ["3.9"] 20 | 21 | timeout-minutes: 30 22 | 23 | steps: 24 | - name: Checkout 25 | uses: actions/checkout@f43a0e5ff2bd294095638e18286ca9a3d1956744 # v3.6.0 26 | 27 | - name: Set up Python ${{ matrix.python-version }} 28 | uses: actions/setup-python@3542bca2639a428e1796aaa6a2ffef0c0f575566 # v3.1.4 29 | with: 30 | python-version: ${{ matrix.python-version }} 31 | 32 | - name: Install dependencies 33 | run: | 34 | python -m pip install --upgrade pip 35 | pip install -e .[full] 36 | 37 | - name: List dependencies 38 | run: | 39 | python -m pip list 40 | 41 | - name: Run pytest 42 | run: | 43 | pytest -v 44 | 45 | # upload code coverage report 46 | code-coverage: 47 | runs-on: ubuntu-latest 48 | 49 | steps: 50 | - name: Checkout 51 | uses: actions/checkout@ee0669bd1cc54295c223e0bb666b733df41de1c5 # v2.7.0 52 | 53 | - name: Set up Python 3.10 54 | uses: actions/setup-python@e9aba2c848f5ebd159c070c61ea2c4e2b122355e # v2.3.4 55 | with: 56 | python-version: "3.10" 57 | 58 | - name: Install dependencies 59 | run: | 60 | python -m pip install --upgrade pip 61 | pip install -e .[full] 62 | 63 | - name: Run tests and collect coverage 64 | run: pytest --cov mart 65 | 66 | - name: Upload coverage to Codecov 67 | uses: codecov/codecov-action@ab904c41d6ece82784817410c45d8b8c02684457 # v3.1.6 68 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .venv 106 | env/ 107 | venv/ 108 | ENV/ 109 | env.bak/ 110 | venv.bak/ 111 | 112 | # Spyder project settings 113 | .spyderproject 114 | .spyproject 115 | 116 | # Rope project settings 117 | .ropeproject 118 | 119 | # mkdocs documentation 120 | /site 121 | 122 | # mypy 123 | .mypy_cache/ 124 | .dmypy.json 125 | dmypy.json 126 | 127 | # Pyre type checker 128 | .pyre/ 129 | 130 | ### VisualStudioCode 131 | .vscode/* 132 | !.vscode/settings.json 133 | !.vscode/tasks.json 134 | !.vscode/launch.json 135 | !.vscode/extensions.json 136 | *.code-workspace 137 | **/.vscode 138 | 139 | # JetBrains 140 | .idea/ 141 | 142 | # Data & Models 143 | *.h5 144 | *.tar 145 | *.tar.gz 146 | 147 | # Lightning-Hydra-Template 148 | configs/local/default.yaml 149 | data/ 150 | logs/ 151 | .env 152 | .autoenv 153 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | default_language_version: 2 | python: python3 3 | 4 | repos: 5 | - repo: https://github.com/pre-commit/pre-commit-hooks 6 | rev: v4.3.0 7 | hooks: 8 | # list of supported hooks: https://pre-commit.com/hooks.html 9 | - id: trailing-whitespace 10 | - id: end-of-file-fixer 11 | - id: check-docstring-first 12 | - id: check-yaml 13 | - id: debug-statements 14 | - id: detect-private-key 15 | - id: check-executables-have-shebangs 16 | - id: check-toml 17 | - id: check-case-conflict 18 | - id: check-added-large-files 19 | 20 | # python code formatting 21 | - repo: https://github.com/psf/black 22 | rev: 22.6.0 23 | hooks: 24 | - id: black 25 | args: [--line-length, "99"] 26 | 27 | # python import sorting 28 | - repo: https://github.com/PyCQA/isort 29 | rev: 5.12.0 30 | hooks: 31 | - id: isort 32 | args: ["--profile", "black", "--filter-files"] 33 | 34 | # python upgrading syntax to newer version 35 | - repo: https://github.com/asottile/pyupgrade 36 | rev: v2.32.1 37 | hooks: 38 | - id: pyupgrade 39 | args: [--py38-plus] 40 | 41 | # python docstring formatting 42 | - repo: https://github.com/myint/docformatter 43 | # Use a commit in the master branch to avoid the conflict with pre-commit 44 | # Source: https://github.com/PyCQA/docformatter/issues/289 45 | rev: eb1df347edd128b30cd3368dddc3aa65edcfac38 46 | hooks: 47 | - id: docformatter 48 | args: [--in-place, --wrap-summaries=99, --wrap-descriptions=99] 49 | 50 | # python check (PEP8), programming errors and code complexity 51 | - repo: https://github.com/PyCQA/flake8 52 | rev: 7.0.0 53 | hooks: 54 | - id: flake8 55 | # ignore E203 because black is used for formatting. 56 | # W503 and W504 are conflicting rules. We ignore W503 in favor of W504. 57 | args: 58 | [ 59 | "--ignore", 60 | "E203,E501,F401,F403,F841,W503", 61 | "--exclude", 62 | "logs/*,data/*", 63 | ] 64 | 65 | # python security linter 66 | - repo: https://github.com/PyCQA/bandit 67 | rev: "1.7.1" 68 | hooks: 69 | - id: bandit 70 | args: ["-s", "B101"] 71 | 72 | # yaml formatting 73 | - repo: https://github.com/pre-commit/mirrors-prettier 74 | rev: v2.7.1 75 | hooks: 76 | - id: prettier 77 | types: [yaml] 78 | 79 | # jupyter notebook cell output clearing 80 | - repo: https://github.com/kynan/nbstripout 81 | rev: 0.5.0 82 | hooks: 83 | - id: nbstripout 84 | 85 | # md formatting 86 | - repo: https://github.com/executablebooks/mdformat 87 | rev: 0.7.22 88 | hooks: 89 | - id: mdformat 90 | args: ["--number"] 91 | additional_dependencies: 92 | - mdformat-gfm 93 | - mdformat-tables 94 | - mdformat_frontmatter 95 | # - mdformat-toc 96 | # - mdformat-black 97 | 98 | # word spelling linter 99 | - repo: https://github.com/codespell-project/codespell 100 | rev: v2.1.0 101 | hooks: 102 | - id: codespell 103 | args: 104 | - --skip=logs/**,data/** 105 | - --ignore-words-list=abc,def,gard,fo 106 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Contributor Covenant Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to making participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | - Using welcoming and inclusive language 18 | - Being respectful of differing viewpoints and experiences 19 | - Gracefully accepting constructive criticism 20 | - Focusing on what is best for the community 21 | - Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | - The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | - Trolling, insulting/derogatory comments, and personal or political attacks 28 | - Public or private harassment 29 | - Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | - Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies both within project spaces and in public spaces 49 | when an individual is representing the project or its community. Examples of 50 | representing a project or community include using an official project e-mail 51 | address, posting via an official social media account, or acting as an appointed 52 | representative at an online or offline event. Representation of a project may be 53 | further defined and clarified by project maintainers. 54 | 55 | ## Enforcement 56 | 57 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 58 | reported by contacting the project team at webadmin@linux.intel.com. All 59 | complaints will be reviewed and investigated and will result in a response that 60 | is deemed necessary and appropriate to the circumstances. The project team is 61 | obligated to maintain confidentiality with regard to the reporter of an incident. 62 | Further details of specific enforcement policies may be posted separately. 63 | 64 | Project maintainers who do not follow or enforce the Code of Conduct in good 65 | faith may face temporary or permanent repercussions as determined by other 66 | members of the project's leadership. 67 | 68 | ## Attribution 69 | 70 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 71 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 72 | 73 | For answers to common questions about this code of conduct, see 74 | https://www.contributor-covenant.org/faq 75 | 76 | [homepage]: https://www.contributor-covenant.org 77 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2021, Intel Labs 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | 3. Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | 2 | .PHONY: help 3 | help: ## Show help 4 | @grep -E '^[.a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}' 5 | 6 | .PHONY: clean 7 | clean: ## Clean autogenerated files 8 | rm -rf dist 9 | find . -type f -name "*.DS_Store" -ls -delete 10 | find . | grep -E "(__pycache__|\.pyc|\.pyo)" | xargs rm -rf 11 | find . | grep -E ".pytest_cache" | xargs rm -rf 12 | find . | grep -E ".ipynb_checkpoints" | xargs rm -rf 13 | rm -f .coverage 14 | 15 | .PHONY: clean-logs 16 | clean-logs: ## Clean logs 17 | rm -r logs/** 18 | 19 | .PHONY: style 20 | style: ## Run pre-commit hooks 21 | pre-commit run -a 22 | 23 | .PHONY: sync 24 | sync: ## Merge changes from main branch to your current branch 25 | git fetch --all 26 | git merge main 27 | 28 | .PHONY: test 29 | test: ## Run not slow tests 30 | pytest -k "not slow" 31 | 32 | .PHONY: test-full 33 | test-full: ## Run all tests 34 | pytest 35 | 36 | .PHONY: debug 37 | debug: ## Enter debugging mode with pdb, an example. 38 | # 39 | # tips: 40 | # - use "breakpoint()" to set breakpoint 41 | # - use "h" to print all commands 42 | # - use "n" to execute the next line 43 | # - use "c" to run until the breakpoint is hit 44 | # - use "l" to print src code around current line, "ll" for full function code 45 | # - docs: https://docs.python.org/3/library/pdb.html 46 | # 47 | python -m pdb -m mart experiment=CIFAR10_CNN debug=default 48 | 49 | .PHONY: cifar_train 50 | cifar_train: ## Adversarial training for a CIFAR-10 model. 51 | python -m mart experiment=CIFAR10_CNN_Adv \ 52 | fit=true \ 53 | trainer=gpu 54 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | # Modular Adversarial Robustness Toolkit 4 | 5 | [![License](https://img.shields.io/badge/License-BSD%203--Clause-blue.svg)](https://opensource.org/licenses/BSD-3-Clause) 6 | 7 | 8 | 9 | A unified optimization-based framework 10 | 11 |
12 | 13 | ## Description 14 | 15 | **Modular Adversarial Robustness Toolkit** makes it easy to compose novel attacks to evaluate adversarial robustness of deep learning models. Thanks to the modular design of the optimization-based attack framework, you can use off-the-shelf elements, such as optimizers and learning rate schedulers, from PyTorch to compose powerful attacks. The unified framework also supports advanced features, such as early stopping, to improve attack efficiency. 16 | 17 |
18 | 19 | 20 | Modular Design 21 | 22 |
23 | 24 | ## Installation 25 | 26 | ### Using pip 27 | 28 | ```bash 29 | pip install mart[full]@https://github.com/IntelLabs/MART/archive/refs/tags/.zip 30 | ``` 31 | 32 | Replace `` with the MART's version you want to install. For example: 33 | 34 | ```bash 35 | pip install mart[full]@https://github.com/IntelLabs/MART/archive/refs/tags/v0.2.1.zip 36 | ``` 37 | 38 | ### Manual installation 39 | 40 | ```bash 41 | # clone project 42 | git clone https://github.com/IntelLabs/MART 43 | cd MART 44 | 45 | # [OPTIONAL] create conda environment 46 | # Recommend Python 3.9 47 | conda create -n myenv python=3.9 48 | conda activate myenv 49 | 50 | # [OPTIONAL] or create virtualenv environment 51 | python3 -m venv .venv 52 | source .venv/bin/activate 53 | 54 | # Install Modular Adversarial Robustness Toolkit, if you plan to create your own `configs` folder elsewhere. 55 | pip install -e .[full] 56 | 57 | # [OPTIONAL] install pre-commit hooks 58 | # this will trigger the pre-commit checks in each `git commit` command. 59 | pre-commit install 60 | 61 | # If your CUDA version is not 10.2, you need to uninstall pytorch and torchvision, and 62 | # then reinstall them according to platform instructions at https://pytorch.org/get-started/ 63 | # FYI, this is what we do: 64 | # $ pip uninstall torch torchvision 65 | # $ pip install torch==1.12.1+cu116 torchvision==0.13.1+cu116 --extra-index-url https://download.pytorch.org/whl/cu116 66 | 67 | ``` 68 | 69 | ## How to run 70 | 71 | The toolkit comes with built-in experiment configurations in [mart/configs](mart/configs). 72 | 73 | For example, you can run a fast adversarial training experiment on CIFAR-10 with `python -m mart experiment=CIFAR10_CNN_Adv`. 74 | Running on GPU will make it even faster `CUDA_VISIBLE_DEVICES=0 python -m mart experiment=CIFAR10_CNN_Adv trainer=gpu trainer.precision=16`. 75 | 76 | You can see other examples in [examples](/examples). 77 | 78 | ## Acknowledgements 79 | 80 | This material is based upon work supported by the Defense Advanced Research Projects Agency (DARPA) under Contract No. HR001119S0026. 81 | 82 | ## Disclaimer 83 | 84 | This “research quality code” is provided by Intel “As Is” without any express or implied warranty of any kind. Intel does not warrant or assume responsibility for the accuracy or completeness of any information, text, graphics, links or other items within the code. A thorough security review has not been performed on this code. Additionally, this repository will not be actively maintained and as such may contain components that are out of date, or contain known security vulnerabilities. Proceed with caution. 85 | -------------------------------------------------------------------------------- /data/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IntelLabs/MART/c86a059a9177b26e1fb0404216145059f826eddb/data/.gitkeep -------------------------------------------------------------------------------- /data/arch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IntelLabs/MART/c86a059a9177b26e1fb0404216145059f826eddb/data/arch.png -------------------------------------------------------------------------------- /data/loop.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IntelLabs/MART/c86a059a9177b26e1fb0404216145059f826eddb/data/loop.png -------------------------------------------------------------------------------- /examples/README.md: -------------------------------------------------------------------------------- 1 | # How to use Modular Adversarial Robustness Toolkit 2 | 3 | We provide examples on how to use the toolkit in your project. 4 | 5 | A typical procedure is 6 | 7 | 1. Install the toolkit as a Python package; 8 | 9 | - Install the latest by `pip install https://github.com/IntelLabs/MART/archive/refs/heads/main.zip` 10 | - Or install a released version by `pip install https://github.com/IntelLabs/MART/archive/refs/tags/.zip` 11 | 12 | 2. Create a `configs` folder; 13 | 3. Add your configurations in `configs`; 14 | 4. Run experiments at the folder that contains `configs`. 15 | 16 | The toolkit searches configurations in the order of `./configs` and `mart.configs`. 17 | Local configurations in `./configs` precede those built-in configurations in `mart/configs` if they share the same name. 18 | 19 | You can find specific examples in sub-folders. 20 | -------------------------------------------------------------------------------- /examples/anomalib_adversary/.gitignore: -------------------------------------------------------------------------------- 1 | results 2 | datasets 3 | -------------------------------------------------------------------------------- /examples/anomalib_adversary/configs/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IntelLabs/MART/c86a059a9177b26e1fb0404216145059f826eddb/examples/anomalib_adversary/configs/.gitkeep -------------------------------------------------------------------------------- /examples/anomalib_adversary/pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "anomalib_adversary" 3 | version = "0.1.0a" 4 | description = "Evaluating robustness of anomaly detection models in Anomalib." 5 | authors = [ 6 | {name = "Intel Corporation"} 7 | ] 8 | 9 | dependencies = [ 10 | # Lock to an earlier version of anomalib because they later decorate anomalib.models.image.winclip.torch_model.WinClipModel.forward() with torch.no_grad(). 11 | "anomalib[full] @ git+https://github.com/openvinotoolkit/anomalib.git@241c14787bded6cd3cb5241b74673477601479ce", 12 | "mart @ https://github.com/IntelLabs/MART/archive/refs/tags/v0.6.1.zip", 13 | "torch-rotation==0.1.2", 14 | ] 15 | 16 | [tool.setuptools.packages.find] 17 | include = ["anomalib_adversary*"] 18 | -------------------------------------------------------------------------------- /examples/carla_overhead_object_detection/Makefile: -------------------------------------------------------------------------------- 1 | # Download and extract dataset of carla_over_obj_det 2 | CARLA_OVERHEAD_DATASET_TRAIN ?= data/carla_over_obj_det/train/kwcoco_annotations.json 3 | CARLA_OVERHEAD_DATASET_DEV ?= data/carla_over_obj_det/dev/kwcoco_annotations.json 4 | 5 | data/carla_over_obj_det/carla_over_od_dev_2.0.0.tar.gz: 6 | mkdir -p $(@D) 7 | wget -O $@ https://armory-public-data.s3.us-east-2.amazonaws.com/carla/carla_over_od_dev_2.0.0.tar.gz 8 | 9 | $(CARLA_OVERHEAD_DATASET_DEV): data/carla_over_obj_det/carla_over_od_dev_2.0.0.tar.gz 10 | tar -zxf $< -C data/carla_over_obj_det 11 | 12 | data/carla_over_obj_det/carla_over_od_train_val_1.0.0.tar.gz: 13 | mkdir -p $(@D) 14 | wget -O $@ https://armory-public-data.s3.us-east-2.amazonaws.com/carla/carla_over_od_train_val_1.0.0.tar.gz 15 | 16 | $(CARLA_OVERHEAD_DATASET_TRAIN): data/carla_over_obj_det/carla_over_od_train_val_1.0.0.tar.gz 17 | tar -zxf $< -C data/carla_over_obj_det 18 | 19 | 20 | .PHONY: carla_train 21 | carla_train: $(CARLA_OVERHEAD_DATASET_TRAIN) $(CARLA_OVERHEAD_DATASET_DEV) ## Train Faster R-CNN with the CarlaOverObjDet dataset from Armory. 22 | python -m mart \ 23 | experiment=ArmoryCarlaOverObjDet_TorchvisionFasterRCNN \ 24 | trainer=gpu \ 25 | trainer.precision=16 \ 26 | fit=true \ 27 | tags=["regular_training","backbone_ImageNetPretrained"] \ 28 | 29 | 30 | # You need to specify weights of target model in [model.modules.losses_and_detections.model.weights_fpath]. 31 | .PHONY: carla_attack 32 | carla_attack: $(CARLA_OVERHEAD_DATASET_TRAIN) $(CARLA_OVERHEAD_DATASET_DEV) ## Evaluate adversarial robustness of a pretrained model. 33 | python -m mart \ 34 | experiment=ArmoryCarlaOverObjDet_TorchvisionFasterRCNN \ 35 | trainer=gpu \ 36 | fit=false \ 37 | model.modules.losses_and_detections.model.weights_fpath=null \ 38 | +attack@model.modules.input_adv_test=object_detection_mask_adversary \ 39 | model.modules.input_adv_test.optimizer.lr=5 \ 40 | model.modules.input_adv_test.max_iters=50 \ 41 | +model.test_sequence.seq001.input_adv_test._call_with_args_=[input,target] \ 42 | +model.test_sequence.seq001.input_adv_test.model=model \ 43 | +model.test_sequence.seq001.input_adv_test.step=step \ 44 | model.test_sequence.seq010.preprocessor=[input_adv_test] \ 45 | # tags=["MaskPGD50_LR5"] 46 | -------------------------------------------------------------------------------- /examples/carla_overhead_object_detection/README.md: -------------------------------------------------------------------------------- 1 | # Introduction 2 | 3 | This example shows how to use MART to train an object detection model on the Carla overhead dataset. 4 | 5 | ## Installation 6 | 7 | ```bash 8 | pip install -r requirements.txt 9 | ``` 10 | 11 | ```bash 12 | # train on 1 GPU 13 | CUDA_VISIBLE_DEVICES=0 \ 14 | python -m mart experiment=ArmoryCarlaOverObjDet_TorchvisionFasterRCNN \ 15 | task_name=1GPU_ArmoryCarlaOverObjDet_TorchvisionFasterRCNN \ 16 | trainer=gpu \ 17 | fit=true 18 | 19 | # train on multiple GPUs using Distributed Data Parallel 20 | CUDA_VISIBLE_DEVICES=0,1 \ 21 | python -m mart experiment=ArmoryCarlaOverObjDet_TorchvisionFasterRCNN \ 22 | task_name=2GPUs_ArmoryCarlaOverObjDet_TorchvisionFasterRCNN \ 23 | fit=true \ 24 | trainer=ddp \ 25 | trainer.devices=2 \ 26 | datamodule.ims_per_batch=4 \ 27 | model.optimizer.lr=0.025 \ 28 | trainer.max_steps=5244 29 | ``` 30 | -------------------------------------------------------------------------------- /examples/carla_overhead_object_detection/configs/datamodule/armory_carla_over_objdet.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - coco 3 | 4 | train_dataset: 5 | root: ${paths.data_dir}/carla_over_obj_det/train 6 | modalities: ["rgb"] 7 | annFile: ${paths.data_dir}/carla_over_obj_det/train/kwcoco_annotations.json 8 | 9 | val_dataset: 10 | root: ${paths.data_dir}/carla_over_obj_det/val 11 | modalities: ["rgb"] 12 | annFile: ${paths.data_dir}/carla_over_obj_det/val/kwcoco_annotations.json 13 | 14 | test_dataset: 15 | root: ${paths.data_dir}/carla_over_obj_det/dev 16 | modalities: ["rgb"] 17 | annFile: ${paths.data_dir}/carla_over_obj_det/dev/kwcoco_annotations.json 18 | -------------------------------------------------------------------------------- /examples/carla_overhead_object_detection/configs/datamodule/armory_carla_over_objdet_perturbable_mask.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - armory_carla_over_objdet 3 | 4 | train_dataset: 5 | transforms: 6 | transforms: 7 | - _target_: torchvision.transforms.ToTensor 8 | # ConvertCocoPolysToMask must be prior to ConvertInstanceSegmentationToPerturbable. 9 | - _target_: mart.transforms.ConvertCocoPolysToMask 10 | - _target_: mart.transforms.RandomHorizontalFlip 11 | p: 0.5 12 | - _target_: mart.transforms.ConvertInstanceSegmentationToPerturbable 13 | - _target_: mart.transforms.Denormalize 14 | center: 0 15 | scale: 255 16 | - _target_: torch.fake_quantize_per_tensor_affine 17 | _partial_: true 18 | # (x/1+0).round().clamp(0, 255) * 1 19 | scale: 1 20 | zero_point: 0 21 | quant_min: 0 22 | quant_max: 255 23 | 24 | val_dataset: 25 | transforms: 26 | transforms: 27 | - _target_: torchvision.transforms.ToTensor 28 | # ConvertCocoPolysToMask must be prior to ConvertInstanceSegmentationToPerturbable. 29 | - _target_: mart.transforms.ConvertCocoPolysToMask 30 | - _target_: mart.transforms.ConvertInstanceSegmentationToPerturbable 31 | - _target_: mart.transforms.Denormalize 32 | center: 0 33 | scale: 255 34 | - _target_: torch.fake_quantize_per_tensor_affine 35 | _partial_: true 36 | # (x/1+0).round().clamp(0, 255) * 1 37 | scale: 1 38 | zero_point: 0 39 | quant_min: 0 40 | quant_max: 255 41 | 42 | test_dataset: 43 | transforms: 44 | transforms: 45 | - _target_: torchvision.transforms.ToTensor 46 | - _target_: mart.transforms.ConvertCocoPolysToMask 47 | # Add masks of perturbable regions. 48 | - _target_: mart.transforms.LoadPerturbableMask 49 | perturb_mask_folder: ${paths.data_dir}/carla_over_obj_det/dev/foreground_mask/ 50 | - _target_: mart.transforms.Denormalize 51 | center: 0 52 | scale: 255 53 | - _target_: torch.fake_quantize_per_tensor_affine 54 | _partial_: true 55 | # (x/1+0).round().clamp(0, 255) * 1 56 | scale: 1 57 | zero_point: 0 58 | quant_min: 0 59 | quant_max: 255 60 | -------------------------------------------------------------------------------- /examples/carla_overhead_object_detection/configs/experiment/ArmoryCarlaOverObjDet_TorchvisionFasterRCNN.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - COCO_TorchvisionFasterRCNN 5 | - override /datamodule: armory_carla_over_objdet_perturbable_mask 6 | 7 | task_name: "ArmoryCarlaOverObjDet_TorchvisionFasterRCNN" 8 | tags: ["regular_training"] 9 | 10 | optimized_metric: "test_metrics/map" 11 | 12 | model: 13 | modules: 14 | losses_and_detections: 15 | model: 16 | num_classes: 3 17 | weights: null 18 | 19 | optimizer: 20 | lr: 0.0125 21 | momentum: 0.9 22 | weight_decay: 1e-4 23 | 24 | trainer: 25 | # 3,600 training images, batch_size=2, 6 epochs 26 | # max_steps = 3600 / 2 * 6 27 | max_steps: 10800 28 | -------------------------------------------------------------------------------- /examples/carla_overhead_object_detection/requirements.txt: -------------------------------------------------------------------------------- 1 | mart[full] @ git+https://github.com/IntelLabs/MART.git 2 | -------------------------------------------------------------------------------- /examples/carla_overhead_object_detection/tests/conftest.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2022 Intel Corporation 3 | # 4 | # SPDX-License-Identifier: BSD-3-Clause 5 | # 6 | 7 | import os 8 | from pathlib import Path 9 | 10 | import pyrootutils 11 | import pytest 12 | from hydra import compose, initialize 13 | from hydra.core.global_hydra import GlobalHydra 14 | from omegaconf import DictConfig 15 | 16 | root = Path(os.getcwd()) 17 | pyrootutils.set_root(path=root, dotenv=True, pythonpath=True) 18 | 19 | experiments_names = [ 20 | "ArmoryCarlaOverObjDet_TorchvisionFasterRCNN", 21 | ] 22 | 23 | 24 | # Loads the configuration file from a given experiment 25 | def get_cfg(experiment): 26 | with initialize(version_base="1.2", config_path="../configs"): 27 | params = "experiment=" + experiment 28 | cfg = compose(config_name="lightning.yaml", return_hydra_config=True, overrides=[params]) 29 | return cfg 30 | 31 | 32 | @pytest.fixture(scope="function", params=experiments_names) 33 | def cfg_experiment(request) -> DictConfig: 34 | cfg = get_cfg(request.param) 35 | 36 | yield cfg 37 | 38 | GlobalHydra.instance().clear() 39 | -------------------------------------------------------------------------------- /examples/carla_overhead_object_detection/tests/helpers: -------------------------------------------------------------------------------- 1 | ../../../tests/helpers/ -------------------------------------------------------------------------------- /examples/carla_overhead_object_detection/tests/test_configs.py: -------------------------------------------------------------------------------- 1 | ../../../tests/test_configs.py -------------------------------------------------------------------------------- /examples/carla_overhead_object_detection/tests/test_experiments.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Dict 3 | 4 | import pytest 5 | from hydra.core.global_hydra import GlobalHydra 6 | 7 | from tests.helpers.dataset_generator import FakeCOCODataset 8 | from tests.helpers.run_if import RunIf 9 | from tests.helpers.run_sh_command import run_sh_command 10 | 11 | module = "mart" 12 | 13 | 14 | @pytest.fixture(scope="function") 15 | def carla_cfg(tmp_path) -> Dict: 16 | # Generate fake CARLA dataset on disk at tmp_path 17 | dataset = FakeCOCODataset(tmp_path, config=carla_ds, name="carla_over_obj_det") 18 | dataset.generate(num_images=2, num_annotations_per_image=2) 19 | 20 | cfg = { 21 | "trainer": [ 22 | "++trainer.fast_dev_run=3", 23 | ], 24 | "datamodel": [ 25 | "++paths.data_dir=" + str(tmp_path), 26 | "datamodule.num_workers=0", 27 | ], 28 | } 29 | yield cfg 30 | 31 | GlobalHydra.instance().clear() 32 | 33 | 34 | carla_ds = { 35 | "train": { 36 | "folder": "train", 37 | "modalities": ["rgb"], 38 | "ann_folder": "train", 39 | "ann_file": "kwcoco_annotations.json", 40 | }, 41 | "val": { 42 | "folder": "val", 43 | "modalities": ["rgb"], 44 | "ann_folder": "val", 45 | "ann_file": "kwcoco_annotations.json", 46 | }, 47 | "test": { 48 | "folder": "dev", 49 | "modalities": ["foreground_mask", "rgb"], 50 | "ann_folder": "dev", 51 | "ann_file": "kwcoco_annotations.json", 52 | }, 53 | } 54 | 55 | 56 | @RunIf(sh=True) 57 | @pytest.mark.slow 58 | def test_armory_carla_fasterrcnn_experiment(carla_cfg, tmp_path): 59 | """Test Armory CARLA TorchVision FasterRCNN experiment.""" 60 | overrides = carla_cfg["trainer"] + carla_cfg["datamodel"] 61 | command = [ 62 | "-m", 63 | module, 64 | "experiment=ArmoryCarlaOverObjDet_TorchvisionFasterRCNN", 65 | "+attack@model.modules.input_adv_test=object_detection_mask_adversary", 66 | "hydra.sweep.dir=" + str(tmp_path), 67 | "optimized_metric=training/loss_objectness", 68 | ] + overrides 69 | run_sh_command(command) 70 | -------------------------------------------------------------------------------- /examples/fiftyone/README.md: -------------------------------------------------------------------------------- 1 | # Introduction 2 | 3 | This example shows how to use MART with the FiftyOne integration. FiftyOne is an open-source tool for building high quality dataset of images and videos. With this integration, MART delegates the data handling to FiftyOne. 4 | 5 | ## Installation 6 | 7 | ```bash 8 | pip install git+https://github.com/IntelLabs/MART.git[fiftyone] 9 | ``` 10 | 11 | # FiftyOne commands to load (index) datasets. 12 | 13 | Use COCO-2017 as an example. Unfortunately, FiftyOne does not support person-keypoints annotations yet. 14 | 15 | ## Download and load zoo datasets 16 | 17 | ```bash 18 | fiftyone zoo datasets load \ 19 | coco-2017 \ 20 | -s train \ 21 | -n coco-2017-instances-train \ 22 | -k include_id=true label_types=detections,segmentations 23 | 24 | fiftyone zoo datasets load \ 25 | coco-2017 \ 26 | -s validation \ 27 | -n coco-2017-instances-validation \ 28 | -k include_id=true label_types=detections,segmentations 29 | ``` 30 | 31 | ## Load local datasets 32 | 33 | ```bash 34 | fiftyone datasets create \ 35 | --name coco-2017-instances-validation \ 36 | --dataset-dir /path/to/datasets/coco/ \ 37 | --type fiftyone.types.COCODetectionDataset \ 38 | --kwargs \ 39 | data_path="val2017" \ 40 | labels_path=/path/to/datasets/coco/annotations/instances_val2017.json \ 41 | persistent=true \ 42 | include_id=true 43 | ``` 44 | 45 | ## Use the FiftyOne datamodule 46 | 47 | ```yaml 48 | datamodule: 49 | train_dataset: 50 | dataset_name: coco-2017-instances-train 51 | gt_field: segmentations 52 | val_dataset: 53 | dataset_name: coco-2017-instances-validation 54 | gt_field: segmentations 55 | ``` 56 | -------------------------------------------------------------------------------- /examples/robust_bench/README.md: -------------------------------------------------------------------------------- 1 | # Introduction 2 | 3 | This example shows how to use MART to evaluate adversarial robustness of models from RobustBench. 4 | 5 | Note that the attack algorithm here is not optimal, just for the demonstration purpose. 6 | 7 | The `requirements.txt` contains dependency of MART and RobustBench. 8 | 9 | The `./configs` folder contains configurations of the target model `classifier_robustbench` and the MART experiment `CIFAR10_RobustBench`. 10 | 11 | The configuration files in `./configs` precedes those in `mart.configs` (MART's built-in configs). 12 | 13 | ## Installation 14 | 15 | ```bash 16 | pip install -r requirements.txt 17 | ``` 18 | 19 | ## How to run 20 | 21 | ```bash 22 | # run on CPU 23 | python -m mart experiment=CIFAR10_RobustBench \ 24 | trainer=default \ 25 | fit=false \ 26 | +trainer.limit_test_batches=1 \ 27 | +attack@model.modules.input_adv_test=classification_eps8_pgd10_step1 28 | 29 | # run on GPU 30 | CUDA_VISIBLE_DEVICES=0 \ 31 | python -m mart experiment=CIFAR10_RobustBench \ 32 | trainer=gpu \ 33 | fit=false \ 34 | +trainer.limit_test_batches=1 \ 35 | +attack@model.modules.input_adv_test=classification_eps8_pgd10_step1 \ 36 | +model.test_sequence.seq005=input_adv_test \ 37 | model.test_sequence.seq010.preprocessor=["input_adv_test"] 38 | 39 | # Evaluate with AutoAttack, expect 0.6171875 40 | CUDA_VISIBLE_DEVICES=0 \ 41 | python -m mart experiment=CIFAR10_RobustBench \ 42 | trainer=gpu \ 43 | fit=false \ 44 | +trainer.limit_test_batches=1 \ 45 | +attack@model.modules.input_adv_test=classification_autoattack \ 46 | +model.test_sequence.seq005=input_adv_test \ 47 | model.test_sequence.seq010.preprocessor=["input_adv_test"] 48 | ``` 49 | -------------------------------------------------------------------------------- /examples/robust_bench/configs/attack/classification_autoattack.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - enforcer: default 3 | - enforcer/constraints: [lp, pixel_range] 4 | 5 | _target_: mart.attack.NormalizedAdversaryAdapter 6 | adversary: 7 | _target_: mart.utils.adapters.PartialInstanceWrapper 8 | partial: 9 | _target_: autoattack.AutoAttack 10 | _partial_: true 11 | # AutoAttack needs to specify device for PyTorch tensors: cpu/cuda 12 | # We can not use ${trainer.accelerator} because the vocabulary is different: cpu/gpu 13 | # device: cpu 14 | norm: Linf 15 | # 8/255 16 | eps: 0.03137254901960784 17 | version: custom 18 | attacks_to_run: 19 | - apgd-dlr 20 | wrapper: 21 | _target_: mart.utils.adapters.CallableAdapter 22 | _partial_: true 23 | redirecting_fn: run_standard_evaluation 24 | enforcer: 25 | constraints: 26 | lp: 27 | eps: 8 28 | -------------------------------------------------------------------------------- /examples/robust_bench/configs/experiment/CIFAR10_RobustBench.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - CIFAR10_CNN 5 | - override /model: classifier_robustbench 6 | 7 | task_name: "CIFAR10_RobustBench" 8 | tags: ["adv_trained"] 9 | -------------------------------------------------------------------------------- /examples/robust_bench/configs/model/classifier_robustbench.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - classifier 3 | 4 | modules: 5 | preprocessor: 6 | # Convert [0, 255] input to [0, 1] 7 | _target_: torchvision.transforms.Normalize 8 | mean: 0 9 | std: 255 10 | 11 | logits: 12 | _target_: robustbench.utils.load_model 13 | model_name: Gowal2021Improving_R18_ddpm_100m 14 | model_dir: ${paths.data_dir} 15 | dataset: cifar10 16 | threat_model: Linf 17 | -------------------------------------------------------------------------------- /examples/robust_bench/requirements.txt: -------------------------------------------------------------------------------- 1 | mart[full] @ git+https://github.com/IntelLabs/MART.git 2 | robustbench @ git+https://github.com/RobustBench/robustbench.git@9a590683b7daecf963244dea402529f0d728c727 3 | -------------------------------------------------------------------------------- /examples/robust_bench/tests/helpers: -------------------------------------------------------------------------------- 1 | ../../../tests/helpers/ -------------------------------------------------------------------------------- /examples/robust_bench/tests/test_experiments.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | import pytest 4 | from hydra.core.global_hydra import GlobalHydra 5 | 6 | from tests.helpers.run_if import RunIf 7 | from tests.helpers.run_sh_command import run_sh_command 8 | 9 | module = "mart" 10 | 11 | 12 | # common configuration for classification related tests. 13 | @pytest.fixture(scope="function") 14 | def classification_cfg() -> Dict: 15 | cfg = { 16 | "trainer": [ 17 | "++trainer.fast_dev_run=3", 18 | ], 19 | "datamodel": [ 20 | "datamodule=dummy_classification", 21 | "datamodule.ims_per_batch=2", 22 | "datamodule.num_workers=0", 23 | ], 24 | } 25 | yield cfg 26 | 27 | GlobalHydra.instance().clear() 28 | 29 | 30 | @RunIf(sh=True) 31 | def test_cifar10_cnn_autoattack_experiment(classification_cfg, tmp_path): 32 | """Test CIFAR10 CNN AutoAttack experiment.""" 33 | overrides = classification_cfg["datamodel"] 34 | command = [ 35 | "-m", 36 | module, 37 | "experiment=CIFAR10_CNN", 38 | "hydra.sweep.dir=" + str(tmp_path), 39 | "++datamodule.train_dataset.image_size=[3,32,32]", 40 | "++datamodule.train_dataset.num_classes=10", 41 | "fit=false", 42 | "+attack@model.modules.input_adv_test=classification_autoattack", 43 | '+model.modules.input_adv_test.adversary.partial.device="cpu"', 44 | "+trainer.limit_test_batches=1", 45 | ] + overrides 46 | run_sh_command(command) 47 | -------------------------------------------------------------------------------- /hydra_plugins/hydra_mart/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IntelLabs/MART/c86a059a9177b26e1fb0404216145059f826eddb/hydra_plugins/hydra_mart/__init__.py -------------------------------------------------------------------------------- /hydra_plugins/hydra_mart/mart.py: -------------------------------------------------------------------------------- 1 | from hydra.core.config_search_path import ConfigSearchPath 2 | from hydra.plugins.search_path_plugin import SearchPathPlugin 3 | from omegaconf import OmegaConf 4 | 5 | 6 | class HydraMartSearchPathPlugin(SearchPathPlugin): 7 | def manipulate_search_path(self, search_path: ConfigSearchPath) -> None: 8 | # Add mart.configs to search path 9 | search_path.append("hydra-mart", "pkg://mart.configs") 10 | 11 | 12 | OmegaConf.register_new_resolver("negate", lambda x: -x) 13 | -------------------------------------------------------------------------------- /logs/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IntelLabs/MART/c86a059a9177b26e1fb0404216145059f826eddb/logs/.gitkeep -------------------------------------------------------------------------------- /mart/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib.metadata 2 | 3 | from mart import attack as attack 4 | from mart import nn as nn 5 | from mart import optim as optim 6 | from mart import transforms as transforms 7 | from mart import utils as utils 8 | from mart.utils.imports import _HAS_LIGHTNING 9 | 10 | if _HAS_LIGHTNING: 11 | from mart import datamodules as datamodules 12 | from mart import models as models 13 | 14 | __version__ = importlib.metadata.version(__package__ or __name__) 15 | -------------------------------------------------------------------------------- /mart/__main__.py: -------------------------------------------------------------------------------- 1 | # this file acts as a robust starting point for launching hydra runs and multiruns 2 | # can be run from any place 3 | 4 | import os 5 | import sys 6 | from pathlib import Path 7 | 8 | import hydra 9 | import pyrootutils 10 | from omegaconf import DictConfig 11 | 12 | from mart import utils 13 | 14 | log = utils.get_pylogger(__name__) 15 | 16 | # project root setup 17 | # uses the current working directory as root. 18 | # sets PROJECT_ROOT environment variable (used in `configs/paths/default.yaml`) 19 | # loads environment variables from ".env" if exists 20 | # adds root dir to the PYTHONPATH (so this file can be run from any place) 21 | # https://github.com/ashleve/pyrootutils 22 | # FIXME: Get rid of pyrootutils if we don't infer config.paths.root from PROJECT_ROOT. 23 | root = Path(os.getcwd()) 24 | pyrootutils.set_root(path=root, dotenv=True, pythonpath=True) 25 | 26 | config_path = root / "configs" 27 | if not config_path.exists(): 28 | log.warning(f"No config directory found at {config_path}!") 29 | config_path = "configs" 30 | 31 | 32 | @hydra.main(version_base="1.2", config_path=str(config_path), config_name="lightning.yaml") 33 | def main(cfg: DictConfig) -> float: 34 | 35 | if cfg.resume is None and ("datamodule" not in cfg or "model" not in cfg): 36 | log.fatal("") 37 | log.fatal("Please specify an experiment to run, e.g.") 38 | log.fatal( 39 | "$ python -m mart experiment=CIFAR10_CNN fit=false +trainer.limit_test_batches=1" 40 | ) 41 | log.fatal("or specify a checkpoint to resume, e.g.") 42 | log.fatal("$ python -m mart resume=logs/my_task_name/checkpoints/last.ckpt") 43 | log.fatal("") 44 | return -1 45 | 46 | # imports can be nested inside @hydra.main to optimize tab completion 47 | # https://github.com/facebookresearch/hydra/issues/934 48 | from mart.tasks.lightning import lightning 49 | from mart.utils import get_metric_value, get_resume_checkpoint 50 | 51 | # Resume and modify configs at the earliest point. 52 | # The actual checkpoint path is in cfg.ckpt_path 53 | cfg = get_resume_checkpoint(cfg) 54 | 55 | # train the model 56 | metric_dict, _ = lightning(cfg) 57 | 58 | # safely retrieve metric value for hydra-based hyperparameter optimization 59 | metric_value = get_metric_value( 60 | metric_dict=metric_dict, metric_name=cfg.get("optimized_metric") 61 | ) 62 | 63 | # return optimized metric 64 | return metric_value 65 | 66 | 67 | if __name__ == "__main__": 68 | ret = main() 69 | if ret is not None and ret < 0: 70 | sys.exit(ret) 71 | else: 72 | sys.exit(0) 73 | -------------------------------------------------------------------------------- /mart/attack/__init__.py: -------------------------------------------------------------------------------- 1 | from ..utils.imports import _HAS_LIGHTNING 2 | from .adversary_wrapper import * 3 | from .composer import * 4 | from .enforcer import * 5 | from .gain import * 6 | from .gradient_modifier import * 7 | from .initializer import * 8 | from .objective import * 9 | from .perturber import * 10 | from .projector import * 11 | 12 | if _HAS_LIGHTNING: 13 | from .adversary import * 14 | -------------------------------------------------------------------------------- /mart/attack/adversary_wrapper.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2022 Intel Corporation 3 | # 4 | # SPDX-License-Identifier: BSD-3-Clause 5 | # 6 | 7 | from __future__ import annotations 8 | 9 | from typing import TYPE_CHECKING, Any, Callable, Iterable 10 | 11 | import torch 12 | 13 | if TYPE_CHECKING: 14 | from .enforcer import Enforcer 15 | 16 | __all__ = ["NormalizedAdversaryAdapter"] 17 | 18 | 19 | class NormalizedAdversaryAdapter(torch.nn.Module): 20 | """A wrapper for running external classification adversaries in MART. 21 | 22 | External adversaries commonly take input of NCWH-[0,1] and return input_adv in the same format. 23 | """ 24 | 25 | def __init__( 26 | self, 27 | adversary: Callable[[Callable], Callable], 28 | enforcer: Enforcer, 29 | ): 30 | """ 31 | 32 | Args: 33 | adversary (functools.partial): A partial of an adversary object which awaits model. 34 | enforcer (Callable): Enforcing constraints of an adversary. 35 | """ 36 | super().__init__() 37 | 38 | self.adversary = adversary 39 | self.enforcer = enforcer 40 | 41 | def forward( 42 | self, 43 | input: torch.Tensor | Iterable[torch.Tensor], 44 | target: torch.Tensor | Iterable[torch.Tensor] | Iterable[dict[str, Any]], 45 | model: torch.nn.Module | None = None, 46 | **kwargs, 47 | ): 48 | 49 | # Shortcut. Input is already updated in the attack loop. 50 | if model is None: 51 | return input 52 | 53 | # Input NCHW [0,1]; Output logits. 54 | def model_wrapper(x): 55 | output = model(input=x * 255, target=target, model=None, **kwargs) 56 | logits = output["logits"] 57 | return logits 58 | 59 | attack = self.adversary(model_wrapper) 60 | input_adv = attack(input / 255, target) 61 | 62 | # Round to integer, in case of imprecise scaling. 63 | input_adv = (input_adv * 255).round() 64 | self.enforcer(input_adv, input=input, target=target) 65 | 66 | return input_adv 67 | -------------------------------------------------------------------------------- /mart/attack/composer/__init__.py: -------------------------------------------------------------------------------- 1 | from ...utils.imports import _HAS_TORCHVISION 2 | from .modular import * 3 | 4 | if _HAS_TORCHVISION: 5 | from .patch import * 6 | from .visualizer import * 7 | -------------------------------------------------------------------------------- /mart/attack/composer/visualizer.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2022 Intel Corporation 3 | # 4 | # SPDX-License-Identifier: BSD-3-Clause 5 | # 6 | 7 | import torch 8 | from torchvision.transforms.functional import to_pil_image 9 | 10 | 11 | class ComposerImageVisualizer: 12 | def __call__(self, output): 13 | for key, value in output.items(): 14 | if isinstance(value, torch.Tensor): 15 | to_pil_image(value / 255).save(f"{key}.png") 16 | -------------------------------------------------------------------------------- /mart/attack/gain.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2022 Intel Corporation 3 | # 4 | # SPDX-License-Identifier: BSD-3-Clause 5 | # 6 | 7 | from typing import Optional 8 | 9 | import torch 10 | 11 | __all__ = ["Gain"] 12 | 13 | 14 | class Gain(torch.nn.Module): 15 | """Gain functions must be differentiable so we inherit from nn.Module.""" 16 | 17 | pass 18 | 19 | 20 | class RoIHeadTargetClass(Gain): 21 | """The gain function encourages logits being classified as a particular class, e.g. background 22 | (class_index==0 in RCNN).""" 23 | 24 | def __init__(self, class_index: Optional[int] = 0, targeted: Optional[bool] = True) -> None: 25 | super().__init__() 26 | 27 | self.gain = torch.nn.CrossEntropyLoss() 28 | self.class_index = class_index 29 | self.targeted = targeted 30 | 31 | def forward( 32 | self, roi_heads_class_logits: torch.Tensor, proposals: torch.Tensor 33 | ) -> torch.Tensor: 34 | """ 35 | 36 | Args: 37 | roi_heads_class_logits (torch.Tensor): Class logits from roi_heads. 38 | proposals (_type_): We only want to know how many proposals are there for one input. 39 | 40 | Returns: 41 | torch.Tensor: A gain vector with separate gain value for each input. 42 | """ 43 | target = [self.class_index] * len(roi_heads_class_logits) 44 | device = roi_heads_class_logits.device 45 | target = torch.tensor(target, device=device) 46 | 47 | # Split class logits by input. 48 | boxes_per_image = [boxes_in_image.shape[0] for boxes_in_image in proposals] 49 | roi_heads_class_logits_list = roi_heads_class_logits.split(boxes_per_image, 0) 50 | target_list = target.split(boxes_per_image, 0) 51 | 52 | gains = [] 53 | for batch_logits, batch_target in zip(roi_heads_class_logits_list, target_list): 54 | gain = self.gain(batch_logits, batch_target) 55 | if self.targeted: 56 | gain = -gain 57 | gains.append(gain) 58 | 59 | gains = torch.stack(gains) 60 | 61 | return gains 62 | 63 | 64 | class RegionProposalScore(Gain): 65 | """The gain function to encourage background or foreground in region proposals. 66 | 67 | rpn_objectness is the sigmoid input. The lower value, the more likely to be background. 68 | """ 69 | 70 | def __init__(self, background: Optional[bool] = True) -> None: 71 | """""" 72 | super().__init__() 73 | 74 | self.background = background 75 | 76 | def forward(self, rpn_objectness: torch.Tensor) -> torch.Tensor: 77 | logits = torch.cat([logits.reshape(-1) for logits in rpn_objectness]) 78 | # TODO: We may remove sigmoid. 79 | probs = torch.sigmoid(logits) 80 | # prob_mean = probs.mean() 81 | if self.background: 82 | # Encourage background. 83 | return -probs 84 | else: 85 | # Encourage foreground. 86 | return probs 87 | -------------------------------------------------------------------------------- /mart/attack/gradient_modifier.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2022 Intel Corporation 3 | # 4 | # SPDX-License-Identifier: BSD-3-Clause 5 | # 6 | 7 | from __future__ import annotations 8 | 9 | from typing import Iterable 10 | 11 | import torch 12 | 13 | __all__ = ["GradientModifier"] 14 | 15 | 16 | class GradientModifier: 17 | """Gradient modifier base class.""" 18 | 19 | def __call__(self, parameters: torch.Tensor | Iterable[torch.Tensor]) -> None: 20 | if isinstance(parameters, torch.Tensor): 21 | parameters = [parameters] 22 | 23 | [self.modify_(parameter) for parameter in parameters] 24 | 25 | @torch.no_grad() 26 | def modify_(self, parameter: torch.Tensor) -> None: 27 | pass 28 | 29 | 30 | class Sign(GradientModifier): 31 | @torch.no_grad() 32 | def modify_(self, parameter: torch.Tensor) -> None: 33 | parameter.grad.sign_() 34 | 35 | 36 | class LpNormalizer(GradientModifier): 37 | """Scale gradients by a certain L-p norm.""" 38 | 39 | def __init__(self, p: int | float): 40 | self.p = float(p) 41 | 42 | @torch.no_grad() 43 | def modify_(self, parameter: torch.Tensor) -> None: 44 | p_norm = torch.norm(parameter.grad.detach(), p=self.p) 45 | parameter.grad.detach().div_(p_norm) 46 | -------------------------------------------------------------------------------- /mart/attack/initializer/__init__.py: -------------------------------------------------------------------------------- 1 | from ...utils.imports import _HAS_TORCHVISION 2 | from .base import * 3 | 4 | if _HAS_TORCHVISION: 5 | from .vision import * 6 | -------------------------------------------------------------------------------- /mart/attack/initializer/base.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2022 Intel Corporation 3 | # 4 | # SPDX-License-Identifier: BSD-3-Clause 5 | # 6 | 7 | from __future__ import annotations 8 | 9 | from typing import Iterable 10 | 11 | import torch 12 | 13 | 14 | class Initializer: 15 | """Initializer base class.""" 16 | 17 | @torch.no_grad() 18 | def __call__(self, parameters: torch.Tensor | Iterable[torch.Tensor]) -> None: 19 | if isinstance(parameters, torch.Tensor): 20 | parameters = [parameters] 21 | 22 | [self.initialize_(parameter) for parameter in parameters] 23 | 24 | @torch.no_grad() 25 | def initialize_(self, parameter: torch.Tensor) -> None: 26 | pass 27 | 28 | 29 | class Constant(Initializer): 30 | def __init__(self, constant: int | float = 0): 31 | self.constant = constant 32 | 33 | @torch.no_grad() 34 | def initialize_(self, parameter: torch.Tensor) -> None: 35 | torch.nn.init.constant_(parameter, self.constant) 36 | 37 | 38 | class Uniform(Initializer): 39 | def __init__(self, min: int | float, max: int | float, round: bool = False): 40 | self.min = min 41 | self.max = max 42 | self.round = round 43 | 44 | @torch.no_grad() 45 | def initialize_(self, parameter: torch.Tensor) -> None: 46 | torch.nn.init.uniform_(parameter, self.min, self.max) 47 | if self.round: 48 | parameter.round_() 49 | 50 | 51 | class UniformLp(Initializer): 52 | def __init__(self, eps: int | float, p: int | float = torch.inf): 53 | self.eps = eps 54 | self.p = p 55 | 56 | @torch.no_grad() 57 | def initialize_(self, parameter: torch.Tensor) -> None: 58 | torch.nn.init.uniform_(parameter, -self.eps, self.eps) 59 | # TODO: make sure the first dim is the batch dim. 60 | if self.p is not torch.inf: 61 | # We don't do tensor.renorm_() because the first dim is not the batch dim. 62 | pert_norm = parameter.norm(p=self.p) 63 | parameter.mul_(self.eps / pert_norm) 64 | -------------------------------------------------------------------------------- /mart/attack/initializer/vision.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2022 Intel Corporation 3 | # 4 | # SPDX-License-Identifier: BSD-3-Clause 5 | # 6 | 7 | import logging 8 | 9 | import torch 10 | import torchvision 11 | import torchvision.transforms.functional as F 12 | 13 | from .base import Initializer 14 | 15 | logger = logging.getLogger(__name__) 16 | 17 | 18 | class Image(Initializer): 19 | def __init__(self, path: str, scale: int = 1): 20 | self.image = torchvision.io.read_image(path, torchvision.io.ImageReadMode.RGB) / scale 21 | 22 | @torch.no_grad() 23 | def initialize_(self, parameter: torch.Tensor) -> None: 24 | image = self.image 25 | 26 | if image.shape != parameter.shape: 27 | logger.info(f"Resizing image from {image.shape} to {parameter.shape}...") 28 | image = F.resize(image, parameter.shape[1:]) 29 | 30 | parameter.copy_(image) 31 | -------------------------------------------------------------------------------- /mart/attack/objective/__init__.py: -------------------------------------------------------------------------------- 1 | from ...utils.imports import _HAS_TORCHVISION 2 | from .base import * 3 | from .classification import * 4 | 5 | if _HAS_TORCHVISION: 6 | from .object_detection import * 7 | -------------------------------------------------------------------------------- /mart/attack/objective/base.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2022 Intel Corporation 3 | # 4 | # SPDX-License-Identifier: BSD-3-Clause 5 | # 6 | 7 | import abc 8 | 9 | import torch 10 | 11 | __all__ = ["Objective"] 12 | 13 | 14 | class Objective(abc.ABC): 15 | """Objectives do not need to be differentiable so we do not inherit from nn.Module.""" 16 | 17 | @abc.abstractmethod 18 | def __call__(self, preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor: 19 | raise NotImplementedError 20 | -------------------------------------------------------------------------------- /mart/attack/objective/classification.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2022 Intel Corporation 3 | # 4 | # SPDX-License-Identifier: BSD-3-Clause 5 | # 6 | 7 | from typing import Any, Callable, Dict, Union 8 | 9 | import torch 10 | 11 | from .base import Objective 12 | 13 | __all__ = ["Mispredict", "RandomTarget"] 14 | 15 | 16 | class Mispredict(Objective): 17 | def __call__(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: 18 | # FIXME: I don't like this argmax call. It feels like this should receive input tensor of 19 | # the same shape as target? 20 | mispredictions = input.argmax(dim=-1) != target 21 | return mispredictions 22 | 23 | 24 | class RandomTarget(Objective): 25 | def __init__( 26 | self, nb_classes: int, gain_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] 27 | ) -> None: 28 | self.nb_classes = nb_classes 29 | self.gain_fn = gain_fn 30 | 31 | def __call__(self, logits: torch.Tensor, target: torch.Tensor) -> torch.Tensor: 32 | # FIXME: It may be better if we make sure that the pseudo target is different from target. 33 | pseudo_target = torch.randint_like(target, low=0, high=self.nb_classes) 34 | return self.gain_fn(logits, pseudo_target) 35 | -------------------------------------------------------------------------------- /mart/attack/objective/object_detection.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2022 Intel Corporation 3 | # 4 | # SPDX-License-Identifier: BSD-3-Clause 5 | # 6 | 7 | import abc 8 | from typing import List, Optional, Tuple, Union 9 | 10 | import torch 11 | from torchvision.ops import box_iou 12 | 13 | from .base import Objective 14 | 15 | __all__ = ["ZeroAP", "Missed"] 16 | 17 | 18 | class ZeroAP(Objective): 19 | """Determine if predictions yields zero Average Precision.""" 20 | 21 | def __init__( 22 | self, 23 | iou_threshold: Optional[float] = 0.5, 24 | confidence_threshold: Optional[float] = 0.5, 25 | ) -> None: 26 | super().__init__() 27 | 28 | self.iou_threshold = iou_threshold 29 | self.confidence_threshold = confidence_threshold 30 | 31 | def __call__( 32 | self, preds: Union[torch.Tensor, List], target: Union[torch.Tensor, List, Tuple] 33 | ) -> torch.Tensor: 34 | # For each class in target, 35 | # if there is one pred_box has IoU with one of the gt_box larger than iou_threshold 36 | # return False 37 | # Return True 38 | achieved_list = [] 39 | 40 | for pred, gt in zip(preds, target): 41 | achieved = True 42 | for gt_cls in set(gt["labels"].cpu().numpy()): 43 | # Ground truth boxes with the same class. 44 | gt_boxes = gt["boxes"][gt["labels"] == gt_cls] 45 | # The same class and confident enough prediction 46 | pred_boxes_idx = torch.logical_and( 47 | pred["labels"] == gt_cls, 48 | pred["scores"] >= self.confidence_threshold, 49 | ) 50 | pred_boxes = pred["boxes"][pred_boxes_idx] 51 | 52 | iou_pairs = box_iou(gt_boxes, pred_boxes) 53 | if iou_pairs.numel() > 0 and iou_pairs.max().item() >= self.iou_threshold: 54 | achieved = False 55 | break 56 | 57 | achieved_list.append(achieved) 58 | 59 | device = target[0]["boxes"].device 60 | achieved_tensor = torch.tensor(achieved_list, device=device) 61 | 62 | return achieved_tensor 63 | 64 | 65 | class Missed(Objective): 66 | """The objective of the adversary is to make all AP errors as the missed error, i.e. no object 67 | is detected, nor false positive.""" 68 | 69 | def __init__(self, confidence_threshold: Optional[float] = 0.5) -> None: 70 | super().__init__() 71 | 72 | self.confidence_threshold = confidence_threshold 73 | 74 | def __call__( 75 | self, preds: Union[torch.Tensor, List], target: Union[torch.Tensor, List, Tuple] 76 | ) -> torch.Tensor: 77 | achieved_list = [] 78 | 79 | for pred in preds: 80 | if (pred["scores"] >= self.confidence_threshold).sum().item() > 0: 81 | achieved_list.append(False) 82 | else: 83 | achieved_list.append(True) 84 | 85 | device = preds[0]["boxes"].device 86 | achieved_tensor = torch.tensor(achieved_list, device=device) 87 | 88 | return achieved_tensor 89 | -------------------------------------------------------------------------------- /mart/callbacks/__init__.py: -------------------------------------------------------------------------------- 1 | # All Lightning callbacks dependent on lightning, so we don't import mart.callbacks by default. 2 | from ..utils.imports import _HAS_TORCHVISION 3 | from .adversary_connector import * 4 | from .eval_mode import * 5 | from .gradients import * 6 | from .logging import * 7 | from .metrics import * 8 | from .no_grad_mode import * 9 | from .progress_bar import * 10 | 11 | if _HAS_TORCHVISION: 12 | from .visualizer import * 13 | -------------------------------------------------------------------------------- /mart/callbacks/eval_mode.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2022 Intel Corporation 3 | # 4 | # SPDX-License-Identifier: BSD-3-Clause 5 | # 6 | 7 | from lightning.pytorch.callbacks import Callback 8 | 9 | __all__ = ["AttackInEvalMode"] 10 | 11 | 12 | class AttackInEvalMode(Callback): 13 | """Switch the model into eval mode during attack.""" 14 | 15 | def __init__(self): 16 | self.training_mode_status = None 17 | 18 | def on_train_start(self, trainer, model): 19 | self.training_mode_status = model.training 20 | model.train(False) 21 | 22 | def on_train_end(self, trainer, model): 23 | assert self.training_mode_status is not None 24 | 25 | # Resume the previous training status of the model. 26 | model.train(self.training_mode_status) 27 | -------------------------------------------------------------------------------- /mart/callbacks/gradients.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2022 Intel Corporation 3 | # 4 | # SPDX-License-Identifier: BSD-3-Clause 5 | # 6 | 7 | from __future__ import annotations 8 | 9 | from collections.abc import Iterable 10 | 11 | from lightning.pytorch.callbacks import Callback 12 | from lightning.pytorch.utilities import grad_norm 13 | 14 | __all__ = ["GradientMonitor"] 15 | 16 | 17 | class GradientMonitor(Callback): 18 | def __init__( 19 | self, 20 | norm_types: float | int | str | Iterable[float | int | str], 21 | frequency: int = 100, 22 | histogram: bool = True, 23 | clipped: bool = True, 24 | ): 25 | if not isinstance(norm_types, Iterable): 26 | norm_types = [norm_types] 27 | 28 | self.norm_types = norm_types 29 | self.frequency = frequency 30 | self.histogram = histogram 31 | self.clipped = clipped 32 | 33 | self.should_log = False 34 | 35 | def on_train_batch_start(self, trainer, pl_module, batch, batch_idx, unused=0): 36 | self.should_log = batch_idx % self.frequency == 0 37 | 38 | def on_before_optimizer_step(self, trainer, pl_module, optimizer, opt_idx): 39 | if not self.should_log: 40 | return 41 | 42 | # Pre-clipping 43 | self.log_grad_norm(trainer, pl_module, self.norm_types) 44 | 45 | if self.histogram: 46 | self.log_grad_histogram(trainer, pl_module) 47 | 48 | def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, unused=0): 49 | if not self.clipped: 50 | return 51 | 52 | if not self.should_log: 53 | return 54 | 55 | # Post-clipping 56 | postfix = ".clipped_grad" 57 | 58 | self.log_grad_norm(trainer, pl_module, self.norm_types, postfix=postfix) 59 | 60 | if self.histogram: 61 | self.log_grad_histogram(trainer, pl_module, postfix=postfix) 62 | 63 | def log_grad_norm(self, trainer, pl_module, norm_types, prefix="gradients/", postfix=""): 64 | for norm_type in self.norm_types: 65 | norms = grad_norm(pl_module, norm_type) 66 | norms = {f"{prefix}{key}{postfix}": value for key, value in norms.items()} 67 | 68 | pl_module.log_dict(norms) 69 | 70 | def log_grad_histogram(self, trainer, pl_module, prefix="gradients/", postfix=".grad"): 71 | for name, param in pl_module.named_parameters(): 72 | if not param.requires_grad: 73 | continue 74 | 75 | self.log_histogram(trainer, f"{prefix}{name}{postfix}", param.grad) 76 | 77 | def log_histogram(self, trainer, name, values): 78 | # Add histogram to each logger that supports it 79 | for logger in trainer.loggers: 80 | # FIXME: Should we just use isinstance(logger.experiment, SummaryWriter)? 81 | if not hasattr(logger.experiment, "add_histogram"): 82 | continue 83 | 84 | logger.experiment.add_histogram(name, values, global_step=trainer.global_step) 85 | -------------------------------------------------------------------------------- /mart/callbacks/logging.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2025 Intel Corporation 3 | # 4 | # SPDX-License-Identifier: BSD-3-Clause 5 | # 6 | 7 | from __future__ import annotations 8 | 9 | import logging 10 | from typing import Sequence 11 | 12 | from lightning.pytorch.callbacks import Callback 13 | from torch import Tensor 14 | 15 | from ..nn.nn import DotDict 16 | 17 | logger = logging.getLogger(__name__) 18 | 19 | __all__ = ["Logging"] 20 | 21 | 22 | class Logging(Callback): 23 | """For models returning a dictionary, we can configure the callback to log scalars from the 24 | outputs, calculate and log metrics.""" 25 | 26 | def __init__( 27 | self, 28 | train_step_log: Sequence | dict = None, 29 | val_step_log: Sequence | dict = None, 30 | test_step_log: Sequence | dict = None, 31 | ): 32 | super().__init__() 33 | 34 | # Be backwards compatible by turning list into dict where each item is its own key-value 35 | if isinstance(train_step_log, Sequence): 36 | train_step_log = {item: {"key": item, "prog_bar": True} for item in train_step_log} 37 | train_step_log = train_step_log or {} 38 | 39 | # Be backwards compatible by turning list into dict where each item is its own key-value 40 | if isinstance(val_step_log, Sequence): 41 | val_step_log = {item: {"key": item, "prog_bar": True} for item in val_step_log} 42 | val_step_log = val_step_log or {} 43 | 44 | # Be backwards compatible by turning list into dict where each item is its own key-value 45 | if isinstance(test_step_log, Sequence): 46 | test_step_log = {item: {"key": item, "prog_bar": True} for item in test_step_log} 47 | test_step_log = test_step_log or {} 48 | 49 | self.step_log = { 50 | "train": train_step_log, 51 | "val": val_step_log, 52 | "test": test_step_log, 53 | } 54 | 55 | def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): 56 | return self.on_batch_end(outputs, prefix="train") 57 | 58 | def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): 59 | return self.on_batch_end(outputs, prefix="val") 60 | 61 | def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): 62 | return self.on_batch_end(outputs, prefix="test") 63 | 64 | # 65 | # Utilities 66 | # 67 | def on_batch_end(self, outputs, *, prefix: str): 68 | # Convert to DotDict, so that we can use a dot-connected string as a key to find a value deep in the dictionary. 69 | outputs = DotDict(outputs) 70 | 71 | step_log = self.step_log[prefix] 72 | for log_name, cfg in step_log.items(): 73 | key, prog_bar = cfg["key"], cfg["prog_bar"] 74 | self.log(f"{prefix}/{log_name}", outputs[key], prog_bar=prog_bar) 75 | -------------------------------------------------------------------------------- /mart/callbacks/no_grad_mode.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2022 Intel Corporation 3 | # 4 | # SPDX-License-Identifier: BSD-3-Clause 5 | # 6 | 7 | from lightning.pytorch.callbacks import Callback 8 | 9 | __all__ = ["ModelParamsNoGrad"] 10 | 11 | 12 | class ModelParamsNoGrad(Callback): 13 | """No gradient for model parameters during attack. 14 | 15 | This callback should not change the result. Don't use unless an attack runs faster. 16 | """ 17 | 18 | def on_train_start(self, trainer, model): 19 | for param in model.parameters(): 20 | param.requires_grad_(False) 21 | 22 | def on_train_end(self, trainer, model): 23 | for param in model.parameters(): 24 | param.requires_grad_(True) 25 | -------------------------------------------------------------------------------- /mart/callbacks/progress_bar.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2022 Intel Corporation 3 | # 4 | # SPDX-License-Identifier: BSD-3-Clause 5 | # 6 | 7 | from typing import Any 8 | 9 | import lightning.pytorch as pl 10 | from lightning.pytorch.callbacks import TQDMProgressBar 11 | from lightning.pytorch.utilities.rank_zero import rank_zero_only 12 | 13 | __all__ = ["ProgressBar"] 14 | 15 | 16 | class ProgressBar(TQDMProgressBar): 17 | """Display progress bar of attack iterations with the gain value.""" 18 | 19 | def __init__(self, *args, enable=True, **kwargs): 20 | if "process_position" not in kwargs: 21 | # Automatically place the progress bar by rank if position is not specified. 22 | # rank starts with 0 23 | rank_id = rank_zero_only.rank 24 | # Adversary progress bars start at position 1, because the main progress bar takes position 0. 25 | process_position = rank_id + 1 26 | kwargs["process_position"] = process_position 27 | 28 | super().__init__(*args, **kwargs) 29 | 30 | if not enable: 31 | self.disable() 32 | 33 | def init_train_tqdm(self): 34 | bar = super().init_train_tqdm() 35 | bar.leave = False 36 | bar.set_description("Attack") 37 | bar.unit = "iter" 38 | 39 | return bar 40 | 41 | def on_train_epoch_start(self, trainer: pl.Trainer, *_: Any) -> None: 42 | super().on_train_epoch_start(trainer) 43 | 44 | # So that it does not display Epoch n. 45 | rank_id = rank_zero_only.rank 46 | self.train_progress_bar.set_description(f"Attack@rank{rank_id}") 47 | -------------------------------------------------------------------------------- /mart/callbacks/visualizer.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2022 Intel Corporation 3 | # 4 | # SPDX-License-Identifier: BSD-3-Clause 5 | # 6 | 7 | import os 8 | 9 | import torch 10 | from lightning.pytorch.callbacks import Callback 11 | from torchvision.transforms import ToPILImage 12 | 13 | __all__ = ["PerturbedImageVisualizer"] 14 | 15 | 16 | class PerturbedImageVisualizer(Callback): 17 | """Save adversarial images as files.""" 18 | 19 | def __init__(self, folder): 20 | super().__init__() 21 | 22 | # FIXME: This should use the Trainer's logging directory. 23 | self.folder = folder 24 | self.convert = ToPILImage() 25 | 26 | if not os.path.isdir(self.folder): 27 | os.makedirs(self.folder) 28 | 29 | def on_train_batch_end(self, trainer, model, outputs, batch, batch_idx): 30 | # Save canonical input and target for on_train_end 31 | self.input = batch[0] 32 | self.target = batch[1] 33 | 34 | def on_train_end(self, trainer, model): 35 | # FIXME: We should really just save this to outputs instead of recomputing adv_input 36 | with torch.no_grad(): 37 | adv_input, _target = model(self.input, self.target) 38 | 39 | for img, tgt in zip(adv_input, self.target): 40 | fname = tgt["file_name"] 41 | fpath = os.path.join(self.folder, fname) 42 | im = self.convert(img / 255) 43 | im.save(fpath) 44 | -------------------------------------------------------------------------------- /mart/configs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IntelLabs/MART/c86a059a9177b26e1fb0404216145059f826eddb/mart/configs/__init__.py -------------------------------------------------------------------------------- /mart/configs/attack/adaptive_gradient_ascent.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - /optimization@: adaptive_sgd 3 | 4 | optimizer: 5 | # Change the optimizer borrowed from training LitModular. 6 | maximize: True 7 | lr: ${..lr} 8 | 9 | lr_scheduler: 10 | # The adaptive learning rate scheduler monitors some variable and adapt to it. 11 | monitor: ??? 12 | scheduler: 13 | # learning rate divided by 2 14 | factor: 0.5 15 | # minimum learning rate 1/255 16 | min_lr: 1 17 | # Set verbose true to debug the learning rate. 18 | verbose: true 19 | # We usually try to maximize something in Adversary. 20 | mode: max 21 | 22 | max_iters: ??? 23 | lr: ??? 24 | -------------------------------------------------------------------------------- /mart/configs/attack/adversary.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - /callbacks@callbacks: [progress_bar] 3 | 4 | _target_: mart.attack.Adversary 5 | _convert_: all 6 | optimizer: 7 | maximize: True 8 | lr_scheduler: null 9 | composer: ??? 10 | gain: ??? 11 | gradient_modifier: null 12 | objective: null 13 | enforcer: ??? 14 | attacker: null 15 | -------------------------------------------------------------------------------- /mart/configs/attack/classification_fgsm_linf.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - adversary 3 | - fgm 4 | - linf 5 | - composer: additive 6 | - gradient_modifier: sign 7 | - gain: cross_entropy 8 | - objective: misclassification 9 | 10 | eps: ??? 11 | max_iters: 1 12 | -------------------------------------------------------------------------------- /mart/configs/attack/classification_pgd_linf.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - adversary 3 | - pgd 4 | - linf 5 | - composer: additive 6 | - gradient_modifier: sign 7 | - gain: cross_entropy 8 | - objective: misclassification 9 | 10 | eps: ??? 11 | lr: ??? 12 | max_iters: ??? 13 | -------------------------------------------------------------------------------- /mart/configs/attack/composer/additive.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default 3 | - modules: [additive] 4 | 5 | sequence: 6 | seq010: 7 | additive: [perturbation, input] 8 | -------------------------------------------------------------------------------- /mart/configs/attack/composer/default.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - perturber: default 3 | 4 | _target_: mart.attack.Composer 5 | return_final_output_only: true 6 | modules: 7 | ??? 8 | # Example: additive, mask, overlay 9 | sequence: 10 | ??? 11 | # Wire modules, with input, 12 | -------------------------------------------------------------------------------- /mart/configs/attack/composer/mask_additive.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default 3 | - modules: [mask, additive] 4 | 5 | sequence: 6 | seq010: 7 | mask: [perturbation, target.perturbable_mask] 8 | seq020: 9 | additive: [mask, input] 10 | -------------------------------------------------------------------------------- /mart/configs/attack/composer/modules/additive.yaml: -------------------------------------------------------------------------------- 1 | additive: 2 | _target_: mart.attack.composer.Additive 3 | -------------------------------------------------------------------------------- /mart/configs/attack/composer/modules/fake_renderer.yaml: -------------------------------------------------------------------------------- 1 | fake_renderer: 2 | _target_: mart.attack.composer.FakeRenderer 3 | -------------------------------------------------------------------------------- /mart/configs/attack/composer/modules/mask.yaml: -------------------------------------------------------------------------------- 1 | mask: 2 | _target_: mart.attack.composer.Mask 3 | -------------------------------------------------------------------------------- /mart/configs/attack/composer/modules/overlay.yaml: -------------------------------------------------------------------------------- 1 | overlay: 2 | _target_: mart.attack.composer.Overlay 3 | -------------------------------------------------------------------------------- /mart/configs/attack/composer/modules/pert_extract_rect.yaml: -------------------------------------------------------------------------------- 1 | pert_extract_rect: 2 | _target_: mart.attack.composer.PertExtractRect 3 | -------------------------------------------------------------------------------- /mart/configs/attack/composer/modules/pert_image_base.yaml: -------------------------------------------------------------------------------- 1 | pert_image_base: 2 | _target_: mart.attack.composer.PertImageBase 3 | fpath: ??? 4 | -------------------------------------------------------------------------------- /mart/configs/attack/composer/modules/pert_rect_perspective.yaml: -------------------------------------------------------------------------------- 1 | pert_rect_perspective: 2 | _target_: mart.attack.composer.PertRectPerspective 3 | -------------------------------------------------------------------------------- /mart/configs/attack/composer/modules/pert_rect_size.yaml: -------------------------------------------------------------------------------- 1 | pert_rect_size: 2 | _target_: mart.attack.composer.PertRectSize 3 | -------------------------------------------------------------------------------- /mart/configs/attack/composer/overlay.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default 3 | - modules: [overlay] 4 | 5 | sequence: 6 | seq010: 7 | overlay: [perturbation, input, target.perturbable_mask] 8 | -------------------------------------------------------------------------------- /mart/configs/attack/composer/perturber/default.yaml: -------------------------------------------------------------------------------- 1 | _target_: mart.attack.Perturber 2 | initializer: ??? 3 | # Avoid null projector here due to the chance of overriding projectors defined in other config files. 4 | projector: ??? 5 | -------------------------------------------------------------------------------- /mart/configs/attack/composer/perturber/initializer/constant.yaml: -------------------------------------------------------------------------------- 1 | _target_: mart.attack.initializer.Constant 2 | constant: ??? 3 | -------------------------------------------------------------------------------- /mart/configs/attack/composer/perturber/initializer/image.yaml: -------------------------------------------------------------------------------- 1 | _target_: mart.attack.initializer.Image 2 | path: ??? 3 | -------------------------------------------------------------------------------- /mart/configs/attack/composer/perturber/initializer/uniform.yaml: -------------------------------------------------------------------------------- 1 | _target_: mart.attack.initializer.Uniform 2 | min: ??? 3 | max: ??? 4 | round: false 5 | -------------------------------------------------------------------------------- /mart/configs/attack/composer/perturber/initializer/uniform_lp.yaml: -------------------------------------------------------------------------------- 1 | _target_: mart.attack.initializer.UniformLp 2 | eps: ??? 3 | -------------------------------------------------------------------------------- /mart/configs/attack/composer/perturber/projector/linf.yaml: -------------------------------------------------------------------------------- 1 | _target_: mart.attack.projector.Lp 2 | # p is actually torch.inf by default. 3 | p: 4 | _target_: builtins.float 5 | _args_: ["inf"] 6 | eps: ??? 7 | -------------------------------------------------------------------------------- /mart/configs/attack/composer/perturber/projector/linf_additive_range.yaml: -------------------------------------------------------------------------------- 1 | _target_: mart.attack.projector.LinfAdditiveRange 2 | eps: ??? 3 | min: 0 4 | max: 255 5 | -------------------------------------------------------------------------------- /mart/configs/attack/composer/perturber/projector/lp_additive_range.yaml: -------------------------------------------------------------------------------- 1 | _target_: mart.attack.projector.LpAdditiveRangeProjector 2 | p: ??? 3 | eps: ??? 4 | min: 0 5 | max: 255 6 | -------------------------------------------------------------------------------- /mart/configs/attack/composer/perturber/projector/mask_range.yaml: -------------------------------------------------------------------------------- 1 | _target_: mart.attack.projector.Compose 2 | projectors: 3 | - _target_: mart.attack.projector.Mask 4 | - _target_: mart.attack.projector.Range 5 | quantize: false 6 | min: 0 7 | max: 255 8 | -------------------------------------------------------------------------------- /mart/configs/attack/composer/perturber/projector/range.yaml: -------------------------------------------------------------------------------- 1 | _target_: mart.attack.projector.Range 2 | quantize: false 3 | min: 0 4 | max: 255 5 | -------------------------------------------------------------------------------- /mart/configs/attack/composer/perturber/universal.yaml: -------------------------------------------------------------------------------- 1 | _target_: mart.attack.UniversalPerturber 2 | shape: ??? 3 | initializer: ??? 4 | # Avoid null projector here due to the chance of overriding projectors defined in other config files. 5 | projector: ??? 6 | -------------------------------------------------------------------------------- /mart/configs/attack/enforcer/constraints/integer.yaml: -------------------------------------------------------------------------------- 1 | integer: 2 | _target_: mart.attack.enforcer.Integer 3 | rtol: 0 4 | atol: 0 5 | -------------------------------------------------------------------------------- /mart/configs/attack/enforcer/constraints/lp.yaml: -------------------------------------------------------------------------------- 1 | lp: 2 | _target_: mart.attack.enforcer.Lp 3 | eps: ??? 4 | # p is inf by default. 5 | # p: ??? 6 | # Calculate Lp over the CHW dimensions for images. 7 | dim: [-1, -2, -3] 8 | -------------------------------------------------------------------------------- /mart/configs/attack/enforcer/constraints/mask.yaml: -------------------------------------------------------------------------------- 1 | mask: 2 | _target_: mart.attack.enforcer.Mask 3 | -------------------------------------------------------------------------------- /mart/configs/attack/enforcer/constraints/pixel_range.yaml: -------------------------------------------------------------------------------- 1 | pixel_range: 2 | _target_: mart.attack.enforcer.Range 3 | min: 0 4 | max: 255 5 | -------------------------------------------------------------------------------- /mart/configs/attack/enforcer/default.yaml: -------------------------------------------------------------------------------- 1 | _target_: mart.attack.Enforcer 2 | constraints: ??? 3 | -------------------------------------------------------------------------------- /mart/configs/attack/fgm.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - composer/perturber/initializer: constant 3 | - /optimizer@optimizer: sgd 4 | 5 | max_iters: 1 6 | eps: ??? 7 | 8 | optimizer: 9 | lr: ${..eps} 10 | 11 | composer: 12 | perturber: 13 | initializer: 14 | constant: 0 15 | projector: 16 | eps: ${....eps} 17 | 18 | # We can turn off progress bar for one-step attack. 19 | callbacks: 20 | progress_bar: 21 | enable: false 22 | -------------------------------------------------------------------------------- /mart/configs/attack/gain/cross_entropy.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - modular 3 | 4 | module: 5 | _target_: torch.nn.CrossEntropyLoss 6 | reduction: none 7 | -------------------------------------------------------------------------------- /mart/configs/attack/gain/dlr.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - modular 3 | 4 | module: 5 | _target_: autoattack.autopgd_base.APGDAttack.dlr_loss 6 | # A class method needs a placeholder self as a positional argument. 7 | _partial_: true 8 | _args_: 9 | - null 10 | -------------------------------------------------------------------------------- /mart/configs/attack/gain/modular.yaml: -------------------------------------------------------------------------------- 1 | _target_: mart.nn.CallWith 2 | module: 3 | _target_: ??? 4 | _call_with_args_: 5 | - logits 6 | - target 7 | -------------------------------------------------------------------------------- /mart/configs/attack/gain/rcnn_class_background.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - modular 3 | 4 | module: 5 | _target_: mart.attack.gain.RoIHeadTargetClass 6 | # Try to classify as background. 7 | class_index: 0 8 | targeted: true 9 | _call_with_args_: 10 | - box_head.class_logits 11 | - rpn_predictor.boxes 12 | -------------------------------------------------------------------------------- /mart/configs/attack/gain/rcnn_training_loss.yaml: -------------------------------------------------------------------------------- 1 | _target_: mart.nn.CallWith 2 | module: 3 | _target_: mart.nn.Sum 4 | _call_with_args_: 5 | - "losses_and_detections.training.loss_objectness" 6 | - "losses_and_detections.training.loss_rpn_box_reg" 7 | - "losses_and_detections.training.loss_classifier" 8 | - "losses_and_detections.training.loss_box_reg" 9 | -------------------------------------------------------------------------------- /mart/configs/attack/gradient_ascent.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - /optimizer@optimizer: sgd 3 | 4 | max_iters: ??? 5 | lr: ??? 6 | 7 | optimizer: 8 | maximize: True 9 | lr: ${..lr} 10 | -------------------------------------------------------------------------------- /mart/configs/attack/gradient_modifier/lp_normalizer.yaml: -------------------------------------------------------------------------------- 1 | _target_: mart.attack.gradient_modifier.LpNormalizer 2 | p: ??? 3 | -------------------------------------------------------------------------------- /mart/configs/attack/gradient_modifier/sign.yaml: -------------------------------------------------------------------------------- 1 | _target_: mart.attack.gradient_modifier.Sign 2 | -------------------------------------------------------------------------------- /mart/configs/attack/linf.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - composer/perturber/projector: linf_additive_range 3 | - enforcer: default 4 | - enforcer/constraints: lp 5 | 6 | enforcer: 7 | constraints: 8 | lp: 9 | p: 10 | _target_: builtins.float 11 | _args_: ["inf"] 12 | eps: ${....eps} 13 | 14 | eps: ??? 15 | -------------------------------------------------------------------------------- /mart/configs/attack/mask.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - composer/perturber/projector: mask_range 3 | - enforcer: default 4 | - enforcer/constraints: [mask, pixel_range] 5 | -------------------------------------------------------------------------------- /mart/configs/attack/object_detection_lp_patch_adversary.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - adversary 3 | - /optimizer@optimizer: adam 4 | - enforcer: default 5 | - composer: default 6 | - composer/perturber/initializer: uniform 7 | - composer/perturber/projector: linf 8 | - composer/modules: 9 | [ 10 | pert_rect_size, 11 | pert_extract_rect, 12 | pert_image_base, 13 | pert_rect_perspective, 14 | overlay, 15 | ] 16 | - gradient_modifier: sign 17 | - gain: rcnn_training_loss 18 | - objective: zero_ap 19 | - override /callbacks@callbacks: [progress_bar, image_visualizer] 20 | 21 | max_iters: ??? 22 | lr: ??? 23 | eps: ??? 24 | 25 | optimizer: 26 | maximize: True 27 | lr: ${..lr} 28 | 29 | enforcer: 30 | # No constraints with complex renderer in the pipeline. 31 | # TODO: Constraint on digital perturbation? 32 | constraints: {} 33 | 34 | composer: 35 | perturber: 36 | initializer: 37 | min: ${negate:${....eps}} 38 | max: ${....eps} 39 | projector: 40 | eps: ${....eps} 41 | modules: 42 | pert_image_base: 43 | fpath: ??? 44 | sequence: 45 | seq010: 46 | pert_rect_size: ["target.coords"] 47 | seq020: 48 | pert_extract_rect: 49 | ["perturbation", "pert_rect_size.height", "pert_rect_size.width"] 50 | seq030: 51 | pert_image_base: ["pert_extract_rect"] 52 | seq040: 53 | pert_rect_perspective: ["pert_image_base", "input", "target.coords"] 54 | seq050: 55 | overlay: ["pert_rect_perspective", "input", "target.perturbable_mask"] 56 | -------------------------------------------------------------------------------- /mart/configs/attack/object_detection_lp_patch_adversary_simulation.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - object_detection_lp_patch_adversary 3 | 4 | composer: 5 | modules: 6 | fake_renderer: 7 | _target_: mart.attack.composer.FakeRenderer 8 | 9 | sequence: 10 | seq060: 11 | # Ignore output from overlay. 12 | fake_renderer: 13 | ["pert_image_base", "pert_rect_perspective", "target.renderer"] 14 | -------------------------------------------------------------------------------- /mart/configs/attack/object_detection_mask_adversary.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - adversary 3 | - gradient_ascent 4 | - mask 5 | - composer: overlay 6 | - composer/perturber/initializer: constant 7 | - gradient_modifier: sign 8 | - gain: rcnn_training_loss 9 | - objective: zero_ap 10 | 11 | max_iters: ??? 12 | lr: ??? 13 | 14 | # Start with grey perturbation in the overlay mode. 15 | composer: 16 | perturber: 17 | initializer: 18 | constant: 127 19 | -------------------------------------------------------------------------------- /mart/configs/attack/object_detection_mask_adversary_missed.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - adversary 3 | - gradient_ascent 4 | - mask 5 | - composer: overlay 6 | - composer/perturber/initializer: constant 7 | - gradient_modifier: sign 8 | - gain: rcnn_class_background 9 | - objective: object_detection_missed 10 | 11 | max_iters: ??? 12 | lr: ??? 13 | 14 | # Start with grey perturbation in the overlay mode. 15 | composer: 16 | perturber: 17 | initializer: 18 | constant: 127 19 | -------------------------------------------------------------------------------- /mart/configs/attack/object_detection_patch_adversary.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - adversary 3 | - /optimizer@optimizer: adam 4 | - enforcer: default 5 | - composer: default 6 | - composer/perturber/initializer: uniform 7 | - composer/perturber/projector: range 8 | - composer/modules: 9 | [pert_rect_size, pert_extract_rect, pert_rect_perspective, overlay] 10 | - gradient_modifier: sign 11 | - gain: rcnn_training_loss 12 | - objective: zero_ap 13 | - override /callbacks@callbacks: [progress_bar, image_visualizer] 14 | 15 | max_iters: ??? 16 | lr: ??? 17 | 18 | optimizer: 19 | maximize: True 20 | lr: ${..lr} 21 | 22 | enforcer: 23 | # No constraints with complex renderer in the pipeline. 24 | # TODO: Constraint on digital perturbation? 25 | constraints: {} 26 | 27 | composer: 28 | perturber: 29 | initializer: 30 | min: 0 31 | max: 255 32 | projector: 33 | min: 0 34 | max: 255 35 | sequence: 36 | seq010: 37 | pert_rect_size: ["target.coords"] 38 | seq020: 39 | pert_extract_rect: 40 | ["perturbation", "pert_rect_size.height", "pert_rect_size.width"] 41 | seq040: 42 | pert_rect_perspective: ["pert_extract_rect", "input", "target.coords"] 43 | seq050: 44 | overlay: ["pert_rect_perspective", "input", "target.perturbable_mask"] 45 | -------------------------------------------------------------------------------- /mart/configs/attack/object_detection_patch_adversary_simulation.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - object_detection_patch_adversary 3 | 4 | composer: 5 | modules: 6 | fake_renderer: 7 | _target_: mart.attack.composer.FakeRenderer 8 | 9 | sequence: 10 | seq060: 11 | # Ignore output from overlay. 12 | fake_renderer: 13 | ["pert_extract_rect", "pert_rect_perspective", "target.renderer"] 14 | -------------------------------------------------------------------------------- /mart/configs/attack/objective/misclassification.yaml: -------------------------------------------------------------------------------- 1 | _target_: mart.nn.CallWith 2 | module: 3 | _target_: mart.attack.objective.Mispredict 4 | _call_with_args_: 5 | - ${model.output_preds_key} 6 | - ${model.output_target_key} 7 | -------------------------------------------------------------------------------- /mart/configs/attack/objective/object_detection_missed.yaml: -------------------------------------------------------------------------------- 1 | _target_: mart.nn.CallWith 2 | module: 3 | _target_: mart.attack.objective.Missed 4 | confidence_threshold: 0.0 5 | _call_with_args_: 6 | - ${model.output_preds_key} 7 | -------------------------------------------------------------------------------- /mart/configs/attack/objective/zero_ap.yaml: -------------------------------------------------------------------------------- 1 | _target_: mart.nn.CallWith 2 | module: 3 | _target_: mart.attack.objective.ZeroAP 4 | iou_threshold: 0.5 5 | confidence_threshold: 0.0 6 | _call_with_args_: 7 | - ${model.output_preds_key} 8 | - ${model.output_target_key} 9 | -------------------------------------------------------------------------------- /mart/configs/attack/pgd.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - composer/perturber/initializer: uniform 3 | - /optimizer@optimizer: sgd 4 | 5 | max_iters: ??? 6 | eps: ??? 7 | lr: ??? 8 | 9 | optimizer: 10 | lr: ${..lr} 11 | 12 | composer: 13 | perturber: 14 | initializer: 15 | min: ${negate:${....eps}} 16 | max: ${....eps} 17 | projector: 18 | eps: ${....eps} 19 | -------------------------------------------------------------------------------- /mart/configs/batch_c15n/dict.yaml: -------------------------------------------------------------------------------- 1 | # We expect the original batch looks like `{"input": tensor, ...}` with the default parameters. 2 | _target_: mart.transforms.DictBatchC15n 3 | input_key: input 4 | -------------------------------------------------------------------------------- /mart/configs/batch_c15n/dict_imagenet_normalized.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - dict 3 | - transform: imagenet_to_255 4 | - transform@untransform: 255_to_imagenet 5 | 6 | input_key: image 7 | -------------------------------------------------------------------------------- /mart/configs/batch_c15n/input_only.yaml: -------------------------------------------------------------------------------- 1 | _target_: mart.transforms.InputOnlyBatchC15n 2 | -------------------------------------------------------------------------------- /mart/configs/batch_c15n/list.yaml: -------------------------------------------------------------------------------- 1 | # We expect the original batch looks like `[input, target]` with the default parameters. 2 | _target_: mart.transforms.ListBatchC15n 3 | input_key: 0 4 | target_size: 1 5 | -------------------------------------------------------------------------------- /mart/configs/batch_c15n/list_image_01.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - list 3 | - transform: times_255_and_round 4 | - transform@untransform: divided_by_255 5 | 6 | input_key: 0 7 | -------------------------------------------------------------------------------- /mart/configs/batch_c15n/transform/255_to_imagenet.yaml: -------------------------------------------------------------------------------- 1 | _target_: torchvision.transforms.Normalize 2 | # from 0-1 scale statistics: mean=[0.485, 0.456, 0.406]*255 std=[0.229, 0.224, 0.225]*255 3 | mean: [123.6750, 116.2800, 103.5300] 4 | std: [58.3950, 57.1200, 57.3750] 5 | -------------------------------------------------------------------------------- /mart/configs/batch_c15n/transform/divided_by_255.yaml: -------------------------------------------------------------------------------- 1 | _target_: torchvision.transforms.Normalize 2 | mean: 0 3 | std: 255 4 | -------------------------------------------------------------------------------- /mart/configs/batch_c15n/transform/imagenet_to_255.yaml: -------------------------------------------------------------------------------- 1 | _target_: torchvision.transforms.Compose 2 | transforms: 3 | - _target_: mart.transforms.Denormalize 4 | # from 0-1 scale statistics: mean=[0.485, 0.456, 0.406]*255 std=[0.229, 0.224, 0.225]*255 5 | center: 6 | _target_: torch.as_tensor 7 | data: [123.6750, 116.2800, 103.5300] 8 | scale: 9 | _target_: torch.as_tensor 10 | data: [58.3950, 57.1200, 57.3750] 11 | - _target_: torch.fake_quantize_per_tensor_affine 12 | _partial_: true 13 | # (x/1+0).round().clamp(0, 255) * 1 14 | scale: 1 15 | zero_point: 0 16 | quant_min: 0 17 | quant_max: 255 18 | -------------------------------------------------------------------------------- /mart/configs/batch_c15n/transform/times_255_and_round.yaml: -------------------------------------------------------------------------------- 1 | _target_: torchvision.transforms.Compose 2 | transforms: 3 | - _target_: mart.transforms.Denormalize 4 | center: 0 5 | scale: 255 6 | # Fix potential numeric error. 7 | - _target_: torch.fake_quantize_per_tensor_affine 8 | _partial_: true 9 | # (x/1+0).round().clamp(0, 255) * 1 10 | scale: 1 11 | zero_point: 0 12 | quant_min: 0 13 | quant_max: 255 14 | -------------------------------------------------------------------------------- /mart/configs/batch_c15n/tuple.yaml: -------------------------------------------------------------------------------- 1 | # We expect the original batch looks like `(input, target)` with the default parameters. 2 | _target_: mart.transforms.TupleBatchC15n 3 | input_key: 0 4 | target_size: 1 5 | -------------------------------------------------------------------------------- /mart/configs/callbacks/adversary_connector.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - /batch_c15n@adversary_connector.batch_c15n: tuple 3 | 4 | adversary_connector: 5 | _target_: mart.callbacks.AdversaryConnector 6 | adversary: null 7 | train_adversary: null 8 | val_adversary: null 9 | test_adversary: null 10 | -------------------------------------------------------------------------------- /mart/configs/callbacks/attack_in_eval_mode.yaml: -------------------------------------------------------------------------------- 1 | attack_in_eval_mode: 2 | _target_: mart.callbacks.AttackInEvalMode 3 | -------------------------------------------------------------------------------- /mart/configs/callbacks/default.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - model_checkpoint 3 | - model_summary 4 | - rich_progress_bar 5 | - _self_ 6 | 7 | model_checkpoint: 8 | dirpath: ${paths.output_dir}/checkpoints 9 | filename: "epoch_{epoch:03d}" 10 | monitor: "val/acc" 11 | mode: "max" 12 | save_last: True 13 | auto_insert_metric_name: False 14 | 15 | model_summary: 16 | max_depth: -1 17 | -------------------------------------------------------------------------------- /mart/configs/callbacks/early_stopping.yaml: -------------------------------------------------------------------------------- 1 | # https://pytorch-lightning.readthedocs.io/en/latest/api/lightning.pytorch.callbacks.EarlyStopping.html 2 | 3 | # Monitor a metric and stop training when it stops improving. 4 | # Look at the above link for more detailed information. 5 | early_stopping: 6 | _target_: lightning.pytorch.callbacks.EarlyStopping 7 | monitor: ??? # quantity to be monitored, must be specified !!! 8 | min_delta: 0. # minimum change in the monitored quantity to qualify as an improvement 9 | patience: 3 # number of checks with no improvement after which training will be stopped 10 | verbose: False # verbosity mode 11 | mode: "min" # "max" means higher metric value is better, can be also "min" 12 | strict: True # whether to crash the training if monitor is not found in the validation metrics 13 | check_finite: True # when set True, stops training when the monitor becomes NaN or infinite 14 | stopping_threshold: null # stop training immediately once the monitored quantity reaches this threshold 15 | divergence_threshold: null # stop training as soon as the monitored quantity becomes worse than this threshold 16 | check_on_train_epoch_end: null # whether to run early stopping at the end of the training epoch 17 | # log_rank_zero_only: False # this keyword argument isn't available in stable version 18 | -------------------------------------------------------------------------------- /mart/configs/callbacks/gradient_monitor.yaml: -------------------------------------------------------------------------------- 1 | gradient_monitor: 2 | _target_: mart.callbacks.GradientMonitor 3 | norm_types: ["inf", 2] 4 | -------------------------------------------------------------------------------- /mart/configs/callbacks/image_visualizer.yaml: -------------------------------------------------------------------------------- 1 | image_visualizer: 2 | _target_: mart.callbacks.PerturbedImageVisualizer 3 | folder: ${paths.output_dir}/adversarial_examples 4 | -------------------------------------------------------------------------------- /mart/configs/callbacks/lr_monitor.yaml: -------------------------------------------------------------------------------- 1 | lr_monitor: 2 | _target_: lightning.pytorch.callbacks.LearningRateMonitor 3 | logging_interval: "step" 4 | -------------------------------------------------------------------------------- /mart/configs/callbacks/model_checkpoint.yaml: -------------------------------------------------------------------------------- 1 | # https://pytorch-lightning.readthedocs.io/en/latest/api/lightning.pytorch.callbacks.ModelCheckpoint.html 2 | 3 | # Save the model periodically by monitoring a quantity. 4 | # Look at the above link for more detailed information. 5 | model_checkpoint: 6 | _target_: lightning.pytorch.callbacks.ModelCheckpoint 7 | dirpath: "${paths.output_dir}/checkpoints/" # directory to save the model file 8 | filename: "epoch_{epoch:03d}" # checkpoint filename 9 | monitor: ??? # name of the logged metric which determines when model is improving 10 | verbose: False # verbosity mode 11 | save_last: True # additionally always save an exact copy of the last checkpoint to a file last.ckpt 12 | save_top_k: 1 # save k best models (determined by above metric) 13 | mode: ??? # "max" means higher metric value is better, can be also "min" 14 | auto_insert_metric_name: True # when True, the checkpoints filenames will contain the metric name 15 | save_weights_only: False # if True, then only the model’s weights will be saved 16 | every_n_train_steps: null # number of training steps between checkpoints 17 | train_time_interval: null # checkpoints are monitored at the specified time interval 18 | every_n_epochs: null # number of epochs between checkpoints 19 | save_on_train_epoch_end: null # whether to run checkpointing at the end of the training epoch or the end of validation 20 | -------------------------------------------------------------------------------- /mart/configs/callbacks/model_summary.yaml: -------------------------------------------------------------------------------- 1 | # https://pytorch-lightning.readthedocs.io/en/latest/api/lightning.pytorch.callbacks.RichModelSummary.html 2 | 3 | # Generates a summary of all layers in a LightningModule with rich text formatting. 4 | # Look at the above link for more detailed information. 5 | model_summary: 6 | _target_: lightning.pytorch.callbacks.RichModelSummary 7 | max_depth: 1 # the maximum depth of layer nesting that the summary will include 8 | -------------------------------------------------------------------------------- /mart/configs/callbacks/no_grad_mode.yaml: -------------------------------------------------------------------------------- 1 | attack_in_eval_mode: 2 | _target_: mart.callbacks.ModelParamsNoGrad 3 | -------------------------------------------------------------------------------- /mart/configs/callbacks/none.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IntelLabs/MART/c86a059a9177b26e1fb0404216145059f826eddb/mart/configs/callbacks/none.yaml -------------------------------------------------------------------------------- /mart/configs/callbacks/progress_bar.yaml: -------------------------------------------------------------------------------- 1 | progress_bar: 2 | _target_: mart.callbacks.ProgressBar 3 | # Enable progress bar for adversary by default. 4 | enable: true 5 | -------------------------------------------------------------------------------- /mart/configs/callbacks/rich_progress_bar.yaml: -------------------------------------------------------------------------------- 1 | # https://pytorch-lightning.readthedocs.io/en/latest/api/lightning.pytorch.callbacks.RichProgressBar.html 2 | 3 | # Create a progress bar with rich text formatting. 4 | # Look at the above link for more detailed information. 5 | rich_progress_bar: 6 | _target_: lightning.pytorch.callbacks.RichProgressBar 7 | -------------------------------------------------------------------------------- /mart/configs/datamodule/carla_patch.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default.yaml 3 | 4 | train_dataset: null 5 | 6 | val_dataset: null 7 | 8 | test_dataset: 9 | _target_: mart.datamodules.coco.CocoDetection 10 | root: ??? 11 | annFile: ${.root}/kwcoco_annotations.json 12 | modalities: ["rgb"] 13 | transforms: 14 | _target_: mart.transforms.Compose 15 | transforms: 16 | - _target_: torchvision.transforms.ToTensor 17 | - _target_: mart.transforms.ConvertCocoPolysToMask 18 | - _target_: mart.transforms.LoadPerturbableMask 19 | perturb_mask_folder: ${....root}/foreground_mask/ 20 | - _target_: mart.transforms.LoadCoords 21 | folder: ${....root}/patch_metadata/ 22 | - _target_: mart.transforms.Denormalize 23 | center: 0 24 | scale: 255 25 | - _target_: torch.fake_quantize_per_tensor_affine 26 | _partial_: true 27 | # (x/1+0).round().clamp(0, 255) * 1 28 | scale: 1 29 | zero_point: 0 30 | quant_min: 0 31 | quant_max: 255 32 | 33 | num_workers: 0 34 | ims_per_batch: 1 35 | 36 | collate_fn: 37 | _target_: hydra.utils.get_method 38 | path: mart.datamodules.coco.collate_fn 39 | -------------------------------------------------------------------------------- /mart/configs/datamodule/carla_patch_rendering.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default.yaml 3 | 4 | train_dataset: null 5 | 6 | val_dataset: null 7 | 8 | test_dataset: 9 | _target_: oscar_datagen_tools.dataset.dataset.CarlaDataset 10 | simulation_run: ??? 11 | modality: "rgb" 12 | annFile: ${.simulation_run}/kwcoco_annotations.json 13 | num_insertion_ticks: 50 14 | transforms: 15 | _target_: mart.transforms.Compose 16 | transforms: 17 | - _target_: torchvision.transforms.ToTensor 18 | - _target_: mart.transforms.ConvertCocoPolysToMask 19 | - _target_: mart.transforms.LoadPerturbableMask 20 | perturb_mask_folder: ${....simulation_run}/foreground_mask/ 21 | - _target_: mart.transforms.LoadCoords 22 | folder: ${....simulation_run}/patch_metadata/ 23 | - _target_: mart.transforms.Denormalize 24 | center: 0 25 | scale: 255 26 | - _target_: torch.fake_quantize_per_tensor_affine 27 | _partial_: true 28 | # (x/1+0).round().clamp(0, 255) * 1 29 | scale: 1 30 | zero_point: 0 31 | quant_min: 0 32 | quant_max: 255 33 | 34 | collate_fn: 35 | _target_: hydra.utils.get_method 36 | path: mart.datamodules.coco.collate_fn 37 | -------------------------------------------------------------------------------- /mart/configs/datamodule/cifar10.yaml: -------------------------------------------------------------------------------- 1 | # 50K Training examples and 10K validation/test shared examples. 2 | defaults: 3 | - default.yaml 4 | 5 | train_dataset: 6 | _target_: torchvision.datasets.CIFAR10 7 | root: ${paths.data_dir} # paths.data_dir is specified in configs/paths/???.yaml 8 | train: true 9 | transform: 10 | _target_: torchvision.transforms.Compose 11 | transforms: 12 | - _target_: torchvision.transforms.RandomCrop 13 | size: 32 14 | padding: 4 15 | pad_if_needed: false 16 | fill: 0 17 | padding_mode: reflect 18 | - _target_: torchvision.transforms.RandomHorizontalFlip 19 | - _target_: torchvision.transforms.ToTensor 20 | - _target_: mart.transforms.Denormalize 21 | center: 0 22 | scale: 255 23 | - _target_: torch.fake_quantize_per_tensor_affine 24 | _partial_: true 25 | # (x/1+0).round().clamp(0, 255) * 1 26 | scale: 1 27 | zero_point: 0 28 | quant_min: 0 29 | quant_max: 255 30 | target_transform: null 31 | download: true 32 | 33 | val_dataset: 34 | _target_: torchvision.datasets.CIFAR10 35 | root: ${paths.data_dir} # paths.data_dir is specified in configs/paths/???.yaml 36 | train: false 37 | transform: 38 | _target_: torchvision.transforms.Compose 39 | transforms: 40 | - _target_: torchvision.transforms.ToTensor 41 | - _target_: mart.transforms.Denormalize 42 | center: 0 43 | scale: 255 44 | - _target_: torch.fake_quantize_per_tensor_affine 45 | _partial_: true 46 | # (x/1+0).round().clamp(0, 255) * 1 47 | scale: 1 48 | zero_point: 0 49 | quant_min: 0 50 | quant_max: 255 51 | target_transform: null 52 | download: true 53 | 54 | test_dataset: ${.val_dataset} 55 | 56 | num_workers: 4 57 | collate_fn: null 58 | num_classes: 10 59 | -------------------------------------------------------------------------------- /mart/configs/datamodule/coco.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default.yaml 3 | 4 | train_dataset: 5 | _target_: mart.datamodules.coco.CocoDetection 6 | root: ${paths.data_dir}/coco/train2017 7 | annFile: ${paths.data_dir}/coco/annotations/instances_train2017.json 8 | transforms: 9 | _target_: mart.transforms.Compose 10 | transforms: 11 | - _target_: torchvision.transforms.ToTensor 12 | - _target_: mart.transforms.ConvertCocoPolysToMask 13 | - _target_: mart.transforms.RandomHorizontalFlip 14 | p: 0.5 15 | - _target_: mart.transforms.Denormalize 16 | center: 0 17 | scale: 255 18 | - _target_: torch.fake_quantize_per_tensor_affine 19 | _partial_: true 20 | # (x/1+0).round().clamp(0, 255) * 1 21 | scale: 1 22 | zero_point: 0 23 | quant_min: 0 24 | quant_max: 255 25 | 26 | val_dataset: 27 | _target_: mart.datamodules.coco.CocoDetection 28 | root: ${paths.data_dir}/coco/val2017 29 | annFile: ${paths.data_dir}/coco/annotations/instances_val2017.json 30 | transforms: 31 | _target_: mart.transforms.Compose 32 | transforms: 33 | - _target_: torchvision.transforms.ToTensor 34 | - _target_: mart.transforms.ConvertCocoPolysToMask 35 | - _target_: mart.transforms.Denormalize 36 | center: 0 37 | scale: 255 38 | - _target_: torch.fake_quantize_per_tensor_affine 39 | _partial_: true 40 | # (x/1+0).round().clamp(0, 255) * 1 41 | scale: 1 42 | zero_point: 0 43 | quant_min: 0 44 | quant_max: 255 45 | 46 | test_dataset: 47 | _target_: mart.datamodules.coco.CocoDetection 48 | root: ${paths.data_dir}/coco/val2017 49 | annFile: ${paths.data_dir}/coco/annotations/instances_val2017.json 50 | transforms: 51 | _target_: mart.transforms.Compose 52 | transforms: 53 | - _target_: torchvision.transforms.ToTensor 54 | - _target_: mart.transforms.ConvertCocoPolysToMask 55 | - _target_: mart.transforms.Denormalize 56 | center: 0 57 | scale: 255 58 | - _target_: torch.fake_quantize_per_tensor_affine 59 | _partial_: true 60 | # (x/1+0).round().clamp(0, 255) * 1 61 | scale: 1 62 | zero_point: 0 63 | quant_min: 0 64 | quant_max: 255 65 | 66 | num_workers: 2 67 | collate_fn: 68 | _target_: hydra.utils.get_method 69 | path: mart.datamodules.coco.collate_fn 70 | -------------------------------------------------------------------------------- /mart/configs/datamodule/coco_perturbable_mask.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - coco 3 | 4 | train_dataset: 5 | transforms: 6 | transforms: 7 | - _target_: torchvision.transforms.ToTensor 8 | # ConvertCocoPolysToMask must be prior to ConvertInstanceSegmentationToPerturbable. 9 | - _target_: mart.transforms.ConvertCocoPolysToMask 10 | - _target_: mart.transforms.RandomHorizontalFlip 11 | p: 0.5 12 | - _target_: mart.transforms.ConvertInstanceSegmentationToPerturbable 13 | - _target_: mart.transforms.Denormalize 14 | center: 0 15 | scale: 255 16 | - _target_: torch.fake_quantize_per_tensor_affine 17 | _partial_: true 18 | # (x/1+0).round().clamp(0, 255) * 1 19 | scale: 1 20 | zero_point: 0 21 | quant_min: 0 22 | quant_max: 255 23 | 24 | val_dataset: 25 | transforms: 26 | transforms: 27 | - _target_: torchvision.transforms.ToTensor 28 | # ConvertCocoPolysToMask must be prior to ConvertInstanceSegmentationToPerturbable. 29 | - _target_: mart.transforms.ConvertCocoPolysToMask 30 | - _target_: mart.transforms.ConvertInstanceSegmentationToPerturbable 31 | - _target_: mart.transforms.Denormalize 32 | center: 0 33 | scale: 255 34 | - _target_: torch.fake_quantize_per_tensor_affine 35 | _partial_: true 36 | # (x/1+0).round().clamp(0, 255) * 1 37 | scale: 1 38 | zero_point: 0 39 | quant_min: 0 40 | quant_max: 255 41 | 42 | test_dataset: 43 | transforms: 44 | transforms: 45 | - _target_: torchvision.transforms.ToTensor 46 | # ConvertCocoPolysToMask must be prior to ConvertInstanceSegmentationToPerturbable. 47 | - _target_: mart.transforms.ConvertCocoPolysToMask 48 | - _target_: mart.transforms.ConvertInstanceSegmentationToPerturbable 49 | - _target_: mart.transforms.Denormalize 50 | center: 0 51 | scale: 255 52 | - _target_: torch.fake_quantize_per_tensor_affine 53 | _partial_: true 54 | # (x/1+0).round().clamp(0, 255) * 1 55 | scale: 1 56 | zero_point: 0 57 | quant_min: 0 58 | quant_max: 255 59 | -------------------------------------------------------------------------------- /mart/configs/datamodule/default.yaml: -------------------------------------------------------------------------------- 1 | _target_: mart.datamodules.LitDataModule 2 | # _convert_: all 3 | 4 | train_dataset: ??? 5 | val_dataset: ??? 6 | test_dataset: ??? 7 | 8 | num_workers: 0 9 | 10 | ims_per_batch: ??? 11 | world_size: ${trainer.devices} 12 | -------------------------------------------------------------------------------- /mart/configs/datamodule/dummy_classification.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default.yaml 3 | 4 | train_dataset: 5 | _target_: torchvision.datasets.FakeData 6 | 7 | size: 1000 8 | image_size: ??? 9 | num_classes: ??? 10 | 11 | transform: 12 | _target_: torchvision.transforms.Compose 13 | transforms: 14 | - _target_: torchvision.transforms.ToTensor 15 | 16 | val_dataset: ${.train_dataset} 17 | test_dataset: ${.val_dataset} 18 | 19 | num_classes: ${.train_dataset.num_classes} 20 | -------------------------------------------------------------------------------- /mart/configs/datamodule/fiftyone.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default.yaml 3 | 4 | train_dataset: 5 | _target_: mart.datamodules.fiftyone.FiftyOneDataset 6 | dataset_name: ??? 7 | gt_field: "ground_truth_segmentations" 8 | sample_tags: [] 9 | label_tags: [] 10 | transforms: 11 | _target_: mart.transforms.Compose 12 | transforms: 13 | - _target_: torchvision.transforms.ToTensor 14 | - _target_: mart.transforms.ConvertCocoPolysToMask 15 | - _target_: mart.transforms.RandomHorizontalFlip 16 | p: 0.5 17 | - _target_: mart.transforms.Denormalize 18 | center: 0 19 | scale: 255 20 | - _target_: torch.fake_quantize_per_tensor_affine 21 | _partial_: true 22 | # (x/1+0).round().clamp(0, 255) * 1 23 | scale: 1 24 | zero_point: 0 25 | quant_min: 0 26 | quant_max: 255 27 | 28 | val_dataset: 29 | _target_: mart.datamodules.fiftyone.FiftyOneDataset 30 | dataset_name: ??? 31 | gt_field: ${..train_dataset.gt_field} 32 | sample_tags: [] 33 | label_tags: [] 34 | transforms: 35 | _target_: mart.transforms.Compose 36 | transforms: 37 | - _target_: torchvision.transforms.ToTensor 38 | - _target_: mart.transforms.ConvertCocoPolysToMask 39 | - _target_: mart.transforms.Denormalize 40 | center: 0 41 | scale: 255 42 | - _target_: torch.fake_quantize_per_tensor_affine 43 | _partial_: true 44 | # (x/1+0).round().clamp(0, 255) * 1 45 | scale: 1 46 | zero_point: 0 47 | quant_min: 0 48 | quant_max: 255 49 | 50 | test_dataset: 51 | _target_: mart.datamodules.fiftyone.FiftyOneDataset 52 | dataset_name: ??? 53 | gt_field: ${..train_dataset.gt_field} 54 | sample_tags: [] 55 | label_tags: [] 56 | transforms: 57 | _target_: mart.transforms.Compose 58 | transforms: 59 | - _target_: torchvision.transforms.ToTensor 60 | - _target_: mart.transforms.ConvertCocoPolysToMask 61 | - _target_: mart.transforms.Denormalize 62 | center: 0 63 | scale: 255 64 | - _target_: torch.fake_quantize_per_tensor_affine 65 | _partial_: true 66 | # (x/1+0).round().clamp(0, 255) * 1 67 | scale: 1 68 | zero_point: 0 69 | quant_min: 0 70 | quant_max: 255 71 | 72 | num_workers: 2 73 | collate_fn: 74 | _target_: hydra.utils.get_method 75 | path: mart.datamodules.coco.collate_fn 76 | -------------------------------------------------------------------------------- /mart/configs/datamodule/fiftyone_perturbable_mask.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - fiftyone 3 | 4 | train_dataset: 5 | _target_: mart.datamodules.fiftyone.FiftyOneDataset 6 | dataset_name: ??? 7 | gt_field: "ground_truth_segmentations" 8 | sample_tags: [] 9 | label_tags: [] 10 | transforms: 11 | _target_: mart.transforms.Compose 12 | transforms: 13 | - _target_: torchvision.transforms.ToTensor 14 | # ConvertCocoPolysToMask must be prior to ConvertInstanceSegmentationToPerturbable. 15 | - _target_: mart.transforms.ConvertCocoPolysToMask 16 | - _target_: mart.transforms.RandomHorizontalFlip 17 | p: 0.5 18 | - _target_: mart.transforms.ConvertInstanceSegmentationToPerturbable 19 | - _target_: mart.transforms.Denormalize 20 | center: 0 21 | scale: 255 22 | - _target_: torch.fake_quantize_per_tensor_affine 23 | _partial_: true 24 | # (x/1+0).round().clamp(0, 255) * 1 25 | scale: 1 26 | zero_point: 0 27 | quant_min: 0 28 | quant_max: 255 29 | 30 | val_dataset: 31 | _target_: mart.datamodules.fiftyone.FiftyOneDataset 32 | dataset_name: ??? 33 | gt_field: ${..train_dataset.gt_field} 34 | sample_tags: [] 35 | label_tags: [] 36 | transforms: 37 | _target_: mart.transforms.Compose 38 | transforms: 39 | - _target_: torchvision.transforms.ToTensor 40 | # ConvertCocoPolysToMask must be prior to ConvertInstanceSegmentationToPerturbable. 41 | - _target_: mart.transforms.ConvertCocoPolysToMask 42 | - _target_: mart.transforms.ConvertInstanceSegmentationToPerturbable 43 | - _target_: mart.transforms.Denormalize 44 | center: 0 45 | scale: 255 46 | - _target_: torch.fake_quantize_per_tensor_affine 47 | _partial_: true 48 | # (x/1+0).round().clamp(0, 255) * 1 49 | scale: 1 50 | zero_point: 0 51 | quant_min: 0 52 | quant_max: 255 53 | 54 | test_dataset: 55 | _target_: mart.datamodules.fiftyone.FiftyOneDataset 56 | dataset_name: ??? 57 | gt_field: ${..train_dataset.gt_field} 58 | sample_tags: [] 59 | label_tags: [] 60 | transforms: 61 | _target_: mart.transforms.Compose 62 | transforms: 63 | - _target_: torchvision.transforms.ToTensor 64 | # ConvertCocoPolysToMask must be prior to ConvertInstanceSegmentationToPerturbable. 65 | - _target_: mart.transforms.ConvertCocoPolysToMask 66 | - _target_: mart.transforms.ConvertInstanceSegmentationToPerturbable 67 | - _target_: mart.transforms.Denormalize 68 | center: 0 69 | scale: 255 70 | - _target_: torch.fake_quantize_per_tensor_affine 71 | _partial_: true 72 | # (x/1+0).round().clamp(0, 255) * 1 73 | scale: 1 74 | zero_point: 0 75 | quant_min: 0 76 | quant_max: 255 77 | -------------------------------------------------------------------------------- /mart/configs/datamodule/imagenet.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default.yaml 3 | 4 | train_dataset: 5 | _target_: torchvision.datasets.ImageNet 6 | 7 | root: ${paths.data_dir}/imagenet/2012/ 8 | split: train 9 | transform: 10 | _target_: torchvision.transforms.Compose 11 | transforms: 12 | - _target_: torchvision.transforms.RandomResizedCrop 13 | size: 224 14 | - _target_: torchvision.transforms.RandomHorizontalFlip 15 | - _target_: torchvision.transforms.ToTensor 16 | - _target_: mart.transforms.Denormalize 17 | center: 0 18 | scale: 255 19 | - _target_: torch.fake_quantize_per_tensor_affine 20 | _partial_: true 21 | # (x/1+0).round().clamp(0, 255) * 1 22 | scale: 1 23 | zero_point: 0 24 | quant_min: 0 25 | quant_max: 255 26 | 27 | val_dataset: 28 | _target_: torchvision.datasets.ImageNet 29 | root: ${paths.data_dir}/imagenet/2012/ 30 | split: val 31 | transform: 32 | _target_: torchvision.transforms.Compose 33 | transforms: 34 | - _target_: torchvision.transforms.CenterCrop 35 | size: 224 36 | - _target_: torchvision.transforms.ToTensor 37 | - _target_: mart.transforms.Denormalize 38 | center: 0 39 | scale: 255 40 | - _target_: torch.fake_quantize_per_tensor_affine 41 | _partial_: true 42 | # (x/1+0).round().clamp(0, 255) * 1 43 | scale: 1 44 | zero_point: 0 45 | quant_min: 0 46 | quant_max: 255 47 | 48 | test_dataset: ${.val_dataset} 49 | 50 | num_classes: 1000 51 | -------------------------------------------------------------------------------- /mart/configs/debug/default.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # default debugging setup, runs 1 full epoch 4 | # other debugging configs can inherit from this one 5 | 6 | # overwrite task name so debugging logs are stored in separate folder 7 | task_name: "debug" 8 | 9 | # disable callbacks and loggers during debugging 10 | callbacks: null 11 | logger: null 12 | 13 | extras: 14 | ignore_warnings: False 15 | enforce_tags: False 16 | 17 | # sets level of all command line loggers to 'DEBUG' 18 | # https://hydra.cc/docs/tutorials/basic/running_your_app/logging/ 19 | hydra: 20 | job_logging: 21 | root: 22 | level: DEBUG 23 | 24 | # use this to also set hydra loggers to 'DEBUG' 25 | # verbose: True 26 | 27 | trainer: 28 | max_epochs: 1 29 | accelerator: cpu # debuggers don't like gpus 30 | devices: 1 # debuggers don't like multiprocessing 31 | detect_anomaly: true # raise exception if NaN or +/-inf is detected in any tensor 32 | 33 | datamodule: 34 | num_workers: 0 # debuggers don't like multiprocessing 35 | pin_memory: False # disable gpu memory pin 36 | -------------------------------------------------------------------------------- /mart/configs/debug/fdr.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # runs 1 train, 1 validation and 1 test step 4 | 5 | defaults: 6 | - default.yaml 7 | 8 | trainer: 9 | fast_dev_run: true 10 | -------------------------------------------------------------------------------- /mart/configs/debug/limit.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # uses only 1% of the training data and 5% of validation/test data 4 | 5 | defaults: 6 | - default.yaml 7 | 8 | trainer: 9 | max_epochs: 3 10 | limit_train_batches: 0.01 11 | limit_val_batches: 0.05 12 | limit_test_batches: 0.05 13 | -------------------------------------------------------------------------------- /mart/configs/debug/overfit.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # overfits to 3 batches 4 | 5 | defaults: 6 | - default.yaml 7 | 8 | trainer: 9 | max_epochs: 20 10 | overfit_batches: 3 11 | 12 | # model ckpt and early stopping need to be disabled during overfitting 13 | callbacks: null 14 | -------------------------------------------------------------------------------- /mart/configs/debug/profiler.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # runs with execution time profiling 4 | 5 | defaults: 6 | - default.yaml 7 | 8 | trainer: 9 | max_epochs: 1 10 | profiler: 11 | _target_: lightning.pytorch.profiler.SimpleProfiler 12 | # _target_: lightning.pytorch.profiler.AdvancedProfiler 13 | # _target_: lightning.pytorch.profiler.PyTorchProfiler 14 | dirpath: ${paths.output_dir} 15 | filename: profiler_log 16 | -------------------------------------------------------------------------------- /mart/configs/experiment/CIFAR10_CNN.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - override /datamodule: cifar10 5 | - override /model: classifier_cifar10_cnn 6 | - override /metric: accuracy 7 | - override /optimization: super_convergence 8 | - override /callbacks: [model_checkpoint, lr_monitor] 9 | 10 | task_name: "CIFAR10_CNN" 11 | tags: ["benign"] 12 | 13 | optimized_metric: "test_metrics/acc" 14 | 15 | callbacks: 16 | model_checkpoint: 17 | monitor: "validation_metrics/acc" 18 | mode: "max" 19 | 20 | trainer: 21 | # 50K training images, batch_size=128, drop_last, 15 epochs. 22 | max_steps: 5850 23 | precision: 32 24 | 25 | datamodule: 26 | ims_per_batch: 128 27 | world_size: 1 28 | num_workers: 8 29 | 30 | model: 31 | optimizer: 32 | lr: 0.1 33 | momentum: 0.9 34 | weight_decay: 1e-4 35 | -------------------------------------------------------------------------------- /mart/configs/experiment/CIFAR10_CNN_Adv.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - /attack@callbacks.adversary_connector.train_adversary: classification_fgsm_linf 5 | - /attack@callbacks.adversary_connector.test_adversary: classification_pgd_linf 6 | - override /datamodule: cifar10 7 | - override /model: classifier_cifar10_cnn 8 | - override /metric: accuracy 9 | - override /optimization: super_convergence 10 | - override /callbacks: [model_checkpoint, lr_monitor, adversary_connector] 11 | 12 | task_name: "CIFAR10_CNN_Adv" 13 | tags: ["adv", "fat"] 14 | 15 | optimized_metric: "test_metrics/acc" 16 | 17 | callbacks: 18 | model_checkpoint: 19 | monitor: "validation_metrics/acc" 20 | mode: "max" 21 | 22 | adversary_connector: 23 | train_adversary: 24 | eps: 1.75 25 | test_adversary: 26 | eps: 2 27 | lr: 1 28 | max_iters: 10 29 | 30 | trainer: 31 | # 50K training images, batch_size=128, drop_last, 15 epochs. 32 | max_steps: 5850 33 | precision: 32 34 | 35 | datamodule: 36 | ims_per_batch: 128 37 | world_size: 1 38 | num_workers: 8 39 | 40 | model: 41 | optimizer: 42 | lr: 0.1 43 | momentum: 0.9 44 | weight_decay: 1e-4 45 | -------------------------------------------------------------------------------- /mart/configs/experiment/COCO_TorchvisionFasterRCNN.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - override /datamodule: coco 5 | - override /model: torchvision_faster_rcnn 6 | - override /metric: average_precision 7 | - override /optimization: super_convergence 8 | - override /callbacks: [model_checkpoint, lr_monitor] 9 | 10 | task_name: "COCO_TorchvisionFasterRCNN" 11 | tags: ["regular_training"] 12 | 13 | optimized_metric: "test_metrics/map" 14 | 15 | callbacks: 16 | model_checkpoint: 17 | monitor: "validation_metrics/map" 18 | mode: "max" 19 | 20 | trainer: 21 | # 117,266 training images, 6 epochs, batch_size=2, 351798 22 | max_steps: 351798 23 | # FIXME: "nms_kernel" not implemented for 'BFloat16', torch.ops.torchvision.nms(). 24 | precision: 32 25 | 26 | datamodule: 27 | ims_per_batch: 2 28 | world_size: 1 29 | 30 | model: 31 | modules: 32 | losses_and_detections: 33 | model: 34 | # Inferred by torchvision. 35 | num_classes: null 36 | weights: COCO_V1 37 | 38 | optimizer: 39 | lr: 0.0125 40 | momentum: 0.9 41 | weight_decay: 1e-4 42 | -------------------------------------------------------------------------------- /mart/configs/experiment/COCO_TorchvisionFasterRCNN_Adv.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - COCO_TorchvisionFasterRCNN 5 | - /attack@callbacks.adversary_connector.test_adversary: object_detection_mask_adversary 6 | - override /datamodule: coco_perturbable_mask 7 | - override /callbacks: [model_checkpoint, lr_monitor, adversary_connector] 8 | 9 | task_name: "COCO_TorchvisionFasterRCNN_Adv" 10 | tags: ["adv"] 11 | 12 | callbacks: 13 | adversary_connector: 14 | test_adversary: 15 | # Make a 5-step attack for the demonstration purpose. 16 | max_iters: 5 17 | lr: 55 18 | -------------------------------------------------------------------------------- /mart/configs/experiment/COCO_TorchvisionRetinaNet.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - override /datamodule: coco 5 | - override /model: torchvision_retinanet 6 | - override /metric: average_precision 7 | - override /optimization: super_convergence 8 | - override /callbacks: [model_checkpoint, lr_monitor] 9 | 10 | task_name: "COCO_TorchvisionRetinaNet" 11 | tags: ["regular_training"] 12 | 13 | optimized_metric: "test_metrics/map" 14 | 15 | callbacks: 16 | model_checkpoint: 17 | monitor: "validation_metrics/map" 18 | mode: "max" 19 | 20 | trainer: 21 | # 117,266 training images, 6 epochs, batch_size=2, 351798 22 | max_steps: 351798 23 | precision: 16 24 | 25 | datamodule: 26 | ims_per_batch: 2 27 | world_size: 1 28 | 29 | model: 30 | modules: 31 | losses_and_detections: 32 | model: 33 | # Inferred by torchvision. 34 | num_classes: null 35 | weights: COCO_V1 36 | 37 | optimizer: 38 | lr: 0.0125 39 | momentum: 0.9 40 | weight_decay: 1e-4 41 | -------------------------------------------------------------------------------- /mart/configs/experiment/ImageNet_Timm.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - override /datamodule: imagenet 5 | - override /model: classifier_timm 6 | - override /metric: accuracy 7 | - override /optimization: super_convergence 8 | 9 | task_name: "ImageNet_Timm" 10 | tags: ["regular_training"] 11 | 12 | optimized_metric: "test_metrics/acc" 13 | 14 | callbacks: 15 | model_checkpoint: 16 | monitor: "validation_metrics/acc" 17 | 18 | trainer: 19 | # 1.2M training images, 15 epochs, batch_size=128, max_steps=1.2e6*15/128=140625 20 | max_steps: 140625 21 | precision: 16 22 | 23 | datamodule: 24 | ims_per_batch: 128 25 | world_size: 1 26 | num_workers: 8 27 | 28 | model: 29 | optimizer: 30 | lr: 0.1 31 | momentum: 0.9 32 | weight_decay: 1e-4 33 | -------------------------------------------------------------------------------- /mart/configs/extras/default.yaml: -------------------------------------------------------------------------------- 1 | # disable python warnings if they annoy you 2 | ignore_warnings: False 3 | 4 | # ask user for tags if none are provided in the config 5 | enforce_tags: True 6 | 7 | # pretty print config tree at the start of the run using Rich library 8 | print_config: True 9 | -------------------------------------------------------------------------------- /mart/configs/hparams_search/mnist_optuna.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # example hyperparameter optimization of some experiment with Optuna: 4 | # python train.py -m hparams_search=mnist_optuna experiment=example 5 | 6 | defaults: 7 | - override /hydra/sweeper: optuna 8 | 9 | # choose metric which will be optimized by Optuna 10 | # make sure this is the correct name of some metric logged in lightning module! 11 | optimized_metric: "val/acc_best" 12 | 13 | # here we define Optuna hyperparameter search 14 | # it optimizes for value returned from function with @hydra.main decorator 15 | # docs: https://hydra.cc/docs/next/plugins/optuna_sweeper 16 | hydra: 17 | mode: "MULTIRUN" # set hydra to multirun by default if this config is attached 18 | 19 | sweeper: 20 | _target_: hydra_plugins.hydra_optuna_sweeper.optuna_sweeper.OptunaSweeper 21 | 22 | # storage URL to persist optimization results 23 | # for example, you can use SQLite if you set 'sqlite:///example.db' 24 | storage: null 25 | 26 | # name of the study to persist optimization results 27 | study_name: null 28 | 29 | # number of parallel workers 30 | n_jobs: 1 31 | 32 | # 'minimize' or 'maximize' the objective 33 | direction: maximize 34 | 35 | # total number of runs that will be executed 36 | n_trials: 20 37 | 38 | # choose Optuna hyperparameter sampler 39 | # you can choose bayesian sampler (tpe), random search (without optimization), grid sampler, and others 40 | # docs: https://optuna.readthedocs.io/en/stable/reference/samplers.html 41 | sampler: 42 | _target_: optuna.samplers.TPESampler 43 | seed: 1234 44 | n_startup_trials: 10 # number of random sampling runs before optimization starts 45 | 46 | # define hyperparameter search space 47 | params: 48 | model.optimizer.lr: interval(0.0001, 0.1) 49 | datamodule.batch_size: choice(32, 64, 128, 256) 50 | model.net.lin1_size: choice(64, 128, 256) 51 | model.net.lin2_size: choice(64, 128, 256) 52 | model.net.lin3_size: choice(32, 64, 128, 256) 53 | -------------------------------------------------------------------------------- /mart/configs/hydra/default.yaml: -------------------------------------------------------------------------------- 1 | # https://hydra.cc/docs/configure_hydra/intro/ 2 | 3 | # enable color logging 4 | defaults: 5 | - override hydra_logging: colorlog 6 | - override job_logging: colorlog 7 | 8 | # Since Hydra 1.2, we need to explicitly tell it to change the working directory. 9 | # https://hydra.cc/docs/next/upgrades/1.1_to_1.2/changes_to_job_working_dir/ 10 | job: 11 | chdir: true 12 | 13 | # output directory, generated dynamically on each run 14 | run: 15 | dir: ${paths.log_dir}/${task_name}/runs/${now:%Y-%m-%d}_${now:%H-%M-%S} 16 | sweep: 17 | dir: ${paths.log_dir}/${task_name}/multiruns/${now:%Y-%m-%d}_${now:%H-%M-%S} 18 | subdir: ${hydra.job.num} 19 | -------------------------------------------------------------------------------- /mart/configs/lightning.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # specify here default training configuration 4 | defaults: 5 | - _self_ 6 | - datamodule: null 7 | - model: null 8 | - metric: null 9 | - optimization: null 10 | - callbacks: default.yaml 11 | - logger: [csv, tensorboard] # set logger here or use command line (e.g. `python run.py logger=wandb`) 12 | - trainer: default.yaml 13 | - paths: default.yaml 14 | - extras: default.yaml 15 | - hydra: default.yaml 16 | 17 | - experiment: null 18 | 19 | - hparams_search: null 20 | 21 | # optional local config for machine/user specific settings 22 | # it's optional since it doesn't need to exist and is excluded from version control 23 | - optional local: default.yaml 24 | 25 | # debugging config (enable through command line, e.g. `python train.py debug=default) 26 | - debug: null 27 | 28 | # task name, determines output directory path 29 | task_name: "lightning" 30 | 31 | # tags to help you identify your experiments 32 | # you can overwrite this in experiment configs 33 | # overwrite from command line with `python -m mart tags="[first_tag, second_tag]"` 34 | # appending lists from command line is currently not supported :( 35 | # https://github.com/facebookresearch/hydra/issues/1547 36 | tags: ["dev"] 37 | 38 | # Train it or not. 39 | fit: True 40 | 41 | # check performance on test set, using the best model achieved during training 42 | # lightning chooses best model based on metric specified in checkpoint callback 43 | test: True 44 | 45 | # Whether to resume training using configuration and checkpoint in specified directory 46 | resume: null 47 | 48 | # seed for random number generators in pytorch, numpy and python.random 49 | seed: null 50 | -------------------------------------------------------------------------------- /mart/configs/logger/comet.yaml: -------------------------------------------------------------------------------- 1 | # https://www.comet.ml 2 | 3 | comet: 4 | _target_: lightning.pytorch.loggers.comet.CometLogger 5 | api_key: ${oc.env:COMET_API_TOKEN} # api key is loaded from environment variable 6 | save_dir: "${paths.output_dir}" 7 | project_name: "lightning-hydra-template" 8 | rest_api_key: null 9 | # experiment_name: "" 10 | experiment_key: null # set to resume experiment 11 | offline: False 12 | prefix: "" 13 | -------------------------------------------------------------------------------- /mart/configs/logger/csv.yaml: -------------------------------------------------------------------------------- 1 | # csv logger built in lightning 2 | 3 | csv: 4 | _target_: lightning.pytorch.loggers.csv_logs.CSVLogger 5 | save_dir: "${paths.output_dir}" 6 | name: "csv/" 7 | prefix: "" 8 | -------------------------------------------------------------------------------- /mart/configs/logger/many_loggers.yaml: -------------------------------------------------------------------------------- 1 | # train with many loggers at once 2 | 3 | defaults: 4 | # - comet.yaml 5 | - csv.yaml 6 | # - mlflow.yaml 7 | # - neptune.yaml 8 | - tensorboard.yaml 9 | - wandb.yaml 10 | -------------------------------------------------------------------------------- /mart/configs/logger/mlflow.yaml: -------------------------------------------------------------------------------- 1 | # https://mlflow.org 2 | 3 | mlflow: 4 | _target_: lightning.pytorch.loggers.mlflow.MLFlowLogger 5 | # experiment_name: "" 6 | # run_name: "" 7 | tracking_uri: ${paths.log_dir}/mlflow/mlruns # run `mlflow ui` command inside the `logs/mlflow/` dir to open the UI 8 | tags: null 9 | # save_dir: "./mlruns" 10 | prefix: "" 11 | artifact_location: null 12 | # run_id: "" 13 | -------------------------------------------------------------------------------- /mart/configs/logger/neptune.yaml: -------------------------------------------------------------------------------- 1 | # https://neptune.ai 2 | 3 | neptune: 4 | _target_: lightning.pytorch.loggers.neptune.NeptuneLogger 5 | api_key: ${oc.env:NEPTUNE_API_TOKEN} # api key is loaded from environment variable 6 | project: username/lightning-hydra-template 7 | # name: "" 8 | log_model_checkpoints: True 9 | prefix: "" 10 | -------------------------------------------------------------------------------- /mart/configs/logger/tensorboard.yaml: -------------------------------------------------------------------------------- 1 | # https://www.tensorflow.org/tensorboard/ 2 | 3 | tensorboard: 4 | _target_: lightning.pytorch.loggers.tensorboard.TensorBoardLogger 5 | save_dir: "${paths.output_dir}/tensorboard/" 6 | name: null 7 | log_graph: False 8 | default_hp_metric: True 9 | prefix: "" 10 | # version: "" 11 | -------------------------------------------------------------------------------- /mart/configs/logger/wandb.yaml: -------------------------------------------------------------------------------- 1 | # https://wandb.ai 2 | 3 | wandb: 4 | _target_: lightning.pytorch.loggers.wandb.WandbLogger 5 | # name: "" # name of the run (normally generated by wandb) 6 | save_dir: "${paths.output_dir}" 7 | offline: False 8 | id: null # pass correct id to resume experiment! 9 | anonymous: null # enable anonymous logging 10 | project: "lightning-hydra-template" 11 | log_model: False # upload lightning ckpts 12 | prefix: "" # a string to put at the beginning of metric keys 13 | # entity: "" # set to name of your wandb team 14 | group: "" 15 | tags: [] 16 | job_type: "" 17 | -------------------------------------------------------------------------------- /mart/configs/metric/accuracy.yaml: -------------------------------------------------------------------------------- 1 | # @package model 2 | 3 | training_metrics: 4 | _target_: torchmetrics.MetricCollection 5 | _convert_: partial # metrics must be a dict 6 | metrics: 7 | acc: 8 | _target_: torchmetrics.Accuracy 9 | task: multiclass 10 | num_classes: ${datamodule.num_classes} 11 | 12 | validation_metrics: ${.training_metrics} 13 | 14 | test_metrics: ${.validation_metrics} 15 | -------------------------------------------------------------------------------- /mart/configs/metric/average_precision.yaml: -------------------------------------------------------------------------------- 1 | # @package model 2 | 3 | training_metrics: 4 | _target_: torchmetrics.detection.mean_ap.MeanAveragePrecision 5 | class_metrics: false 6 | 7 | validation_metrics: 8 | _target_: torchmetrics.detection.mean_ap.MeanAveragePrecision 9 | class_metrics: false 10 | 11 | test_metrics: 12 | _target_: torchmetrics.collections.MetricCollection 13 | _convert_: partial 14 | metrics: 15 | map: 16 | _target_: torchmetrics.detection.mean_ap.MeanAveragePrecision 17 | class_metrics: false 18 | json: 19 | _target_: mart.utils.export.CocoPredictionJSON 20 | prediction_file_name: ${paths.output_dir}/test_prediction.json 21 | groundtruth_file_name: ${paths.output_dir}/test_groundtruth.json 22 | -------------------------------------------------------------------------------- /mart/configs/model/classifier.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - modular 3 | 4 | # All three sequences of training/validation/test are (almost, except for the training loss) equivalent but in different syntax. 5 | 6 | # The verbose version. 7 | training_sequence: 8 | seq010: 9 | preprocessor: 10 | _call_with_args_: ["input"] 11 | seq020: 12 | logits: 13 | _call_with_args_: ["preprocessor"] 14 | seq030: 15 | loss: 16 | _call_with_args_: ["logits", "target"] 17 | seq040: 18 | preds: 19 | _call_with_args_: ["logits"] 20 | 21 | # The kwargs-centric version. 22 | # We may use *args as **kwargs to avoid the lengthy _call_with_args_. 23 | # The drawback is that we would need to lookup the *args names from the code. 24 | # We use a list-style sequence since we don't care about replacing any elements. 25 | validation_sequence: 26 | - preprocessor: 27 | tensor: input 28 | - logits: ["preprocessor"] 29 | - preds: 30 | input: logits 31 | 32 | # The simplified version. 33 | # We treat a list as the `_call_with_args_` parameter. 34 | test_sequence: 35 | seq010: 36 | preprocessor: ["input"] 37 | seq020: 38 | logits: ["preprocessor"] 39 | seq030: 40 | preds: ["logits"] 41 | 42 | modules: 43 | preprocessor: ??? 44 | 45 | logits: ??? 46 | 47 | loss: 48 | _target_: torch.nn.CrossEntropyLoss 49 | 50 | preds: 51 | _target_: torch.nn.Softmax 52 | dim: 1 53 | -------------------------------------------------------------------------------- /mart/configs/model/classifier_cifar10_cnn.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - classifier 3 | - /model@modules.logits: cnn_7layer_bn2 4 | 5 | modules: 6 | preprocessor: 7 | # Normalize [0, 255] input. 8 | _target_: torchvision.transforms.Normalize 9 | mean: [125.307, 122.961, 113.8575] 10 | std: [51.5865, 50.847, 51.255] 11 | -------------------------------------------------------------------------------- /mart/configs/model/classifier_timm.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - classifier 3 | 4 | modules: 5 | preprocessor: 6 | # Convert [0, 255] input to [0, 1] 7 | _target_: torchvision.transforms.Normalize 8 | mean: 0 9 | std: 255 10 | 11 | logits: 12 | _target_: timm.models.convnext.convnext_tiny 13 | pretrained: true 14 | -------------------------------------------------------------------------------- /mart/configs/model/cnn_7layer_bn2.yaml: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/shizhouxing/Fast-Certified-Robust-Training/ 2 | _target_: torch.nn.Sequential 3 | _args_: 4 | # layer1 5 | - _target_: torch.nn.Conv2d 6 | in_channels: 3 7 | out_channels: 64 # width 8 | kernel_size: 3 9 | stride: 1 10 | padding: 1 11 | - _target_: torch.nn.BatchNorm2d 12 | num_features: 64 # width 13 | - _target_: torch.nn.ReLU 14 | 15 | # layer2 16 | - _target_: torch.nn.Conv2d 17 | in_channels: 64 # width 18 | out_channels: 64 # width 19 | kernel_size: 3 20 | stride: 1 21 | padding: 1 22 | - _target_: torch.nn.BatchNorm2d 23 | num_features: 64 # width 24 | - _target_: torch.nn.ReLU 25 | 26 | # layer3 27 | - _target_: torch.nn.Conv2d 28 | in_channels: 64 # width 29 | out_channels: 128 # 2*width 30 | kernel_size: 3 31 | stride: 2 32 | padding: 1 33 | - _target_: torch.nn.BatchNorm2d 34 | num_features: 128 # 2*width 35 | - _target_: torch.nn.ReLU 36 | 37 | # layer4 38 | - _target_: torch.nn.Conv2d 39 | in_channels: 128 # 2*width 40 | out_channels: 128 # 2*width 41 | kernel_size: 3 42 | stride: 1 43 | padding: 1 44 | - _target_: torch.nn.BatchNorm2d 45 | num_features: 128 # 2*width 46 | - _target_: torch.nn.ReLU 47 | 48 | # layer5 49 | - _target_: torch.nn.Conv2d 50 | in_channels: 128 # 2*width 51 | out_channels: 128 # 2*width 52 | kernel_size: 3 53 | stride: 1 54 | padding: 1 55 | - _target_: torch.nn.BatchNorm2d 56 | num_features: 128 57 | - _target_: torch.nn.ReLU 58 | 59 | - _target_: torch.nn.Flatten 60 | 61 | # layer 6 62 | - _target_: torch.nn.Linear 63 | in_features: 32768 # (input_width=32 // 2)*(input_height=32 // 2) * 2 * width 64 | out_features: 512 # linear_size 65 | - _target_: torch.nn.BatchNorm1d 66 | num_features: 512 # linear_size 67 | - _target_: torch.nn.ReLU 68 | 69 | # layer 7 70 | - _target_: torch.nn.Linear 71 | in_features: 512 # linear_size 72 | out_features: 10 # num_class 73 | -------------------------------------------------------------------------------- /mart/configs/model/modular.yaml: -------------------------------------------------------------------------------- 1 | _target_: mart.models.LitModular 2 | _convert_: all 3 | 4 | output_preds_key: "preds" 5 | output_target_key: "target" 6 | 7 | modules: ??? 8 | optimizer: ??? 9 | 10 | validation_metrics: ${.training_metrics} 11 | # We may use different metrics in test. 12 | # test_metrics: ${.training_metrics} 13 | -------------------------------------------------------------------------------- /mart/configs/model/torchvision_faster_rcnn.yaml: -------------------------------------------------------------------------------- 1 | # We simply wrap a torchvision object detection model for validation. 2 | defaults: 3 | - torchvision_object_detection 4 | 5 | # log all losses separately in training. 6 | training_step_log: 7 | loss_objectness: "losses_and_detections.training.loss_objectness" 8 | loss_rpn_box_reg: "losses_and_detections.training.loss_rpn_box_reg" 9 | loss_classifier: "losses_and_detections.training.loss_classifier" 10 | loss_box_reg: "losses_and_detections.training.loss_box_reg" 11 | 12 | training_sequence: 13 | seq010: 14 | preprocessor: ["input"] 15 | 16 | seq020: 17 | losses_and_detections: ["preprocessor", "target"] 18 | 19 | seq030: 20 | loss: 21 | # Sum up the losses. 22 | [ 23 | "losses_and_detections.training.loss_objectness", 24 | "losses_and_detections.training.loss_rpn_box_reg", 25 | "losses_and_detections.training.loss_classifier", 26 | "losses_and_detections.training.loss_box_reg", 27 | ] 28 | 29 | validation_sequence: 30 | seq010: 31 | preprocessor: ["input"] 32 | 33 | seq020: 34 | losses_and_detections: ["preprocessor", "target"] 35 | 36 | test_sequence: 37 | seq010: 38 | preprocessor: ["input"] 39 | 40 | seq020: 41 | losses_and_detections: ["preprocessor", "target"] 42 | 43 | modules: 44 | losses_and_detections: 45 | # 17s: DualModeGeneralizedRCNN 46 | # 23s: DualMode 47 | _target_: mart.models.DualModeGeneralizedRCNN 48 | model: 49 | _target_: torchvision.models.detection.fasterrcnn_resnet50_fpn 50 | num_classes: ??? 51 | -------------------------------------------------------------------------------- /mart/configs/model/torchvision_object_detection.yaml: -------------------------------------------------------------------------------- 1 | # We simply wrap a torchvision object detection model for validation. 2 | defaults: 3 | - modular 4 | 5 | training_step_log: 6 | loss: "loss" 7 | 8 | training_sequence: ??? 9 | validation_sequence: ??? 10 | test_sequence: ??? 11 | 12 | output_preds_key: "losses_and_detections.eval" 13 | 14 | modules: 15 | preprocessor: 16 | _target_: mart.transforms.TupleTransforms 17 | transforms: 18 | _target_: torchvision.transforms.Normalize 19 | mean: 0 20 | std: 255 21 | 22 | losses_and_detections: 23 | # Return losses in the training mode and predictions in the eval mode in one pass. 24 | _target_: mart.models.DualMode 25 | model: ??? 26 | 27 | loss: 28 | _target_: mart.nn.Sum 29 | -------------------------------------------------------------------------------- /mart/configs/model/torchvision_retinanet.yaml: -------------------------------------------------------------------------------- 1 | # We simply wrap a torchvision object detection model for validation. 2 | defaults: 3 | - torchvision_object_detection 4 | 5 | # log all losses separately in training. 6 | training_step_log: 7 | loss_classifier: "losses_and_detections.training.classification" 8 | loss_box_reg: "losses_and_detections.training.bbox_regression" 9 | 10 | training_sequence: 11 | - preprocessor: ["input"] 12 | - losses_and_detections: ["preprocessor", "target"] 13 | - loss: 14 | # Sum up the losses. 15 | [ 16 | "losses_and_detections.training.classification", 17 | "losses_and_detections.training.bbox_regression", 18 | ] 19 | 20 | validation_sequence: 21 | - preprocessor: ["input"] 22 | - losses_and_detections: ["preprocessor", "target"] 23 | 24 | test_sequence: 25 | - preprocessor: ["input"] 26 | - losses_and_detections: ["preprocessor", "target"] 27 | 28 | modules: 29 | losses_and_detections: 30 | # _target_: mart.models.DualMode 31 | model: 32 | _target_: torchvision.models.detection.retinanet_resnet50_fpn 33 | num_classes: ??? 34 | -------------------------------------------------------------------------------- /mart/configs/optimization/adaptive_sgd.yaml: -------------------------------------------------------------------------------- 1 | # @package model 2 | defaults: 3 | - /optimizer@optimizer: sgd 4 | - lr_scheduler/scheduler: reduced_lr_on_plateau 5 | 6 | optimizer: ??? 7 | 8 | lr_scheduler: 9 | interval: step 10 | frequency: 1 11 | monitor: ??? 12 | -------------------------------------------------------------------------------- /mart/configs/optimization/lr_scheduler/scheduler/reduced_lr_on_plateau.yaml: -------------------------------------------------------------------------------- 1 | _target_: torch.optim.lr_scheduler.ReduceLROnPlateau 2 | _partial_: true 3 | mode: min 4 | factor: 0.1 5 | patience: 10 6 | threshold: 0.0001 7 | threshold_mode: rel 8 | cooldown: 0 9 | min_lr: 0 10 | eps: 1e-08 11 | verbose: false 12 | -------------------------------------------------------------------------------- /mart/configs/optimization/super_convergence.yaml: -------------------------------------------------------------------------------- 1 | # @package model 2 | defaults: 3 | - /optimizer@optimizer: sgd 4 | 5 | optimizer: ??? 6 | 7 | lr_scheduler: 8 | scheduler: 9 | _target_: torch.optim.lr_scheduler.OneCycleLR 10 | _partial_: true 11 | total_steps: ${trainer.max_steps} 12 | max_lr: ${model.optimizer.lr} 13 | anneal_strategy: cos 14 | interval: step 15 | frequency: 1 16 | -------------------------------------------------------------------------------- /mart/configs/optimizer/adam.yaml: -------------------------------------------------------------------------------- 1 | _target_: mart.optim.OptimizerFactory 2 | optimizer: 3 | _target_: hydra.utils.get_method 4 | path: torch.optim.Adam 5 | lr: ??? 6 | betas: 7 | - 0.9 8 | - 0.999 9 | eps: 1e-08 10 | weight_decay: 0 11 | bias_decay: 0 12 | norm_decay: 0 13 | -------------------------------------------------------------------------------- /mart/configs/optimizer/sgd.yaml: -------------------------------------------------------------------------------- 1 | _target_: mart.optim.OptimizerFactory 2 | optimizer: 3 | _target_: hydra.utils.get_method 4 | path: torch.optim.SGD 5 | lr: ??? 6 | momentum: 0 7 | weight_decay: 0 8 | bias_decay: 0 9 | norm_decay: 0 10 | ## You may simplify the config as below, if you just want to set bias_decay and norm_decay to 0. 11 | ## LitModular and Adversary will wrap the optimizer with OptimizerFactor for you. 12 | # _target_: torch.optim.SGD 13 | # _partial_: true 14 | # lr: ??? 15 | -------------------------------------------------------------------------------- /mart/configs/paths/default.yaml: -------------------------------------------------------------------------------- 1 | # path to root directory 2 | # this requires PROJECT_ROOT environment variable to exist 3 | # PROJECT_ROOT is inferred and set by pyrootutils package in `train.py` and `eval.py` 4 | root_dir: ${oc.env:PROJECT_ROOT} 5 | 6 | # path to data directory 7 | data_dir: ${paths.root_dir}/data/ 8 | 9 | # path to logging directory 10 | log_dir: ${paths.root_dir}/logs/ 11 | 12 | # path to output directory, created dynamically by hydra 13 | # path generation pattern is specified in `configs/hydra/default.yaml` 14 | # use it to store all files generated during the run, like ckpts and metrics 15 | output_dir: ${hydra:runtime.output_dir} 16 | 17 | # path to working directory 18 | work_dir: ${hydra:runtime.cwd} 19 | -------------------------------------------------------------------------------- /mart/configs/trainer/ddp.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default.yaml 3 | 4 | strategy: ddp 5 | 6 | accelerator: gpu 7 | num_nodes: 1 8 | sync_batchnorm: True 9 | -------------------------------------------------------------------------------- /mart/configs/trainer/ddp_sim.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default.yaml 3 | 4 | # simulate DDP on CPU, useful for debugging 5 | accelerator: cpu 6 | strategy: ddp_spawn 7 | -------------------------------------------------------------------------------- /mart/configs/trainer/default.yaml: -------------------------------------------------------------------------------- 1 | _target_: lightning.pytorch.Trainer 2 | 3 | default_root_dir: ${paths.output_dir} 4 | 5 | min_epochs: 1 # prevents early stopping 6 | # max_epochs: 10 7 | 8 | accelerator: cpu 9 | devices: 1 10 | 11 | # mixed precision for extra speed-up 12 | # precision: 16 13 | 14 | # set True to to ensure deterministic results 15 | # makes training slower but gives more reproducibility than just setting seeds 16 | deterministic: False 17 | 18 | # Disable PyTorch inference mode in val/test/predict, because we may run back-propagation for adversary. 19 | inference_mode: False 20 | -------------------------------------------------------------------------------- /mart/configs/trainer/gpu.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default.yaml 3 | 4 | accelerator: gpu 5 | devices: 1 6 | -------------------------------------------------------------------------------- /mart/configs/trainer/mps.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default.yaml 3 | 4 | accelerator: mps 5 | devices: 1 6 | -------------------------------------------------------------------------------- /mart/datamodules/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2022 Intel Corporation 3 | # 4 | # SPDX-License-Identifier: BSD-3-Clause 5 | # 6 | 7 | from .modular import * 8 | from .vision import * 9 | -------------------------------------------------------------------------------- /mart/datamodules/vision/__init__.py: -------------------------------------------------------------------------------- 1 | from ...utils.imports import _HAS_FIFTYONE, _HAS_TORCHVISION 2 | 3 | if _HAS_TORCHVISION: 4 | from .coco import * 5 | 6 | if _HAS_TORCHVISION and _HAS_FIFTYONE: 7 | from .fiftyone import * 8 | -------------------------------------------------------------------------------- /mart/datamodules/vision/coco.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2022 Intel Corporation 3 | # 4 | # SPDX-License-Identifier: BSD-3-Clause 5 | # 6 | 7 | import os 8 | from typing import Any, Callable, List, Optional 9 | 10 | import numpy as np 11 | from torchvision.datasets.coco import CocoDetection as CocoDetection_ 12 | from torchvision.datasets.folder import default_loader 13 | 14 | __all__ = ["CocoDetection"] 15 | 16 | 17 | class CocoDetection(CocoDetection_): 18 | """Extra features: 19 | 1. Add image_id to the target dict; 20 | 2. Add file_name to the target dict; 21 | 3. Add ability to load multiple modalities 22 | 23 | Args: 24 | See torchvision.datasets.coco.CocoDetection 25 | 26 | modalities (list of strings): A list of subfolders under root to load modalities. 27 | """ 28 | 29 | def __init__( 30 | self, 31 | root: str, 32 | annFile: str, 33 | transform: Optional[Callable] = None, 34 | target_transform: Optional[Callable] = None, 35 | transforms: Optional[Callable] = None, 36 | modalities: Optional[List[str]] = None, 37 | ) -> None: 38 | # CocoDetection doesn't support transform or target_transform because 39 | # we need to manipulate the input and target at the same time. 40 | assert transform is None 41 | assert target_transform is None 42 | 43 | super().__init__(root, annFile, transform, target_transform, transforms) 44 | 45 | self.modalities = modalities 46 | 47 | def _load_image(self, id: int) -> Any: 48 | if self.modalities is None: 49 | return super()._load_image(id) 50 | 51 | # Concatenate modalities into single tensor (not PIL.Image). We do this to be 52 | # compatible with existing transforms, since transforms should make the same 53 | # transformation to each modality. 54 | path = self.coco.loadImgs(id)[0]["file_name"] 55 | modalities = [ 56 | default_loader(os.path.join(self.root, modality, path)) for modality in self.modalities 57 | ] 58 | 59 | # Create numpy.ndarry by stacking modalities along channels axis. We use numpy 60 | # because PIL does not support multi-channel images. 61 | image = np.concatenate([np.array(modality) for modality in modalities], axis=-1) 62 | 63 | return image 64 | 65 | def _load_target(self, id: int) -> List[Any]: 66 | annotations = super()._load_target(id) 67 | file_name = self.coco.loadImgs(id)[0]["file_name"] 68 | 69 | return {"image_id": id, "file_name": file_name, "annotations": annotations} 70 | 71 | 72 | # Source: https://github.com/pytorch/vision/blob/dc07ac2add8285e16a716564867d0b4b953f6735/references/detection/utils.py#L203 73 | def collate_fn(batch): 74 | return tuple(zip(*batch)) 75 | -------------------------------------------------------------------------------- /mart/generate_config.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2022 Intel Corporation 3 | # 4 | # SPDX-License-Identifier: BSD-3-Clause 5 | # 6 | 7 | from __future__ import annotations 8 | 9 | import fire 10 | from omegaconf import OmegaConf 11 | 12 | from .utils.config import ( 13 | DEFAULT_CONFIG_DIR, 14 | DEFAULT_CONFIG_NAME, 15 | DEFAULT_VERSION_BASE, 16 | compose, 17 | ) 18 | 19 | 20 | def main( 21 | *overrides, 22 | version_base: str = DEFAULT_VERSION_BASE, 23 | config_dir: str = DEFAULT_CONFIG_DIR, 24 | config_name: str = DEFAULT_CONFIG_NAME, 25 | export_node: str | None = None, 26 | resolve: bool = False, 27 | sort_keys: bool = True, 28 | ): 29 | cfg = compose( 30 | *overrides, 31 | version_base=version_base, 32 | config_dir=config_dir, 33 | config_name=config_name, 34 | export_node=export_node, 35 | ) 36 | 37 | cfg_yaml = OmegaConf.to_yaml(cfg, resolve=resolve, sort_keys=sort_keys) 38 | 39 | # OmegaConf.to_yaml() already ends with `\n`. 40 | print(cfg_yaml, end="") 41 | 42 | 43 | if __name__ == "__main__": 44 | fire.Fire(main) 45 | -------------------------------------------------------------------------------- /mart/models/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2022 Intel Corporation 3 | # 4 | # SPDX-License-Identifier: BSD-3-Clause 5 | # 6 | 7 | from .modular import * # noqa: F403 8 | from .vision import * # noqa: F403 9 | -------------------------------------------------------------------------------- /mart/models/vision/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2022 Intel Corporation 3 | # 4 | # SPDX-License-Identifier: BSD-3-Clause 5 | # 6 | 7 | from ...utils.imports import _HAS_TORCHVISION 8 | 9 | if _HAS_TORCHVISION: 10 | from .dual_mode import * # noqa: F403 11 | -------------------------------------------------------------------------------- /mart/nn/__init__.py: -------------------------------------------------------------------------------- 1 | from .nn import * # noqa: F403 2 | -------------------------------------------------------------------------------- /mart/optim/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2022 Intel Corporation 3 | # 4 | # SPDX-License-Identifier: BSD-3-Clause 5 | # 6 | 7 | from .optimizer import * # noqa: F403 8 | -------------------------------------------------------------------------------- /mart/optim/optimizer.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2022 Intel Corporation 3 | # 4 | # SPDX-License-Identifier: BSD-3-Clause 5 | # 6 | 7 | import logging 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | import torch # noqa: E402 12 | 13 | __all__ = ["OptimizerFactory"] 14 | 15 | 16 | class OptimizerFactory: 17 | """Create optimizers compatible with LightningModule. 18 | 19 | Also supports decay parameters for bias and norm modules independently. 20 | """ 21 | 22 | def __init__(self, optimizer, **kwargs): 23 | weight_decay = kwargs.get("weight_decay", 0.0) 24 | 25 | self.bias_decay = kwargs.pop("bias_decay", weight_decay) 26 | self.norm_decay = kwargs.pop("norm_decay", weight_decay) 27 | self.optimizer = optimizer 28 | self.kwargs = kwargs 29 | 30 | def __call__(self, module): 31 | # Separate parameters into biases, norms, and weights 32 | bias_params = [] 33 | norm_params = [] 34 | weight_params = [] 35 | 36 | for param_name, param in module.named_parameters(): 37 | if not param.requires_grad: 38 | continue 39 | 40 | # Find module by name 41 | module_name = ".".join(param_name.split(".")[:-1]) 42 | _, param_module = next(filter(lambda nm: nm[0] == module_name, module.named_modules())) 43 | module_kind = param_module.__class__.__name__ 44 | 45 | if "Norm" in module_kind: 46 | assert len(param.shape) == 1 47 | norm_params.append(param) 48 | elif isinstance(param, torch.nn.UninitializedParameter): 49 | # Assume lazy parameters are weights 50 | weight_params.append(param) 51 | elif len(param.shape) == 1: 52 | bias_params.append(param) 53 | else: # Assume weights 54 | weight_params.append(param) 55 | 56 | # Set decay for bias and norm parameters 57 | params = [] 58 | if len(weight_params) > 0: 59 | params.append({"params": weight_params}) # use default weight decay 60 | if len(bias_params) > 0: 61 | params.append({"params": bias_params, "weight_decay": self.bias_decay}) 62 | if len(norm_params) > 0: 63 | params.append({"params": norm_params, "weight_decay": self.norm_decay}) 64 | 65 | return self.optimizer(params, **self.kwargs) 66 | -------------------------------------------------------------------------------- /mart/tasks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IntelLabs/MART/c86a059a9177b26e1fb0404216145059f826eddb/mart/tasks/__init__.py -------------------------------------------------------------------------------- /mart/tasks/lightning.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Tuple 2 | 3 | import hydra 4 | from lightning.pytorch import ( 5 | Callback, 6 | LightningDataModule, 7 | LightningModule, 8 | Trainer, 9 | seed_everything, 10 | ) 11 | from lightning.pytorch.loggers import Logger 12 | from omegaconf import DictConfig 13 | 14 | from mart import utils 15 | 16 | log = utils.get_pylogger(__name__) 17 | 18 | 19 | @utils.task_wrapper 20 | def lightning(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: 21 | """Contains training pipeline. Instantiates all PyTorch Lightning objects from config. 22 | 23 | Args: 24 | cfg (DictConfig): Configuration composed by Hydra. 25 | 26 | Returns: 27 | Optional[float]: Metric score for hyperparameter optimization. 28 | """ 29 | 30 | # Set seed for random number generators in pytorch, numpy and python.random 31 | if cfg.get("seed"): 32 | seed_everything(cfg.seed, workers=True) 33 | 34 | # Init lightning datamodule 35 | log.info(f"Instantiating datamodule <{cfg.datamodule._target_}>") 36 | datamodule: LightningDataModule = hydra.utils.instantiate(cfg.datamodule) 37 | # Init lightning model 38 | log.info(f"Instantiating model <{cfg.model._target_}>") 39 | model: LightningModule = hydra.utils.instantiate(cfg.model) 40 | 41 | # Init lightning callbacks 42 | callbacks: List[Callback] = utils.instantiate_callbacks(cfg.get("callbacks")) 43 | 44 | # Init lightning loggers 45 | logger: List[Logger] = utils.instantiate_loggers(cfg.get("logger")) 46 | 47 | # Init lightning trainer 48 | log.info(f"Instantiating trainer <{cfg.trainer._target_}>") 49 | trainer: Trainer = hydra.utils.instantiate( 50 | cfg.trainer, 51 | callbacks=callbacks, 52 | logger=logger, 53 | ) 54 | 55 | # Send some parameters from config to all lightning loggers 56 | object_dict = { 57 | "cfg": cfg, 58 | "datamodule": datamodule, 59 | "model": model, 60 | "callbacks": callbacks, 61 | "logger": logger, 62 | "trainer": trainer, 63 | } 64 | 65 | if logger: 66 | log.info("Logging hyperparameters!") 67 | utils.log_hyperparameters(object_dict) 68 | 69 | # ckpt_path could be None if resume=null. 70 | ckpt_path = cfg.get("ckpt_path", None) 71 | 72 | # Train the model 73 | if cfg.get("fit"): 74 | log.info("Starting training!") 75 | trainer.fit(model=model, datamodule=datamodule, ckpt_path=ckpt_path) 76 | ckpt_path = None # make sure trainer tests trained model 77 | 78 | train_metrics = trainer.callback_metrics 79 | 80 | # Evaluate model on test set, using the best model achieved during training 81 | if cfg.get("test"): 82 | log.info("Starting testing!") 83 | trainer.test(model=model, datamodule=datamodule, ckpt_path=ckpt_path) 84 | 85 | test_metrics = trainer.callback_metrics 86 | 87 | # merge train and test metrics 88 | metric_dict = {**train_metrics, **test_metrics} 89 | 90 | # Print path to best checkpoint 91 | if not cfg.trainer.get("fast_dev_run"): 92 | log.info(f"Best model ckpt: {trainer.checkpoint_callback.best_model_path}") 93 | 94 | return metric_dict, object_dict 95 | -------------------------------------------------------------------------------- /mart/transforms/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2022 Intel Corporation 3 | # 4 | # SPDX-License-Identifier: BSD-3-Clause 5 | # 6 | 7 | from .batch_c15n import * # noqa: F403 8 | from .transforms import * # noqa: F403 9 | from .vision import * # noqa: F403 10 | -------------------------------------------------------------------------------- /mart/transforms/transforms.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2022 Intel Corporation 3 | # 4 | # SPDX-License-Identifier: BSD-3-Clause 5 | # 6 | 7 | import torch 8 | 9 | __all__ = ["Cat", "Permute", "Unsqueeze", "Squeeze", "Chunk", "TupleTransforms"] 10 | 11 | 12 | class Cat: 13 | def __init__(self, dim): 14 | self.dim = dim 15 | 16 | def __call__(self, tensors): 17 | return torch.cat(tensors, dim=self.dim) 18 | 19 | 20 | class Permute: 21 | def __init__(self, *dims): 22 | self.dims = dims 23 | 24 | def __call__(self, tensor): 25 | return tensor.permute(*self.dims) 26 | 27 | 28 | class Unsqueeze: 29 | def __init__(self, dim): 30 | self.dim = dim 31 | 32 | def __call__(self, tensor): 33 | return tensor.unsqueeze(self.dim) 34 | 35 | 36 | class Squeeze: 37 | def __init__(self, dim=None): 38 | self.dim = dim 39 | 40 | def __call__(self, tensor): 41 | return tensor.squeeze(self.dim) 42 | 43 | 44 | # TODO: Add Cat transform 45 | # FIXME: Change to Split transform 46 | class Chunk: 47 | def __init__(self, chunks, dim=0): 48 | self.chunks = chunks 49 | self.dim = dim 50 | 51 | def __call__(self, tensor): 52 | chunks = tensor.chunk(self.chunks, dim=self.dim) 53 | # print("chunks =", type(chunks)) 54 | return [*chunks] # tuple -> list 55 | 56 | 57 | class TupleTransforms(torch.nn.Module): 58 | def __init__(self, transforms): 59 | super().__init__() 60 | 61 | self.transforms = transforms 62 | 63 | def forward(self, x_tuple): 64 | output_tuple = tuple(self.transforms(x) for x in x_tuple) 65 | return output_tuple 66 | -------------------------------------------------------------------------------- /mart/transforms/vision/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2022 Intel Corporation 3 | # 4 | # SPDX-License-Identifier: BSD-3-Clause 5 | # 6 | 7 | 8 | from ...utils.imports import _HAS_TORCHVISION 9 | 10 | if _HAS_TORCHVISION: 11 | from .objdet import * # noqa: F403 12 | from .transforms import * # noqa: F403 13 | -------------------------------------------------------------------------------- /mart/transforms/vision/objdet/__init__.py: -------------------------------------------------------------------------------- 1 | from ....utils.imports import _HAS_PYCOCOTOOLS, _HAS_TORCHVISION 2 | 3 | if _HAS_TORCHVISION and _HAS_PYCOCOTOOLS: 4 | from .extended import * # noqa: F403 5 | -------------------------------------------------------------------------------- /mart/transforms/vision/objdet/torchvision_ref.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from pycocotools import mask as coco_mask 3 | from torchvision.transforms import functional as F 4 | 5 | __all__ = ["convert_coco_poly_to_mask", "ConvertCocoPolysToMask"] 6 | 7 | 8 | # Source: https://github.com/pytorch/vision/blob/dc07ac2add8285e16a716564867d0b4b953f6735/references/detection/transforms.py#L10 9 | def _flip_coco_person_keypoints(kps, width): 10 | flip_inds = [0, 2, 1, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 16, 15] 11 | flipped_data = kps[:, flip_inds] 12 | flipped_data[..., 0] = width - flipped_data[..., 0] 13 | # Maintain COCO convention that if visibility == 0, then x, y = 0 14 | inds = flipped_data[..., 2] == 0 15 | flipped_data[inds] = 0 16 | return flipped_data 17 | 18 | 19 | # Source: https://github.com/pytorch/vision/blob/dc07ac2add8285e16a716564867d0b4b953f6735/references/detection/coco_utils.py#L30 20 | def convert_coco_poly_to_mask(segmentations, height, width): 21 | masks = [] 22 | for polygons in segmentations: 23 | rles = coco_mask.frPyObjects(polygons, height, width) 24 | mask = coco_mask.decode(rles) 25 | if len(mask.shape) < 3: 26 | mask = mask[..., None] 27 | mask = torch.as_tensor(mask, dtype=torch.uint8) 28 | mask = mask.any(dim=2) 29 | masks.append(mask) 30 | if masks: 31 | masks = torch.stack(masks, dim=0) 32 | else: 33 | masks = torch.zeros((0, height, width), dtype=torch.uint8) 34 | return masks 35 | 36 | 37 | # Source: https://github.com/pytorch/vision/blob/dc07ac2add8285e16a716564867d0b4b953f6735/references/detection/coco_utils.py#L47 38 | # Adapted to mart.datamodules.coco.CocoDetection by adding the "file_name" field. 39 | class ConvertCocoPolysToMask: 40 | def __call__(self, image, target): 41 | w, h = F.get_image_size(image) 42 | 43 | file_name = target.pop("file_name", None) 44 | image_id = target["image_id"] 45 | image_id = torch.tensor([image_id]) 46 | 47 | anno = target["annotations"] 48 | 49 | anno = [obj for obj in anno if obj["iscrowd"] == 0] 50 | 51 | boxes = [obj["bbox"] for obj in anno] 52 | # guard against no boxes via resizing 53 | boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4) 54 | boxes[:, 2:] += boxes[:, :2] 55 | boxes[:, 0::2].clamp_(min=0, max=w) 56 | boxes[:, 1::2].clamp_(min=0, max=h) 57 | 58 | classes = [obj["category_id"] for obj in anno] 59 | classes = torch.tensor(classes, dtype=torch.int64) 60 | 61 | masks = None 62 | if len(anno) > 0 and "segmentation" in anno[0]: 63 | segmentations = [obj["segmentation"] for obj in anno] 64 | masks = convert_coco_poly_to_mask(segmentations, h, w) 65 | 66 | keypoints = None 67 | if len(anno) > 0 and "keypoints" in anno[0]: 68 | keypoints = [obj["keypoints"] for obj in anno] 69 | keypoints = torch.as_tensor(keypoints, dtype=torch.float32) 70 | num_keypoints = keypoints.shape[0] 71 | if num_keypoints: 72 | keypoints = keypoints.view(num_keypoints, -1, 3) 73 | 74 | keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0]) 75 | boxes = boxes[keep] 76 | classes = classes[keep] 77 | if masks is not None: 78 | masks = masks[keep] 79 | if keypoints is not None: 80 | keypoints = keypoints[keep] 81 | 82 | target = {} 83 | target["boxes"] = boxes 84 | target["labels"] = classes 85 | if masks is not None: 86 | target["masks"] = masks 87 | target["image_id"] = image_id 88 | target["file_name"] = file_name 89 | if keypoints is not None: 90 | target["keypoints"] = keypoints 91 | 92 | # for conversion to coco api 93 | area = torch.tensor([obj["area"] for obj in anno]) 94 | iscrowd = torch.tensor([obj["iscrowd"] for obj in anno]) 95 | target["area"] = area 96 | target["iscrowd"] = iscrowd 97 | 98 | return image, target 99 | -------------------------------------------------------------------------------- /mart/transforms/vision/transforms.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2022 Intel Corporation 3 | # 4 | # SPDX-License-Identifier: BSD-3-Clause 5 | # 6 | 7 | __all__ = ["Denormalize"] 8 | 9 | from torchvision.transforms import transforms as T 10 | 11 | 12 | class Denormalize(T.Normalize): 13 | """Unnormalized using center and scale via existing Normalize transform such that: 14 | 15 | output = (input * scale + center) 16 | 17 | Args: 18 | center: value to center input by 19 | scale: value to scale input by 20 | """ 21 | 22 | def __init__(self, center, scale, inplace=False): 23 | mean = -center / scale 24 | std = 1 / scale 25 | 26 | super().__init__(mean, std, inplace=inplace) 27 | -------------------------------------------------------------------------------- /mart/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Only import components without external dependency. 2 | from .adapters import * 3 | from .imports import _HAS_LIGHTNING 4 | from .monkey_patch import * 5 | from .optimization import * 6 | from .silent import * 7 | from .utils import * 8 | 9 | if _HAS_LIGHTNING: 10 | from .lightning import * 11 | from .pylogger import * 12 | from .rich_utils import * 13 | -------------------------------------------------------------------------------- /mart/utils/adapters.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2022 Intel Corporation 3 | # 4 | # SPDX-License-Identifier: BSD-3-Clause 5 | # 6 | 7 | from typing import Any, Callable 8 | 9 | __all__ = ["CallableAdapter", "PartialInstanceWrapper"] 10 | 11 | 12 | class CallableAdapter: 13 | """Adapter to make an object callable.""" 14 | 15 | def __init__(self, instance, redirecting_fn): 16 | """ 17 | 18 | Args: 19 | instance (object): instance to make callable. 20 | redirecting_fn (str): name of the function that will be invoked in the `__call__` method. 21 | """ 22 | assert instance is not None 23 | assert redirecting_fn != "" 24 | 25 | self.instance = instance 26 | self.redirecting_fn = redirecting_fn 27 | 28 | def __call__(self, *args, **kwargs): 29 | """ 30 | 31 | Args: 32 | args (Any): values to use in the callable method. 33 | kwargs (Any): keyword values to use in the callable method. 34 | """ 35 | function = getattr(self.instance, self.redirecting_fn) 36 | 37 | assert callable(function) 38 | 39 | return function(*args, **kwargs) 40 | 41 | 42 | class PartialInstanceWrapper: 43 | """Make a partial class object callable.""" 44 | 45 | def __init__(self, partial: Callable, wrapper: Callable): 46 | """ 47 | 48 | Args: 49 | partial (Callable): A partial of a class object. 50 | adapter (Callable): An adapter that creates the `__call__` method. 51 | """ 52 | self.partial = partial 53 | self.wrapper = wrapper 54 | 55 | def __call__(self, *args: Any, **kwargs: Any) -> Callable: 56 | # Turn a partial to a class object. 57 | instance = self.partial(*args, **kwargs) 58 | # Make the object callable. 59 | return self.wrapper(instance) 60 | -------------------------------------------------------------------------------- /mart/utils/config.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2022 Intel Corporation 3 | # 4 | # SPDX-License-Identifier: BSD-3-Clause 5 | # 6 | 7 | from __future__ import annotations 8 | 9 | import os 10 | 11 | import hydra 12 | from hydra import compose as hydra_compose 13 | from hydra import initialize_config_dir 14 | from lightning.pytorch.callbacks.callback import Callback 15 | from omegaconf import OmegaConf 16 | 17 | DEFAULT_VERSION_BASE = "1.2" 18 | DEFAULT_CONFIG_DIR = "." 19 | DEFAULT_CONFIG_NAME = "lightning.yaml" 20 | 21 | __all__ = ["compose", "instantiate", "Instantiator", "CallbackInstantiator"] 22 | 23 | 24 | def compose( 25 | *overrides, 26 | version_base: str = DEFAULT_VERSION_BASE, 27 | config_dir: str = DEFAULT_CONFIG_DIR, 28 | config_name: str = DEFAULT_CONFIG_NAME, 29 | export_node: str | None = None, 30 | ): 31 | # Add an absolute path {config_dir} to the search path of configs, preceding those in mart.configs. 32 | if not os.path.isabs(config_dir): 33 | config_dir = os.path.abspath(config_dir) 34 | 35 | # hydra.initialize_config_dir() requires an absolute path, 36 | # while hydra.initialize() searches paths relatively to mart. 37 | with initialize_config_dir(version_base=version_base, config_dir=config_dir): 38 | cfg = hydra_compose(config_name=config_name, overrides=overrides) 39 | 40 | # Export a sub-tree. 41 | if export_node is not None: 42 | for key in export_node.split("."): 43 | cfg = cfg[key] 44 | 45 | return cfg 46 | 47 | 48 | def instantiate(cfg_path): 49 | """Instantiate an object from a Hydra yaml config file.""" 50 | config = OmegaConf.load(cfg_path) 51 | obj = hydra.utils.instantiate(config) 52 | return obj 53 | 54 | 55 | class Instantiator: 56 | def __new__(cls, cfg_path): 57 | return instantiate(cfg_path) 58 | 59 | 60 | class CallbackInstantiator(Callback): 61 | """Satisfying type checking for Lightning Callback.""" 62 | 63 | def __new__(cls, cfg_path): 64 | obj = instantiate(cfg_path) 65 | if isinstance(obj, Callback): 66 | return obj 67 | else: 68 | raise ValueError( 69 | f"We expect to instantiate a lightning Callback from {cfg_path}, but we get {type(obj)} instead." 70 | ) 71 | -------------------------------------------------------------------------------- /mart/utils/imports.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2022 Intel Corporation 3 | # 4 | # SPDX-License-Identifier: BSD-3-Clause 5 | # 6 | 7 | import logging 8 | from importlib.util import find_spec 9 | 10 | # Avoid importing .pylogger when checking imports before running other code. 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | def has(module_name): 15 | module = find_spec(module_name) 16 | if module is None: 17 | logger.warning( 18 | f"{module_name} is not installed, so some features in MART are unavailable." 19 | ) 20 | return False 21 | else: 22 | return True 23 | 24 | 25 | # Do not forget to add dependency checks on CI in `tests/test_dependency.py` 26 | _HAS_FIFTYONE = has("fiftyone") 27 | _HAS_TORCHVISION = has("torchvision") 28 | _HAS_TIMM = has("timm") 29 | _HAS_PYCOCOTOOLS = has("pycocotools") 30 | _HAS_LIGHTNING = has("lightning") 31 | -------------------------------------------------------------------------------- /mart/utils/monkey_patch.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2022 Intel Corporation 3 | # 4 | # SPDX-License-Identifier: BSD-3-Clause 5 | # 6 | 7 | import logging 8 | 9 | __all__ = ["MonkeyPatch"] 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | class MonkeyPatch: 15 | """Temporarily replace a module's object value, i.e., its functionality.""" 16 | 17 | def __init__(self, obj, name, value, verbose=False): 18 | self.obj = obj 19 | self.name = name 20 | self.value = value 21 | self.verbose = verbose 22 | 23 | def __enter__(self): 24 | self.orig_value = getattr(self.obj, self.name) 25 | 26 | if self.orig_value == self.value: 27 | return 28 | 29 | if self.verbose: 30 | logger.info("Monkey patching %s to %s", self.orig_value, self.value) 31 | 32 | setattr(self.obj, self.name, self.value) 33 | 34 | def __exit__(self, exc_type, exc_value, traceback): 35 | if self.orig_value == self.value: 36 | return 37 | 38 | if self.verbose: 39 | logger.info("Reverting monkey patch on %s", self.orig_value) 40 | 41 | setattr(self.obj, self.name, self.orig_value) 42 | -------------------------------------------------------------------------------- /mart/utils/optimization.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2022 Intel Corporation 3 | # 4 | # SPDX-License-Identifier: BSD-3-Clause 5 | # 6 | 7 | 8 | def configure_optimizers(module, optimizer, lr_scheduler): 9 | config = {} 10 | config["optimizer"] = optimizer(module) 11 | 12 | if lr_scheduler is not None: 13 | # FIXME: I don't think this actually work correctly, but we don't have an example of an lr_scheduler that is not a DictConfig 14 | if "scheduler" in lr_scheduler: 15 | config["lr_scheduler"] = dict(lr_scheduler) 16 | config["lr_scheduler"]["scheduler"] = config["lr_scheduler"]["scheduler"]( 17 | config["optimizer"] 18 | ) 19 | else: 20 | config["lr_scheduler"] = lr_scheduler(config["optimizer"]) 21 | 22 | return config 23 | -------------------------------------------------------------------------------- /mart/utils/pylogger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from lightning.pytorch.utilities import rank_zero_only 4 | 5 | __all__ = ["get_pylogger"] 6 | 7 | 8 | def get_pylogger(name=__name__) -> logging.Logger: 9 | """Initializes multi-GPU-friendly python command line logger.""" 10 | 11 | logger = logging.getLogger(name) 12 | 13 | # this ensures all logging levels get marked with the rank zero decorator 14 | # otherwise logs would get multiplied for each GPU process in multi-GPU setup 15 | logging_levels = ("debug", "info", "warning", "error", "exception", "fatal", "critical") 16 | for level in logging_levels: 17 | setattr(logger, level, rank_zero_only(getattr(logger, level))) 18 | 19 | return logger 20 | -------------------------------------------------------------------------------- /mart/utils/rich_utils.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Sequence 3 | 4 | import rich 5 | import rich.syntax 6 | import rich.tree 7 | from hydra.core.hydra_config import HydraConfig 8 | from lightning.pytorch.utilities import rank_zero_only 9 | from omegaconf import DictConfig, OmegaConf, open_dict 10 | from rich.prompt import Prompt 11 | 12 | from mart.utils import pylogger 13 | 14 | __all__ = ["enforce_tags", "print_config_tree"] 15 | 16 | log = pylogger.get_pylogger(__name__) 17 | 18 | 19 | @rank_zero_only 20 | def print_config_tree( 21 | cfg: DictConfig, 22 | print_order: Sequence[str] = ( 23 | "datamodule", 24 | "model", 25 | "callbacks", 26 | "logger", 27 | "trainer", 28 | "paths", 29 | "extras", 30 | ), 31 | resolve: bool = False, 32 | save_to_file: bool = False, 33 | ) -> None: 34 | """Prints content of DictConfig using Rich library and its tree structure. 35 | 36 | Args: 37 | cfg (DictConfig): Configuration composed by Hydra. 38 | print_order (Sequence[str], optional): Determines in what order config components are printed. 39 | resolve (bool, optional): Whether to resolve reference fields of DictConfig. 40 | save_to_file (bool, optional): Whether to export config to the hydra output folder. 41 | """ 42 | 43 | style = "dim" 44 | tree = rich.tree.Tree("CONFIG", style=style, guide_style=style) 45 | 46 | queue = [] 47 | 48 | # add fields from `print_order` to queue 49 | for field in print_order: 50 | queue.append(field) if field in cfg else log.warning( 51 | f"Field '{field}' not found in config. Skipping '{field}' config printing..." 52 | ) 53 | 54 | # add all the other fields to queue (not specified in `print_order`) 55 | for field in cfg: 56 | if field not in queue: 57 | queue.append(field) 58 | 59 | # generate config tree from queue 60 | for field in queue: 61 | branch = tree.add(field, style=style, guide_style=style) 62 | 63 | config_group = cfg[field] 64 | if isinstance(config_group, DictConfig): 65 | branch_content = OmegaConf.to_yaml(config_group, resolve=resolve) 66 | else: 67 | branch_content = str(config_group) 68 | 69 | branch.add(rich.syntax.Syntax(branch_content, "yaml")) 70 | 71 | # print config tree 72 | rich.print(tree) 73 | 74 | # save config tree to file 75 | if save_to_file: 76 | with open(Path(cfg.paths.output_dir, "config_tree.log"), "w") as file: 77 | rich.print(tree, file=file) 78 | 79 | 80 | @rank_zero_only 81 | def enforce_tags(cfg: DictConfig, save_to_file: bool = False) -> None: 82 | """Prompts user to input tags from command line if no tags are provided in config.""" 83 | 84 | if not cfg.get("tags"): 85 | if "id" in HydraConfig().cfg.hydra.job: 86 | raise ValueError("Specify tags before launching a multirun!") 87 | 88 | log.warning("No tags provided in config. Prompting user to input tags...") 89 | tags = Prompt.ask("Enter a list of comma separated tags", default="dev") 90 | tags = [t.strip() for t in tags.split(",") if t != ""] 91 | 92 | with open_dict(cfg): 93 | cfg.tags = tags 94 | 95 | log.info(f"Tags: {cfg.tags}") 96 | 97 | if save_to_file: 98 | with open(Path(cfg.paths.output_dir, "tags.log"), "w") as file: 99 | rich.print(cfg.tags, file=file) 100 | 101 | 102 | if __name__ == "__main__": 103 | from hydra import compose, initialize 104 | 105 | with initialize(version_base="1.2", config_path="../../configs"): 106 | cfg = compose(config_name="train.yaml", return_hydra_config=False, overrides=[]) 107 | print_config_tree(cfg, resolve=False, save_to_file=False) 108 | -------------------------------------------------------------------------------- /mart/utils/silent.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2022 Intel Corporation 3 | # 4 | # SPDX-License-Identifier: BSD-3-Clause 5 | # 6 | 7 | import logging 8 | from contextlib import ContextDecorator 9 | 10 | __all__ = ["silent"] 11 | 12 | 13 | class silent(ContextDecorator): 14 | """Suppress logging.""" 15 | 16 | DEFAULT_NAMES = [ 17 | "lightning.pytorch.utilities.rank_zero", 18 | "lightning.pytorch.accelerators.cuda", 19 | ] 20 | 21 | def __init__(self, names=None): 22 | if names is None: 23 | names = silent.DEFAULT_NAMES 24 | 25 | self.loggers = [logging.getLogger(name) for name in names] 26 | 27 | def __enter__(self): 28 | for logger in self.loggers: 29 | logger.propagate = False 30 | 31 | def __exit__(self, exc_type, exc_value, traceback): 32 | for logger in self.loggers: 33 | logger.propagate = False 34 | -------------------------------------------------------------------------------- /mart/utils/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | __all__ = [ 4 | "flatten_dict", 5 | ] 6 | 7 | 8 | def flatten_dict(d, delimiter="."): 9 | def get_dottedpath_items(d: dict, parent: Optional[str] = None): 10 | """Get pairs of the dotted path and the value from a nested dictionary.""" 11 | for name, value in d.items(): 12 | path = f"{parent}{delimiter}{name}" if parent else name 13 | if isinstance(value, dict): 14 | yield from get_dottedpath_items(value, parent=path) 15 | else: 16 | yield path, value 17 | 18 | ret = {} 19 | for key, value in get_dottedpath_items(d): 20 | if key in ret: 21 | raise KeyError(f"Key collision when flattening a dictionary: {key}") 22 | ret[key] = value 23 | 24 | return ret 25 | -------------------------------------------------------------------------------- /notebooks/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IntelLabs/MART/c86a059a9177b26e1fb0404216145059f826eddb/notebooks/.gitkeep -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "mart" 3 | version = "0.7.0a0" 4 | description = "Modular Adversarial Robustness Toolkit" 5 | readme = "README.md" 6 | license = {file = "LICENSE"} 7 | authors = [ 8 | { name = "Intel Corporation", email = "weilin.xu@intel.com" }, 9 | ] 10 | 11 | requires-python = ">=3.9" 12 | 13 | dependencies = [ 14 | ] 15 | 16 | [project.urls] 17 | Source = "https://github.com/IntelLabs/MART" 18 | 19 | [project.scripts] 20 | mart = "mart.__main__:main" 21 | 22 | [project.optional-dependencies] 23 | 24 | # These are required dependencies, but we make it flexible for users to adjust. 25 | core = [ 26 | # --------- hydra --------- # 27 | "hydra-core ~= 1.2.0", 28 | "hydra-colorlog ~= 1.2.0", 29 | "hydra-optuna-sweeper ~= 1.2.0", 30 | 31 | # --------- basics --------- # 32 | "pyrootutils ~= 1.0.4", # standardizing the project root setup 33 | "rich ~= 13.5.2", # beautiful text formatting in terminal 34 | "fire ~= 0.5.0", # automated CLI 35 | 36 | # ---- PyTorch ecosystem --- # 37 | "torch >= 2.0.1", 38 | "lightning[extra] ~= 2.1.4", # Full functionality including TensorboardX. 39 | "torchmetrics == 1.0.1", 40 | ] 41 | 42 | vision = [ 43 | "torchvision >= 0.15.2", 44 | "timm ~= 0.6.11", # pytorch image models 45 | ] 46 | 47 | objdet = [ 48 | "pycocotools ~= 2.0.6", # data format for object detection. 49 | "fiftyone ~= 0.21.4", # visualization for object detection 50 | ] 51 | 52 | # Comment out loggers to avoid lengthy dependency resolution. 53 | # It is rare that users need more than one logger. 54 | # And lightning[extra] already includes TensorboardX. 55 | loggers = [ 56 | # "wandb", 57 | # "neptune", 58 | # "mlflow", 59 | # "comet-ml", 60 | ] 61 | 62 | developer = [ 63 | "pre-commit ~= 4.2.0", # hooks for applying linters on commit 64 | "pytest ~= 7.2.0", # tests 65 | "sh ~= 1.14.3", # for running bash commands in some tests 66 | "wheel", # support setup.py 67 | "pytest-cov[toml]", 68 | "protobuf==3.20.0" 69 | ] 70 | 71 | full = [ 72 | "mart[core,vision,objdet,loggers,developer]", 73 | ] 74 | 75 | extras = [ 76 | ] 77 | 78 | [tool.setuptools] 79 | zip-safe = false 80 | 81 | [tool.setuptools.packages.find] 82 | include = ["mart*", "hydra_plugins*"] 83 | 84 | [tool.setuptools.package-data] 85 | "*" = ["*.yaml"] 86 | 87 | [tool.pytest.ini_options] 88 | addopts = [ 89 | "--color=yes", 90 | "--durations=0", 91 | "--strict-markers", 92 | "--doctest-modules", 93 | ] 94 | filterwarnings = [ 95 | "ignore::DeprecationWarning", 96 | "ignore::UserWarning", 97 | ] 98 | log_cli = "True" 99 | markers = [ 100 | "slow: slow tests", 101 | ] 102 | minversion = "6.0" 103 | testpaths = "tests/" 104 | 105 | [tool.coverage.report] 106 | exclude_lines = [ 107 | "pragma: nocover", 108 | "raise NotImplementedError", 109 | "raise NotImplementedError()", 110 | "if __name__ == .__main__.:", 111 | ] 112 | -------------------------------------------------------------------------------- /scripts/schedule.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Schedule execution of many runs 3 | # Run from root folder with: bash scripts/schedule.sh 4 | 5 | python src/train.py trainer.max_epochs=5 logger=csv 6 | 7 | python src/train.py trainer.max_epochs=10 logger=csv 8 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IntelLabs/MART/c86a059a9177b26e1fb0404216145059f826eddb/tests/__init__.py -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2022 Intel Corporation 3 | # 4 | # SPDX-License-Identifier: BSD-3-Clause 5 | # 6 | 7 | import os 8 | from pathlib import Path 9 | 10 | import pyrootutils 11 | import pytest 12 | import torch 13 | from hydra import compose, initialize 14 | from hydra.core.global_hydra import GlobalHydra 15 | from omegaconf import DictConfig, open_dict 16 | 17 | from mart.utils.imports import _HAS_PYCOCOTOOLS, _HAS_TIMM, _HAS_TORCHVISION 18 | 19 | root = Path(os.getcwd()) 20 | pyrootutils.set_root(path=root, dotenv=True, pythonpath=True) 21 | 22 | experiments_require_torchvision = [ 23 | "CIFAR10_CNN", 24 | "CIFAR10_CNN_Adv", 25 | ] 26 | 27 | experiments_require_torchvision_pycocotools = [ 28 | "COCO_TorchvisionFasterRCNN", 29 | "COCO_TorchvisionFasterRCNN_Adv", 30 | "COCO_TorchvisionRetinaNet", 31 | ] 32 | 33 | experiments_require_torchvision_and_timm = [ 34 | "ImageNet_Timm", 35 | ] 36 | 37 | # Only test experiments with installed packages in local environment. 38 | experiments_names = [] 39 | if _HAS_TORCHVISION: 40 | experiments_names += experiments_require_torchvision 41 | if _HAS_TORCHVISION and _HAS_PYCOCOTOOLS: 42 | experiments_names += experiments_require_torchvision_pycocotools 43 | if _HAS_TIMM and _HAS_TORCHVISION: 44 | experiments_names += experiments_require_torchvision_and_timm 45 | 46 | 47 | # Loads the configuration file from a given experiment 48 | def get_cfg(experiment): 49 | with initialize(version_base="1.2", config_path="../mart/configs"): 50 | params = "experiment=" + experiment 51 | cfg = compose(config_name="lightning.yaml", return_hydra_config=True, overrides=[params]) 52 | return cfg 53 | 54 | 55 | @pytest.fixture(scope="function", params=experiments_names) 56 | def cfg_experiment(request) -> DictConfig: 57 | cfg = get_cfg(request.param) 58 | 59 | yield cfg 60 | 61 | GlobalHydra.instance().clear() 62 | 63 | 64 | @pytest.fixture(scope="function") 65 | def input_data(): 66 | image_size = (3, 32, 32) 67 | return torch.randint(0, 256, image_size, dtype=torch.float) 68 | 69 | 70 | @pytest.fixture(scope="function") 71 | def target_data(): 72 | image_size = (3, 32, 32) 73 | return {"perturbable_mask": torch.ones(*image_size), "file_name": "test.jpg"} 74 | 75 | 76 | @pytest.fixture(scope="function") 77 | def perturbation(): 78 | torch.manual_seed(0) 79 | perturbation = torch.randint(0, 256, (3, 32, 32), dtype=torch.float) 80 | return perturbation 81 | -------------------------------------------------------------------------------- /tests/helpers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IntelLabs/MART/c86a059a9177b26e1fb0404216145059f826eddb/tests/helpers/__init__.py -------------------------------------------------------------------------------- /tests/helpers/dataset_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2022 Intel Corporation 3 | # 4 | # SPDX-License-Identifier: BSD-3-Clause 5 | # 6 | 7 | import itertools 8 | import json 9 | import pathlib 10 | 11 | import PIL 12 | import torch 13 | 14 | 15 | def create_image_file(root, name, size, **kwargs): 16 | """Create an image file from random data. 17 | Reference: https://github.com/pytorch/vision/blob/7b8a6db7f450e70a1e0fb07e07b30dda6a7e6e1c/test/datasets_utils.py#L675 18 | 19 | Args: 20 | root (str): Root for the images that will be created. 21 | name (str): File name. 22 | size (Sequence[int]): Size of the image that represents the ``(num_channels, height, width)``. 23 | Returns: 24 | pathlib.Path: Path to the created image file. 25 | """ 26 | 27 | image = torch.randint(0, 256, size, dtype=torch.uint8) 28 | file = pathlib.Path(root) / name 29 | 30 | # torch (num_channels x height x width) -> PIL (width x height x num_channels) 31 | image = image.permute(2, 1, 0) 32 | PIL.Image.fromarray(image.numpy()).save(file, **kwargs) 33 | return file 34 | 35 | 36 | def create_image_folder(root, name, file_name_fn, num_examples, size, **kwargs): 37 | """Create a folder of random images. 38 | Reference: https://github.com/pytorch/vision/blob/7b8a6db7f450e70a1e0fb07e07b30dda6a7e6e1c/test/datasets_utils.py#L711 39 | 40 | Args: 41 | root (str): Root directory the image folder will be placed in. 42 | name (str): Name of the image folder. 43 | file_name_fn (Callable[[int], str]): Should return a file name if called with the file index. 44 | num_examples (int): Number of images to create. 45 | size (Sequence[int]): Size of the images. 46 | Returns: 47 | List[pathlib.Path]: Paths to all created image files. 48 | """ 49 | root = pathlib.Path(root) / name 50 | pathlib.Path.mkdir(root, parents=True, exist_ok=True) 51 | 52 | created_files = [ 53 | create_image_file( 54 | root, file_name_fn(idx), size=size(idx) if callable(size) else size, **kwargs 55 | ) 56 | for idx in range(num_examples) 57 | ] 58 | 59 | return created_files 60 | 61 | 62 | def mask_to_rle(mask): 63 | """Convert mask to RLE. 64 | 65 | Args: 66 | mask (numpy.array): Binary mask. 67 | Returns: 68 | dict: RLE representation. 69 | """ 70 | rle = {"counts": [], "size": list(mask.shape)} 71 | counts = rle.get("counts") 72 | for i, (value, elements) in enumerate(itertools.groupby(mask.ravel(order="F"))): 73 | if i == 0 and value == 1: 74 | counts.append(0) 75 | counts.append(len(list(elements))) 76 | return rle 77 | 78 | 79 | def create_json(root, name, content): 80 | """Creates a JSON file with dataset annotations. 81 | 82 | Args: 83 | root (str): Directory where the annotations will be stored. 84 | name (str): Annotation's filename. 85 | content (dict): Dictionary with the annotation's content. 86 | """ 87 | file = pathlib.Path(root) / name 88 | with open(file, "w") as fh: 89 | json.dump(content, fh) 90 | return 91 | 92 | 93 | def combinations_grid(**kwargs): 94 | """Generates an array of parameters combinations. 95 | 96 | Reference: https://github.com/pytorch/vision/blob/7b8a6db7f450e70a1e0fb07e07b30dda6a7e6e1c/test/datasets_utils.py#L172 97 | """ 98 | return [dict(zip(kwargs.keys(), values)) for values in itertools.product(*kwargs.values())] 99 | -------------------------------------------------------------------------------- /tests/helpers/package_available.py: -------------------------------------------------------------------------------- 1 | import platform 2 | 3 | import pkg_resources 4 | from lightning.fabric.accelerators import TPUAccelerator 5 | 6 | 7 | def _package_available(package_name: str) -> bool: 8 | """Check if a package is available in your environment.""" 9 | try: 10 | return pkg_resources.require(package_name) is not None 11 | except pkg_resources.DistributionNotFound: 12 | return False 13 | 14 | 15 | _TPU_AVAILABLE = TPUAccelerator.is_available() 16 | 17 | _IS_WINDOWS = platform.system() == "Windows" 18 | 19 | _SH_AVAILABLE = not _IS_WINDOWS and _package_available("sh") 20 | 21 | _DEEPSPEED_AVAILABLE = not _IS_WINDOWS and _package_available("deepspeed") 22 | _FAIRSCALE_AVAILABLE = not _IS_WINDOWS and _package_available("fairscale") 23 | 24 | _WANDB_AVAILABLE = _package_available("wandb") 25 | _NEPTUNE_AVAILABLE = _package_available("neptune") 26 | _COMET_AVAILABLE = _package_available("comet_ml") 27 | _MLFLOW_AVAILABLE = _package_available("mlflow") 28 | -------------------------------------------------------------------------------- /tests/helpers/run_sh_command.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import pytest 4 | 5 | from tests.helpers.package_available import _SH_AVAILABLE 6 | 7 | if _SH_AVAILABLE: 8 | import sh 9 | 10 | 11 | def run_sh_command(command: List[str]): 12 | """Default method for executing shell commands with pytest and sh package.""" 13 | try: 14 | # Return stdout to help debug failed tests. 15 | return sh.python(command) 16 | except sh.ErrorReturnCode as e: 17 | msg = e.stderr.decode() 18 | # The error message could be empty. 19 | pytest.fail(msg=msg) 20 | -------------------------------------------------------------------------------- /tests/test_adversary_connector.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2022 Intel Corporation 3 | # 4 | # SPDX-License-Identifier: BSD-3-Clause 5 | # 6 | 7 | from itertools import cycle 8 | from unittest.mock import Mock 9 | 10 | from lightning.pytorch import LightningModule, Trainer 11 | 12 | from mart.callbacks import AdversaryConnector 13 | from mart.transforms import TupleBatchC15n 14 | 15 | 16 | def test_adversary_connector_callback(input_data, target_data, perturbation): 17 | batch = (input_data, target_data) 18 | input_adv = input_data + perturbation 19 | 20 | def adversary_fit(input, target, *, model): 21 | model(input, target) 22 | 23 | adversary = Mock(return_value=(input_adv, target_data), fit=adversary_fit) 24 | batch_c15n = TupleBatchC15n() 25 | callback = AdversaryConnector(test_adversary=adversary, batch_c15n=batch_c15n) 26 | trainer = Trainer( 27 | accelerator="cpu", 28 | devices=1, 29 | limit_test_batches=1, 30 | callbacks=callback, 31 | num_sanity_val_steps=0, 32 | logger=[], 33 | enable_model_summary=False, 34 | enable_checkpointing=False, 35 | ) 36 | 37 | # Call attack_step() if defined, instead of training_step() 38 | # `model` must be a `LightningModule` 39 | model_attack = LightningModule() 40 | # Trick PL that test_step is overridden. 41 | model_attack.test_step = Mock(wraps=lambda *args: None) 42 | model_attack.attack_step = Mock(wraps=lambda *args: None) 43 | model_attack.training_step = Mock(wraps=lambda *args: None) 44 | trainer.test(model_attack, dataloaders=cycle([batch])) 45 | model_attack.attack_step.assert_called_once() 46 | # training_step is not called when there is attack_step. 47 | model_attack.training_step.assert_not_called() 48 | 49 | # Call training_step(), because attack_step() is not defined. 50 | model_training = LightningModule() 51 | model_training.test_step = Mock(wraps=lambda *args: None) 52 | model_training.training_step = Mock(wraps=lambda *args: None) 53 | trainer.test(model_training, dataloaders=cycle([batch])) 54 | model_training.training_step.assert_called_once() 55 | -------------------------------------------------------------------------------- /tests/test_composer.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2022 Intel Corporation 3 | # 4 | # SPDX-License-Identifier: BSD-3-Clause 5 | # 6 | 7 | from unittest.mock import Mock 8 | 9 | import torch 10 | 11 | from mart.attack.composer import Additive, Composer, Mask, Overlay 12 | 13 | 14 | def test_additive_composer_forward(input_data, target_data, perturbation): 15 | perturber = Mock(return_value=perturbation) 16 | modules = {"additive": Additive()} 17 | sequence = {"seq010": {"additive": ["perturbation", "input"]}} 18 | composer = Composer(perturber=perturber, modules=modules, sequence=sequence) 19 | 20 | output = composer(input=input_data, target=target_data) 21 | expected_output = input_data + perturbation 22 | torch.testing.assert_close(output, expected_output, equal_nan=True) 23 | 24 | 25 | def test_overlay_composer_forward(input_data, target_data, perturbation): 26 | perturber = Mock(return_value=perturbation) 27 | modules = {"overlay": Overlay()} 28 | sequence = {"seq010": {"overlay": ["perturbation", "input", "target.perturbable_mask"]}} 29 | composer = Composer(perturber=perturber, modules=modules, sequence=sequence) 30 | 31 | output = composer(input=input_data, target=target_data) 32 | mask = target_data["perturbable_mask"] 33 | mask = mask.to(input_data) 34 | expected_output = input_data * (1 - mask) + perturbation 35 | torch.testing.assert_close(output, expected_output, equal_nan=True) 36 | 37 | 38 | def test_mask_additive_composer_forward(): 39 | input = torch.zeros((2, 2)) 40 | perturbation = torch.ones((2, 2)) 41 | target = {"perturbable_mask": torch.eye(2)} 42 | expected_output = torch.eye(2) 43 | 44 | perturber = Mock(return_value=perturbation) 45 | modules = {"mask": Mask(), "additive": Additive()} 46 | sequence = { 47 | "seq010": {"mask": ["perturbation", "target.perturbable_mask"]}, 48 | "seq020": {"additive": ["mask", "input"]}, 49 | } 50 | composer = Composer(perturber=perturber, modules=modules, sequence=sequence) 51 | 52 | output = composer(input=input, target=target) 53 | torch.testing.assert_close(output, expected_output, equal_nan=True) 54 | -------------------------------------------------------------------------------- /tests/test_configs.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2022 Intel Corporation 3 | # 4 | # SPDX-License-Identifier: BSD-3-Clause 5 | # 6 | 7 | from unittest.mock import Mock, patch 8 | 9 | import hydra 10 | import pytest 11 | from hydra.core.hydra_config import HydraConfig 12 | from omegaconf import DictConfig 13 | 14 | 15 | @patch("torchvision.datasets.imagenet.ImageNet.__init__") 16 | @patch("mart.datamodules.coco.CocoDetection_.__init__") 17 | @patch("torchvision.datasets.CIFAR10.__init__") 18 | def test_experiment_config( 19 | mock_cifar10: Mock, mock_coco: Mock, mock_imagenet: Mock, cfg_experiment: DictConfig 20 | ): 21 | assert cfg_experiment 22 | assert cfg_experiment.datamodule 23 | assert cfg_experiment.model 24 | assert cfg_experiment.trainer 25 | 26 | # setup mocks 27 | mock_cifar10.return_value = None 28 | mock_imagenet.return_value = None 29 | mock_coco.return_value = None 30 | 31 | HydraConfig().set_config(cfg_experiment) 32 | 33 | hydra.utils.instantiate(cfg_experiment.datamodule) 34 | hydra.utils.instantiate(cfg_experiment.model) 35 | hydra.utils.instantiate(cfg_experiment.trainer) 36 | -------------------------------------------------------------------------------- /tests/test_dependency.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2022 Intel Corporation 3 | # 4 | # SPDX-License-Identifier: BSD-3-Clause 5 | # 6 | 7 | import os 8 | 9 | from mart.utils.imports import ( 10 | _HAS_FIFTYONE, 11 | _HAS_LIGHTNING, 12 | _HAS_PYCOCOTOOLS, 13 | _HAS_TIMM, 14 | _HAS_TORCHVISION, 15 | ) 16 | 17 | 18 | def test_dependency_on_ci(): 19 | if os.getenv("CI") == "true": 20 | assert ( 21 | _HAS_FIFTYONE 22 | and _HAS_TIMM 23 | and _HAS_PYCOCOTOOLS 24 | and _HAS_TORCHVISION 25 | and _HAS_LIGHTNING is True 26 | ), "The dependency is not complete on CI, thus some tests are skipped." 27 | -------------------------------------------------------------------------------- /tests/test_enforcer.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2022 Intel Corporation 3 | # 4 | # SPDX-License-Identifier: BSD-3-Clause 5 | # 6 | 7 | import pytest 8 | import torch 9 | 10 | from mart.attack.enforcer import ConstraintViolated, Enforcer, Integer, Lp, Mask, Range 11 | 12 | 13 | def test_constraint_range(): 14 | input = torch.tensor([0, 0, 0]) 15 | target = None 16 | 17 | constraint = Range(min=0, max=255) 18 | 19 | perturbation = torch.tensor([0, 128, 255]) 20 | constraint(input + perturbation, input=input, target=target) 21 | 22 | with pytest.raises(ConstraintViolated): 23 | perturbation = torch.tensor([0, -1, 255]) 24 | constraint(input + perturbation, input=input, target=target) 25 | perturbation = torch.tensor([0, 1, 256]) 26 | constraint(input + perturbation, input=input, target=target) 27 | 28 | 29 | def test_constraint_l2(): 30 | input = torch.zeros((3, 10, 10)) 31 | batch_input = torch.stack((input, input)) 32 | 33 | constraint = Lp(eps=17.33, p=2, dim=[-1, -2, -3]) 34 | target = None 35 | 36 | # (3*10*10)**0.5 = 17.3205 37 | perturbation = torch.ones((3, 10, 10)) 38 | constraint(input + perturbation, input=input, target=target) 39 | constraint(batch_input + perturbation, input=batch_input, target=target) 40 | 41 | with pytest.raises(ConstraintViolated): 42 | constraint(batch_input + perturbation * 2, input=input, target=target) 43 | constraint(batch_input + perturbation * 2, input=batch_input, target=target) 44 | 45 | 46 | def test_constraint_integer(): 47 | input, target = None, None 48 | 49 | constraint = Integer() 50 | 51 | input_adv = torch.tensor([1.0, 2.0]) 52 | constraint(input_adv, input=input, target=target) 53 | 54 | input_adv = torch.tensor([1.0, 2.001]) 55 | with pytest.raises(ConstraintViolated): 56 | constraint(input_adv, input=input, target=target) 57 | 58 | 59 | def test_constraint_mask(): 60 | input = torch.zeros((3, 2, 2)) 61 | perturbation = torch.ones((3, 2, 2)) 62 | mask = torch.tensor([[0.0, 1.0], [1.0, 0.0]]) 63 | target = {"perturbable_mask": mask} 64 | 65 | constraint = Mask() 66 | 67 | constraint(input + perturbation * mask, input=input, target=target) 68 | with pytest.raises(ConstraintViolated): 69 | constraint(input + perturbation, input=input, target=target) 70 | 71 | 72 | def test_enforcer_non_modality(): 73 | enforcer = Enforcer(constraints={"range": Range(min=0, max=255)}) 74 | 75 | input = torch.tensor([0, 0, 0]) 76 | perturbation = torch.tensor([0, 128, 255]) 77 | input_adv = input + perturbation 78 | target = None 79 | 80 | # tensor input. 81 | enforcer(input_adv, input=input, target=target) 82 | # list of tensor input. 83 | enforcer([input_adv], input=[input], target=[target]) 84 | # tuple of tensor input. 85 | enforcer((input_adv,), input=(input,), target=(target,)) 86 | 87 | perturbation = torch.tensor([0, -1, 255]) 88 | input_adv = input + perturbation 89 | 90 | with pytest.raises(ConstraintViolated): 91 | enforcer(input_adv, input=input, target=target) 92 | 93 | with pytest.raises(ConstraintViolated): 94 | enforcer([input_adv], input=[input], target=[target]) 95 | 96 | with pytest.raises(ConstraintViolated): 97 | enforcer((input_adv,), input=(input,), target=(target,)) 98 | 99 | 100 | # def test_enforcer_modality(): 101 | # # Assume a rgb modality. 102 | # enforcer = Enforcer(rgb={"range": Range(min=0, max=255)}) 103 | # 104 | # input = torch.tensor([0, 0, 0]) 105 | # perturbation = torch.tensor([0, 128, 255]) 106 | # input_adv = input + perturbation 107 | # target = None 108 | # 109 | # # Dictionary input. 110 | # enforcer({"rgb": input_adv}, input={"rgb": input}, target=target) 111 | # # List of dictionary input. 112 | # enforcer([{"rgb": input_adv}], input=[{"rgb": input}], target=[target]) 113 | # # Tuple of dictionary input. 114 | # enforcer(({"rgb": input_adv},), input=({"rgb": input},), target=(target,)) 115 | # 116 | # perturbation = torch.tensor([0, -1, 255]) 117 | # input_adv = input + perturbation 118 | # 119 | # with pytest.raises(ConstraintViolated): 120 | # enforcer({"rgb": input_adv}, input={"rgb": input}, target=target) 121 | # 122 | # with pytest.raises(ConstraintViolated): 123 | # enforcer([{"rgb": input_adv}], input=[{"rgb": input}], target=[target]) 124 | # 125 | # with pytest.raises(ConstraintViolated): 126 | # enforcer(({"rgb": input_adv},), input=({"rgb": input},), target=(target,)) 127 | -------------------------------------------------------------------------------- /tests/test_gradient.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2022 Intel Corporation 3 | # 4 | # SPDX-License-Identifier: BSD-3-Clause 5 | # 6 | 7 | import pytest 8 | import torch 9 | 10 | from mart.attack.gradient_modifier import LpNormalizer, Sign 11 | 12 | 13 | def test_gradient_sign(input_data): 14 | # Don't share input_data with other tests, because the gradient would be changed. 15 | input_data = torch.tensor([1.0, 2.0, 3.0]) 16 | input_data.grad = torch.tensor([-1.0, 3.0, 0.0]) 17 | 18 | grad_modifier = Sign() 19 | grad_modifier(input_data) 20 | expected_grad = torch.tensor([-1.0, 1.0, 0.0]) 21 | torch.testing.assert_close(input_data.grad, expected_grad) 22 | 23 | 24 | def test_gradient_lp_normalizer(): 25 | # Don't share input_data with other tests, because the gradient would be changed. 26 | input_data = torch.tensor([1.0, 2.0, 3.0]) 27 | input_data.grad = torch.tensor([-1.0, 3.0, 0.0]) 28 | 29 | p = 1 30 | grad_modifier = LpNormalizer(p) 31 | grad_modifier(input_data) 32 | expected_grad = torch.tensor([-0.25, 0.75, 0.0]) 33 | torch.testing.assert_close(input_data.grad, expected_grad) 34 | -------------------------------------------------------------------------------- /tests/test_initializer.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2022 Intel Corporation 3 | # 4 | # SPDX-License-Identifier: BSD-3-Clause 5 | # 6 | 7 | import pytest 8 | import torch 9 | 10 | from mart.attack.initializer import Constant, Uniform, UniformLp 11 | 12 | 13 | def test_constant_initializer(perturbation): 14 | constant = 1 15 | initializer = Constant(constant) 16 | initializer(perturbation) 17 | expected_perturbation = torch.ones(perturbation.shape) 18 | torch.testing.assert_close(perturbation, expected_perturbation) 19 | 20 | 21 | def test_uniform_initializer(perturbation): 22 | min = 0 23 | max = 100 24 | 25 | initializer = Uniform(min, max) 26 | initializer(perturbation) 27 | 28 | assert torch.max(perturbation) <= max 29 | assert torch.min(perturbation) >= min 30 | 31 | 32 | @pytest.mark.parametrize("p", [1, torch.inf]) 33 | def test_uniform_lp_initializer(p, perturbation): 34 | eps = 10 35 | 36 | initializer = UniformLp(eps, p) 37 | initializer(perturbation) 38 | 39 | assert torch.max(perturbation) <= eps 40 | assert torch.min(perturbation) >= -eps 41 | -------------------------------------------------------------------------------- /tests/test_perturber.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2022 Intel Corporation 3 | # 4 | # SPDX-License-Identifier: BSD-3-Clause 5 | # 6 | 7 | from unittest.mock import Mock 8 | 9 | import pytest 10 | 11 | from mart.attack import Perturber 12 | 13 | 14 | def test_forward(input_data, target_data): 15 | initializer = Mock() 16 | projector = Mock() 17 | 18 | perturber = Perturber(initializer=initializer, projector=projector) 19 | 20 | perturber.configure_perturbation(input_data) 21 | 22 | output = perturber(input=input_data, target=target_data) 23 | 24 | initializer.assert_called_once() 25 | projector.assert_called_once() 26 | 27 | 28 | def test_misconfiguration(input_data, target_data): 29 | initializer = Mock() 30 | projector = Mock() 31 | 32 | perturber = Perturber(initializer=initializer, projector=projector) 33 | 34 | with pytest.raises(RuntimeError): 35 | perturber(input=input_data, target=target_data) 36 | 37 | with pytest.raises(RuntimeError): 38 | perturber.parameters() 39 | 40 | 41 | def test_configure_perturbation(input_data, target_data): 42 | initializer = Mock() 43 | projector = Mock() 44 | 45 | perturber = Perturber(initializer=initializer, projector=projector) 46 | 47 | perturber.configure_perturbation(input_data) 48 | perturber.configure_perturbation(input_data) 49 | perturber.configure_perturbation(input_data[:, :16, :16]) 50 | 51 | # Each call to configure_perturbation should re-initialize the perturbation 52 | assert initializer.call_count == 3 53 | 54 | 55 | def test_parameters(input_data, target_data): 56 | initializer = Mock() 57 | projector = Mock() 58 | 59 | perturber = Perturber(initializer=initializer, projector=projector) 60 | 61 | perturber.configure_perturbation(input_data) 62 | 63 | # Make sure each parameter in optimizer requires a gradient 64 | for param in perturber.parameters(): 65 | assert param.requires_grad 66 | -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2022 Intel Corporation 3 | # 4 | # SPDX-License-Identifier: BSD-3-Clause 5 | # 6 | 7 | import pytest 8 | 9 | from mart.utils import flatten_dict 10 | 11 | 12 | def test_flatten_dict(): 13 | d = {"a": 1, "b": {"c": 2, "d": 3}, "b.e": 4} 14 | assert flatten_dict(d) == {"a": 1, "b.c": 2, "b.d": 3, "b.e": 4} 15 | 16 | 17 | def test_flatten_dict_key_collision(): 18 | d = {"a": 1, "b": {"c": 2, "d": 3}, "b.c": 4} 19 | with pytest.raises(KeyError): 20 | flatten_dict(d) 21 | -------------------------------------------------------------------------------- /tests/test_visualizer.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2022 Intel Corporation 3 | # 4 | # SPDX-License-Identifier: BSD-3-Clause 5 | # 6 | 7 | import pytest 8 | 9 | from mart.utils.imports import _HAS_TORCHVISION 10 | 11 | if not _HAS_TORCHVISION: 12 | pytest.skip("test requires that torchvision is installed", allow_module_level=True) 13 | 14 | from unittest.mock import Mock 15 | 16 | from PIL import Image, ImageChops 17 | from torchvision.transforms import ToPILImage 18 | 19 | from mart.attack import Adversary 20 | from mart.callbacks import PerturbedImageVisualizer 21 | 22 | 23 | def test_visualizer_run_end(input_data, target_data, perturbation, tmp_path): 24 | folder = tmp_path / "test" 25 | input_list = [input_data] 26 | target_list = [target_data] 27 | 28 | # simulate an addition perturbation 29 | def perturb(input, target): 30 | result = [sample + perturbation for sample in input] 31 | return result, target 32 | 33 | adversary = Mock(spec=Adversary, side_effect=perturb) 34 | trainer = Mock() 35 | outputs = Mock() 36 | target_model = Mock() 37 | 38 | # Canonical batch in Adversary. 39 | batch = (input_list, target_list, target_model) 40 | 41 | visualizer = PerturbedImageVisualizer(folder) 42 | visualizer.on_train_batch_end(trainer, adversary, outputs, batch, 0) 43 | visualizer.on_train_end(trainer, adversary) 44 | 45 | # verify that the visualizer created the JPG file 46 | expected_output_path = folder / target_data["file_name"] 47 | assert expected_output_path.exists() 48 | 49 | # verify image file content 50 | perturbed_img = input_data + perturbation 51 | converter = ToPILImage() 52 | expected_img = converter(perturbed_img / 255) 53 | expected_img.save(folder / "test_expected.jpg") 54 | 55 | stored_img = Image.open(expected_output_path) 56 | expected_stored_img = Image.open(folder / "test_expected.jpg") 57 | diff = ImageChops.difference(expected_stored_img, stored_img) 58 | assert not diff.getbbox() 59 | --------------------------------------------------------------------------------