├── .github └── workflows │ ├── build.yaml │ ├── codespell.yml │ └── release-pypi.yml ├── .gitignore ├── LICENSE ├── Makefile ├── Notebooks ├── AROS.ipynb └── Ablation_Study.ipynb ├── README.md ├── aros_node ├── __init__.py ├── data_loader.py ├── evaluate.py ├── stability_loss_function.py ├── utils.py └── version.py ├── main.py ├── pyproject.toml ├── reinstall.sh ├── requirements.txt ├── setup.cfg └── tests └── test_dataloaders.py /.github/workflows/build.yaml: -------------------------------------------------------------------------------- 1 | name: Python package 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | pull_request: 8 | branches: 9 | - main 10 | 11 | jobs: 12 | build: 13 | strategy: 14 | fail-fast: true 15 | matrix: 16 | os: [ubuntu-latest] 17 | python-version: ["3.12"] 18 | torch-version: ["2.4.0"] 19 | include: 20 | - os: windows-latest 21 | torch-version: 2.4.0 22 | python-version: "3.12" 23 | 24 | runs-on: ${{ matrix.os }} 25 | 26 | - name: Cache dependencies 27 | id: pip-cache 28 | uses: actions/cache@v3 29 | with: 30 | path: ~/.cache/pip 31 | key: pip-os_${{ runner.os }}-python_${{ matrix.python-version }}-torch_${{ matrix.torch-version }} 32 | 33 | - name: Checkout code 34 | uses: actions/checkout@v2 35 | 36 | - name: Set up Python ${{ matrix.python-version }} 37 | uses: actions/setup-python@v4 38 | with: 39 | python-version: ${{ matrix.python-version }} 40 | 41 | - name: Install package 42 | run: | 43 | python -m pip install git+https://github.com/RobustBench/robustbench.git 44 | python -m pip install --upgrade pip setuptools wheel 45 | python -m pip install packaging==24.2 46 | python -m pip install torch==${{ matrix.torch-version }} --extra-index-url https://download.pytorch.org/whl/cpu 47 | pip install '.[dev]' 48 | 49 | - name: Run pytest tests 50 | timeout-minutes: 10 51 | run: | 52 | pip install pytest 53 | python -m pytest 54 | 55 | - name: Build package 56 | run: | 57 | make build 58 | 59 | - name: Check reinstall script 60 | timeout-minutes: 3 61 | run: | 62 | ./reinstall.sh 63 | -------------------------------------------------------------------------------- /.github/workflows/codespell.yml: -------------------------------------------------------------------------------- 1 | --- 2 | name: Codespell 3 | 4 | on: 5 | push: 6 | branches: [main] 7 | pull_request: 8 | branches: [main] 9 | 10 | jobs: 11 | codespell: 12 | name: Check for spelling errors 13 | runs-on: ubuntu-latest 14 | 15 | steps: 16 | - name: Checkout 17 | uses: actions/checkout@v3 18 | - name: Codespell 19 | uses: codespell-project/actions-codespell@v1 20 | with: 21 | ignore_words_list: aros, fpr, tpr, idx, fpr95 -------------------------------------------------------------------------------- /.github/workflows/release-pypi.yml: -------------------------------------------------------------------------------- 1 | name: release 2 | 3 | on: 4 | push: 5 | tags: 6 | - 'v*.*.*' 7 | pull_request: 8 | branches: 9 | - main 10 | types: 11 | - labeled 12 | - opened 13 | - edited 14 | - synchronize 15 | - reopened 16 | 17 | jobs: 18 | release: 19 | runs-on: ubuntu-latest 20 | 21 | steps: 22 | - name: Cache dependencies 23 | id: pip-cache 24 | uses: actions/cache@v3 25 | with: 26 | path: ~/.cache/pip 27 | key: ${{ runner.os }}-pip 28 | restore-keys: | 29 | ${{ runner.os }}-pip 30 | 31 | - name: Install dependencies 32 | run: | 33 | pip install --upgrade pip 34 | pip install wheel 35 | # see https://github.com/pypa/twine/issues/1216#issuecomment-2629069669 36 | pip install "packaging>=24.2" 37 | 38 | - name: Checkout code 39 | uses: actions/checkout@v3 40 | 41 | 42 | - name: Build and publish to PyPI 43 | if: ${{ github.event_name == 'push' }} 44 | env: 45 | TWINE_USERNAME: __token__ 46 | TWINE_PASSWORD: ${{ secrets.TWINE_API_KEY }} 47 | run: | 48 | make dist 49 | ls dist/ 50 | tar tvf dist/aros_node-*.tar.gz 51 | python3 -m twine upload dist/* 52 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | .tar.gz 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | share/python-wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | *.py,cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | cover/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | db.sqlite3-journal 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | .pybuilder/ 79 | target/ 80 | 81 | # Jupyter Notebook 82 | .ipynb_checkpoints 83 | 84 | # IPython 85 | profile_default/ 86 | ipython_config.py 87 | 88 | # pyenv 89 | # For a library or package, you might want to ignore these files since the code is 90 | # intended to run in multiple environments; otherwise, check them in: 91 | # .python-version 92 | 93 | # pipenv 94 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 95 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 96 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 97 | # install all needed dependencies. 98 | #Pipfile.lock 99 | 100 | # poetry 101 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 102 | # This is especially recommended for binary packages to ensure reproducibility, and is more 103 | # commonly ignored for libraries. 104 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 105 | #poetry.lock 106 | 107 | # pdm 108 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 109 | #pdm.lock 110 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 111 | # in version control. 112 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 113 | .pdm.toml 114 | .pdm-python 115 | .pdm-build/ 116 | 117 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 118 | __pypackages__/ 119 | 120 | # Celery stuff 121 | celerybeat-schedule 122 | celerybeat.pid 123 | 124 | # SageMath parsed files 125 | *.sage.py 126 | 127 | # Environments 128 | .env 129 | .venv 130 | env/ 131 | venv/ 132 | ENV/ 133 | env.bak/ 134 | venv.bak/ 135 | 136 | # Spyder project settings 137 | .spyderproject 138 | .spyproject 139 | 140 | # Rope project settings 141 | .ropeproject 142 | 143 | # mkdocs documentation 144 | /site 145 | 146 | # mypy 147 | .mypy_cache/ 148 | .dmypy.json 149 | dmypy.json 150 | 151 | # Pyre type checker 152 | .pyre/ 153 | 154 | # pytype static type analyzer 155 | .pytype/ 156 | 157 | # Cython debug symbols 158 | cython_debug/ 159 | 160 | # PyCharm 161 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 162 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 163 | # and can be added to the global gitignore or merged into this file. For a more nuclear 164 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 165 | #.idea/ 166 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | AROS_VERSION := 0.0.1 2 | 3 | dist: 4 | python3 -m pip install virtualenv 5 | python3 -m pip install --upgrade build twine 6 | python3 -m build --wheel --sdist 7 | 8 | build: dist 9 | 10 | archlinux: 11 | mkdir -p dist/arch 12 | cp PKGBUILD dist/arch 13 | cp dist/aros_node-${AROS_VERSION}.tar.gz dist/arch 14 | (cd dist/arch; makepkg --skipchecksums -f) 15 | -------------------------------------------------------------------------------- /Notebooks/Ablation_Study.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "colab_type": "text", 7 | "id": "view-in-github" 8 | }, 9 | "source": [ 10 | "\"Open" 11 | ] 12 | }, 13 | { 14 | "cell_type": "markdown", 15 | "metadata": { 16 | "id": "G1Ues10_fww5" 17 | }, 18 | "source": [ 19 | "## AROS, Ablation Study" 20 | ] 21 | }, 22 | { 23 | "cell_type": "code", 24 | "execution_count": null, 25 | "metadata": {}, 26 | "outputs": [], 27 | "source": [ 28 | "!git clone https://github.com/AdaptiveMotorControlLab/AROS" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": null, 34 | "metadata": {}, 35 | "outputs": [], 36 | "source": [ 37 | "!pip install -r ./AROS/requirements.txt\n", 38 | "cd ./AROS/AROS" 39 | ] 40 | }, 41 | { 42 | "cell_type": "code", 43 | "execution_count": null, 44 | "metadata": {}, 45 | "outputs": [], 46 | "source": [ 47 | "import argparse\n", 48 | "import torch\n", 49 | "import torch.nn as nn\n", 50 | "from tqdm.notebook import tqdm" 51 | ] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "execution_count": null, 56 | "metadata": { 57 | "id": "S2YKR1ps79o3", 58 | "outputId": "499e2580-c01b-4f2d-a59d-890f798b3295" 59 | }, 60 | "outputs": [ 61 | { 62 | "name": "stdout", 63 | "output_type": "stream", 64 | "text": [ 65 | "Defaulting to user installation because normal site-packages is not writeable\n", 66 | "Collecting git+https://github.com/RobustBench/robustbench.git (from -r requirements.txt (line 3))\n", 67 | " Cloning https://github.com/RobustBench/robustbench.git to /tmp/pip-req-build-cdsd2hhb\n", 68 | " Running command git clone --filter=blob:none --quiet https://github.com/RobustBench/robustbench.git /tmp/pip-req-build-cdsd2hhb\n", 69 | " Resolved https://github.com/RobustBench/robustbench.git to commit 776bc95bb4167827fb102a32ac5aea62e46cfaab\n", 70 | " Preparing metadata (setup.py) ... \u001b[?25ldone\n", 71 | "\u001b[?25hRequirement already satisfied: geotorch in /usr/local/lib/python3.10/dist-packages (from -r requirements.txt (line 1)) (0.3.0)\n", 72 | "Requirement already satisfied: torchdiffeq in /usr/local/lib/python3.10/dist-packages (from -r requirements.txt (line 2)) (0.2.4)\n", 73 | "Requirement already satisfied: torch>=1.9 in /usr/local/lib/python3.10/dist-packages (from geotorch->-r requirements.txt (line 1)) (2.4.1)\n", 74 | "Requirement already satisfied: scipy>=1.4.0 in /usr/local/lib/python3.10/dist-packages (from torchdiffeq->-r requirements.txt (line 2)) (1.14.1)\n", 75 | "Collecting autoattack@ git+https://github.com/fra31/auto-attack.git@a39220048b3c9f2cca9a4d3a54604793c68eca7e#egg=autoattack\n", 76 | " Using cached autoattack-0.1-py3-none-any.whl\n", 77 | "Requirement already satisfied: Jinja2~=3.1.2 in /usr/local/lib/python3.10/dist-packages (from robustbench==1.1->-r requirements.txt (line 3)) (3.1.4)\n", 78 | "Requirement already satisfied: gdown==5.1.0 in /home/hossein/.local/lib/python3.10/site-packages (from robustbench==1.1->-r requirements.txt (line 3)) (5.1.0)\n", 79 | "Requirement already satisfied: numpy>=1.19.4 in /usr/local/lib/python3.10/dist-packages (from robustbench==1.1->-r requirements.txt (line 3)) (2.1.2)\n", 80 | "Requirement already satisfied: pandas>=1.3.5 in /usr/local/lib/python3.10/dist-packages (from robustbench==1.1->-r requirements.txt (line 3)) (2.2.3)\n", 81 | "Requirement already satisfied: pyyaml in /usr/local/lib/python3.10/dist-packages (from robustbench==1.1->-r requirements.txt (line 3)) (6.0.2)\n", 82 | "Requirement already satisfied: requests>=2.25.0 in /usr/local/lib/python3.10/dist-packages (from robustbench==1.1->-r requirements.txt (line 3)) (2.32.3)\n", 83 | "Requirement already satisfied: timm>=0.9.0 in /usr/local/lib/python3.10/dist-packages (from robustbench==1.1->-r requirements.txt (line 3)) (1.0.9)\n", 84 | "Requirement already satisfied: torchvision>=0.8.2 in /usr/local/lib/python3.10/dist-packages (from robustbench==1.1->-r requirements.txt (line 3)) (0.19.1)\n", 85 | "Requirement already satisfied: tqdm>=4.56.1 in /usr/local/lib/python3.10/dist-packages (from robustbench==1.1->-r requirements.txt (line 3)) (4.66.5)\n", 86 | "Requirement already satisfied: beautifulsoup4 in /usr/local/lib/python3.10/dist-packages (from gdown==5.1.0->robustbench==1.1->-r requirements.txt (line 3)) (4.12.3)\n", 87 | "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from gdown==5.1.0->robustbench==1.1->-r requirements.txt (line 3)) (3.16.1)\n", 88 | "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from Jinja2~=3.1.2->robustbench==1.1->-r requirements.txt (line 3)) (2.1.5)\n", 89 | "Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.10/dist-packages (from pandas>=1.3.5->robustbench==1.1->-r requirements.txt (line 3)) (2.9.0.post0)\n", 90 | "Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.10/dist-packages (from pandas>=1.3.5->robustbench==1.1->-r requirements.txt (line 3)) (2024.2)\n", 91 | "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas>=1.3.5->robustbench==1.1->-r requirements.txt (line 3)) (2024.2)\n", 92 | "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests>=2.25.0->robustbench==1.1->-r requirements.txt (line 3)) (2.2.3)\n", 93 | "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests>=2.25.0->robustbench==1.1->-r requirements.txt (line 3)) (3.3.2)\n", 94 | "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests>=2.25.0->robustbench==1.1->-r requirements.txt (line 3)) (2024.8.30)\n", 95 | "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests>=2.25.0->robustbench==1.1->-r requirements.txt (line 3)) (3.10)\n", 96 | "Requirement already satisfied: safetensors in /usr/local/lib/python3.10/dist-packages (from timm>=0.9.0->robustbench==1.1->-r requirements.txt (line 3)) (0.4.5)\n", 97 | "Requirement already satisfied: huggingface_hub in /usr/local/lib/python3.10/dist-packages (from timm>=0.9.0->robustbench==1.1->-r requirements.txt (line 3)) (0.25.2)\n", 98 | "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch>=1.9->geotorch->-r requirements.txt (line 1)) (12.1.105)\n", 99 | "Requirement already satisfied: nvidia-cusparse-cu12==12.1.0.106 in /usr/local/lib/python3.10/dist-packages (from torch>=1.9->geotorch->-r requirements.txt (line 1)) (12.1.0.106)\n", 100 | "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch>=1.9->geotorch->-r requirements.txt (line 1)) (12.1.105)\n", 101 | "Requirement already satisfied: nvidia-nccl-cu12==2.20.5 in /usr/local/lib/python3.10/dist-packages (from torch>=1.9->geotorch->-r requirements.txt (line 1)) (2.20.5)\n", 102 | "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch>=1.9->geotorch->-r requirements.txt (line 1)) (12.1.105)\n", 103 | "Requirement already satisfied: typing-extensions>=4.8.0 in /usr/local/lib/python3.10/dist-packages (from torch>=1.9->geotorch->-r requirements.txt (line 1)) (4.12.2)\n", 104 | "Requirement already satisfied: nvidia-cublas-cu12==12.1.3.1 in /usr/local/lib/python3.10/dist-packages (from torch>=1.9->geotorch->-r requirements.txt (line 1)) (12.1.3.1)\n", 105 | "Requirement already satisfied: nvidia-curand-cu12==10.3.2.106 in /usr/local/lib/python3.10/dist-packages (from torch>=1.9->geotorch->-r requirements.txt (line 1)) (10.3.2.106)\n", 106 | "Requirement already satisfied: nvidia-cudnn-cu12==9.1.0.70 in /usr/local/lib/python3.10/dist-packages (from torch>=1.9->geotorch->-r requirements.txt (line 1)) (9.1.0.70)\n", 107 | "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=1.9->geotorch->-r requirements.txt (line 1)) (3.3)\n", 108 | "Requirement already satisfied: nvidia-cufft-cu12==11.0.2.54 in /usr/local/lib/python3.10/dist-packages (from torch>=1.9->geotorch->-r requirements.txt (line 1)) (11.0.2.54)\n", 109 | "Requirement already satisfied: triton==3.0.0 in /usr/local/lib/python3.10/dist-packages (from torch>=1.9->geotorch->-r requirements.txt (line 1)) (3.0.0)\n", 110 | "Requirement already satisfied: nvidia-nvtx-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch>=1.9->geotorch->-r requirements.txt (line 1)) (12.1.105)\n", 111 | "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch>=1.9->geotorch->-r requirements.txt (line 1)) (1.13.3)\n", 112 | "Requirement already satisfied: nvidia-cusolver-cu12==11.4.5.107 in /usr/local/lib/python3.10/dist-packages (from torch>=1.9->geotorch->-r requirements.txt (line 1)) (11.4.5.107)\n", 113 | "Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch>=1.9->geotorch->-r requirements.txt (line 1)) (2024.9.0)\n", 114 | "Requirement already satisfied: nvidia-nvjitlink-cu12 in /usr/local/lib/python3.10/dist-packages (from nvidia-cusolver-cu12==11.4.5.107->torch>=1.9->geotorch->-r requirements.txt (line 1)) (12.6.77)\n", 115 | "Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in /usr/local/lib/python3.10/dist-packages (from torchvision>=0.8.2->robustbench==1.1->-r requirements.txt (line 3)) (10.4.0)\n", 116 | "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.8.2->pandas>=1.3.5->robustbench==1.1->-r requirements.txt (line 3)) (1.16.0)\n", 117 | "Requirement already satisfied: soupsieve>1.2 in /usr/local/lib/python3.10/dist-packages (from beautifulsoup4->gdown==5.1.0->robustbench==1.1->-r requirements.txt (line 3)) (2.6)\n", 118 | "Requirement already satisfied: packaging>=20.9 in /usr/local/lib/python3.10/dist-packages (from huggingface_hub->timm>=0.9.0->robustbench==1.1->-r requirements.txt (line 3)) (24.1)\n", 119 | "Requirement already satisfied: PySocks!=1.5.7,>=1.5.6 in /usr/local/lib/python3.10/dist-packages (from requests>=2.25.0->robustbench==1.1->-r requirements.txt (line 3)) (1.7.1)\n", 120 | "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.10/dist-packages (from sympy->torch>=1.9->geotorch->-r requirements.txt (line 1)) (1.3.0)\n" 121 | ] 122 | } 123 | ], 124 | "source": [ 125 | "from evaluate import *\n", 126 | "from utils import *\n", 127 | "from tqdm.notebook import tqdm\n", 128 | "from data_loader import *\n", 129 | "from stability_loss_function import *" 130 | ] 131 | }, 132 | { 133 | "cell_type": "code", 134 | "execution_count": null, 135 | "metadata": { 136 | "id": "d7vxLb0179fa" 137 | }, 138 | "outputs": [], 139 | "source": [ 140 | "parser = argparse.ArgumentParser(description=\"Hyperparameters for the script\")\n", 141 | "\n", 142 | " \n", 143 | "parser.add_argument('--in_dataset', type=str, default='cifar100', choices=['cifar10', 'cifar100'], help='The in-distribution dataset to be used')\n", 144 | "parser.add_argument('--threat_model', type=str, default='Linf', help='Adversarial threat model for robust training')\n", 145 | "parser.add_argument('--noise_std', type=float, default=1, help='Standard deviation of noise for generating noisy fake embeddings')\n", 146 | "parser.add_argument('--attack_eps', type=float, default=8/255, help='Perturbation bound (epsilon) for PGD attack')\n", 147 | "parser.add_argument('--attack_steps', type=int, default=10, help='Number of steps for the PGD attack')\n", 148 | "parser.add_argument('--attack_alpha', type=float, default=2.5 * (8/255) / 10, help='Step size (alpha) for each PGD attack iteration')\n", 149 | "\n", 150 | "args = parser.parse_args('')\n", 151 | "\n", 152 | "# Set the default model name based on the selected dataset\n", 153 | "if args.in_dataset == 'cifar10':\n", 154 | " default_model_name = 'Rebuffi2021Fixing_70_16_cutmix_extra'\n", 155 | "elif args.in_dataset == 'cifar100':\n", 156 | " default_model_name = 'Wang2023Better_WRN-70-16'\n", 157 | "\n", 158 | "parser.add_argument('--model_name', type=str, default=default_model_name, choices=['Rebuffi2021Fixing_70_16_cutmix_extra', 'Wang2023Better_WRN-70-16'], help='The pre-trained model to be used for feature extraction')\n", 159 | "\n", 160 | "# Re-parse arguments to include model_name selection based on the dataset\n", 161 | "args = parser.parse_args('')\n", 162 | "num_classes = 10 if args.in_dataset == 'cifar10' else 100\n", 163 | "\n", 164 | "trainloader, testloader,test_set, ID_OOD_loader = get_loaders(in_dataset=args.in_dataset)\n", 165 | "\n", 166 | "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n" 167 | ] 168 | }, 169 | { 170 | "cell_type": "code", 171 | "execution_count": null, 172 | "metadata": { 173 | "colab": { 174 | "referenced_widgets": [ 175 | "59296a90b8c84b1c94648a4c5d68a43b", 176 | "ad54c341af6e400280d000b3725f08ee" 177 | ] 178 | }, 179 | "id": "g2TltXvg7MfF", 180 | "outputId": "4df864e7-e14b-4db4-e1ae-06e33c9b11be" 181 | }, 182 | "outputs": [ 183 | { 184 | "data": { 185 | "application/vnd.jupyter.widget-view+json": { 186 | "model_id": "59296a90b8c84b1c94648a4c5d68a43b", 187 | "version_major": 2, 188 | "version_minor": 0 189 | }, 190 | "text/plain": [ 191 | " 0%| | 0/1250 [00:00 1+tol or torch.min(images) < 0-tol: 150 | raise ValueError('Input must have a range [0, 1] (max: {}, min: {})'.format( 151 | torch.max(images), torch.min(images))) 152 | return images 153 | 154 | def _check_outputs(self, images): 155 | if self._normalization_applied: 156 | images = self.normalize(images) 157 | return images 158 | 159 | @wrapper_method 160 | def set_model(self, model): 161 | self.model = model 162 | self.model_name = model.__class__.__name__ 163 | 164 | def get_logits(self, inputs, labels=None, *args, **kwargs): 165 | if self._normalization_applied: 166 | inputs = self.normalize(inputs) 167 | logits = self.model(inputs) 168 | return logits 169 | 170 | @wrapper_method 171 | def _set_normalization_applied(self, flag): 172 | self._normalization_applied = flag 173 | 174 | @wrapper_method 175 | def set_device(self, device): 176 | self.device = device 177 | 178 | @wrapper_method 179 | def _set_auto_normalization_used(self, model): 180 | if model.__class__.__name__ == 'RobModel': 181 | mean = getattr(model, 'mean', None) 182 | std = getattr(model, 'std', None) 183 | if (mean is not None) and (std is not None): 184 | if isinstance(mean, torch.Tensor): 185 | mean = mean.cpu().numpy() 186 | if isinstance(std, torch.Tensor): 187 | std = std.cpu().numpy() 188 | if (mean != 0).all() or (std != 1).all(): 189 | self.set_normalization_used(mean, std) 190 | # logging.info("Normalization automatically loaded from `model.mean` and `model.std`.") 191 | 192 | @wrapper_method 193 | def set_normalization_used(self, mean, std): 194 | n_channels = len(mean) 195 | mean = torch.tensor(mean).reshape(1, n_channels, 1, 1) 196 | std = torch.tensor(std).reshape(1, n_channels, 1, 1) 197 | self.normalization_used['mean'] = mean 198 | self.normalization_used['std'] = std 199 | self._normalization_applied = True 200 | 201 | def normalize(self, inputs): 202 | mean = self.normalization_used['mean'].to(inputs.device) 203 | std = self.normalization_used['std'].to(inputs.device) 204 | return (inputs - mean) / std 205 | 206 | def inverse_normalize(self, inputs): 207 | mean = self.normalization_used['mean'].to(inputs.device) 208 | std = self.normalization_used['std'].to(inputs.device) 209 | return inputs*std + mean 210 | 211 | def get_mode(self): 212 | r""" 213 | Get attack mode. 214 | 215 | """ 216 | return self.attack_mode 217 | 218 | @wrapper_method 219 | def set_mode_default(self): 220 | r""" 221 | Set attack mode as default mode. 222 | 223 | """ 224 | self.attack_mode = 'default' 225 | self.targeted = False 226 | print("Attack mode is changed to 'default.'") 227 | 228 | @wrapper_method 229 | def _set_mode_targeted(self, mode, quiet): 230 | if "targeted" not in self.supported_mode: 231 | raise ValueError("Targeted mode is not supported.") 232 | self.targeted = True 233 | self.attack_mode = mode 234 | if not quiet: 235 | print("Attack mode is changed to '%s'." % mode) 236 | 237 | @wrapper_method 238 | def set_mode_targeted_by_function(self, target_map_function, quiet=False): 239 | r""" 240 | Set attack mode as targeted. 241 | 242 | Arguments: 243 | target_map_function (function): Label mapping function. 244 | e.g. lambda inputs, labels:(labels+1)%10. 245 | None for using input labels as targeted labels. (Default) 246 | 247 | """ 248 | self._set_mode_targeted('targeted(custom)', quiet) 249 | self._target_map_function = target_map_function 250 | 251 | @wrapper_method 252 | def set_mode_targeted_random(self, quiet=False): 253 | r""" 254 | Set attack mode as targeted with random labels. 255 | 256 | Arguments: 257 | num_classses (str): number of classes. 258 | 259 | """ 260 | self._set_mode_targeted('targeted(random)', quiet) 261 | self._target_map_function = self.get_random_target_label 262 | 263 | @wrapper_method 264 | def set_mode_targeted_least_likely(self, kth_min=1, quiet=False): 265 | r""" 266 | Set attack mode as targeted with least likely labels. 267 | 268 | Arguments: 269 | kth_min (str): label with the k-th smallest probability used as target labels. (Default: 1) 270 | 271 | """ 272 | self._set_mode_targeted('targeted(least-likely)', quiet) 273 | assert (kth_min > 0) 274 | self._kth_min = kth_min 275 | self._target_map_function = self.get_least_likely_label 276 | 277 | @wrapper_method 278 | def set_mode_targeted_by_label(self, quiet=False): 279 | r""" 280 | Set attack mode as targeted. 281 | 282 | .. note:: 283 | Use user-supplied labels as target labels. 284 | """ 285 | self._set_mode_targeted('targeted(label)', quiet) 286 | self._target_map_function = 'function is a string' 287 | 288 | @wrapper_method 289 | def set_model_training_mode(self, model_training=False, batchnorm_training=False, dropout_training=False): 290 | r""" 291 | Set training mode during attack process. 292 | 293 | Arguments: 294 | model_training (bool): True for using training mode for the entire model during attack process. 295 | batchnorm_training (bool): True for using training mode for batchnorms during attack process. 296 | dropout_training (bool): True for using training mode for dropouts during attack process. 297 | 298 | .. note:: 299 | For RNN-based models, we cannot calculate gradients with eval mode. 300 | Thus, it should be changed to the training mode during the attack. 301 | """ 302 | self._model_training = model_training 303 | self._batchnorm_training = batchnorm_training 304 | self._dropout_training = dropout_training 305 | 306 | @wrapper_method 307 | def _change_model_mode(self, given_training): 308 | if self._model_training: 309 | self.model.train() 310 | for _, m in self.model.named_modules(): 311 | if not self._batchnorm_training: 312 | if 'BatchNorm' in m.__class__.__name__: 313 | m = m.eval() 314 | if not self._dropout_training: 315 | if 'Dropout' in m.__class__.__name__: 316 | m = m.eval() 317 | else: 318 | self.model.eval() 319 | 320 | @wrapper_method 321 | def _recover_model_mode(self, given_training): 322 | if given_training: 323 | self.model.train() 324 | 325 | def save(self, data_loader, save_path=None, verbose=True, return_verbose=False, 326 | save_predictions=False, save_clean_inputs=False, save_type='float'): 327 | r""" 328 | Save adversarial inputs as torch.tensor from given torch.utils.data.DataLoader. 329 | 330 | Arguments: 331 | save_path (str): save_path. 332 | data_loader (torch.utils.data.DataLoader): data loader. 333 | verbose (bool): True for displaying detailed information. (Default: True) 334 | return_verbose (bool): True for returning detailed information. (Default: False) 335 | save_predictions (bool): True for saving predicted labels (Default: False) 336 | save_clean_inputs (bool): True for saving clean inputs (Default: False) 337 | 338 | """ 339 | if save_path is not None: 340 | adv_input_list = [] 341 | label_list = [] 342 | if save_predictions: 343 | pred_list = [] 344 | if save_clean_inputs: 345 | input_list = [] 346 | 347 | correct = 0 348 | total = 0 349 | l2_distance = [] 350 | 351 | total_batch = len(data_loader) 352 | given_training = self.model.training 353 | 354 | for step, (inputs, labels) in enumerate(data_loader): 355 | start = time.time() 356 | adv_inputs = self.__call__(inputs, labels) 357 | batch_size = len(inputs) 358 | 359 | if verbose or return_verbose: 360 | with torch.no_grad(): 361 | outputs = self.get_output_with_eval_nograd(adv_inputs) 362 | 363 | # Calculate robust accuracy 364 | _, pred = torch.max(outputs.data, 1) 365 | total += labels.size(0) 366 | right_idx = (pred == labels.to(self.device)) 367 | correct += right_idx.sum() 368 | rob_acc = 100 * float(correct) / total 369 | 370 | # Calculate l2 distance 371 | delta = (adv_inputs - inputs.to(self.device)).view(batch_size, -1) 372 | l2_distance.append(torch.norm(delta[~right_idx], p=2, dim=1)) 373 | l2 = torch.cat(l2_distance).mean().item() 374 | 375 | # Calculate time computation 376 | progress = (step+1)/total_batch*100 377 | end = time.time() 378 | elapsed_time = end-start 379 | 380 | if verbose: 381 | self._save_print(progress, rob_acc, l2, elapsed_time, end='\r') 382 | 383 | if save_path is not None: 384 | adv_input_list.append(adv_inputs.detach().cpu()) 385 | label_list.append(labels.detach().cpu()) 386 | 387 | adv_input_list_cat = torch.cat(adv_input_list, 0) 388 | label_list_cat = torch.cat(label_list, 0) 389 | 390 | save_dict = {'adv_inputs': adv_input_list_cat, 'labels': label_list_cat} 391 | 392 | if save_predictions: 393 | pred_list.append(pred.detach().cpu()) 394 | pred_list_cat = torch.cat(pred_list, 0) 395 | save_dict['preds'] = pred_list_cat 396 | 397 | if save_clean_inputs: 398 | input_list.append(inputs.detach().cpu()) 399 | input_list_cat = torch.cat(input_list, 0) 400 | save_dict['clean_inputs'] = input_list_cat 401 | 402 | if self.normalization_used is not None: 403 | save_dict['adv_inputs'] = self.inverse_normalize(save_dict['adv_inputs']) 404 | if save_clean_inputs: 405 | save_dict['clean_inputs'] = self.inverse_normalize(save_dict['clean_inputs']) 406 | 407 | if save_type == 'int': 408 | save_dict['adv_inputs'] = self.to_type(save_dict['adv_inputs'], 'int') 409 | if save_clean_inputs: 410 | save_dict['clean_inputs'] = self.to_type(save_dict['clean_inputs'], 'int') 411 | 412 | save_dict['save_type'] = save_type 413 | torch.save(save_dict, save_path) 414 | 415 | # To avoid erasing the printed information. 416 | if verbose: 417 | self._save_print(progress, rob_acc, l2, elapsed_time, end='\n') 418 | 419 | if given_training: 420 | self.model.train() 421 | 422 | if return_verbose: 423 | return rob_acc, l2, elapsed_time 424 | 425 | @staticmethod 426 | def to_type(inputs, type): 427 | r""" 428 | Return inputs as int if float is given. 429 | """ 430 | if type == 'int': 431 | if isinstance(inputs, torch.FloatTensor) or isinstance(inputs, torch.cuda.FloatTensor): 432 | return (inputs*255).type(torch.uint8) 433 | elif type == 'float': 434 | if isinstance(inputs, torch.ByteTensor) or isinstance(inputs, torch.cuda.ByteTensor): 435 | return inputs.float()/255 436 | else: 437 | raise ValueError( 438 | type + " is not a valid type. [Options: float, int]") 439 | return inputs 440 | 441 | @staticmethod 442 | def _save_print(progress, rob_acc, l2, elapsed_time, end): 443 | print('- Save progress: %2.2f %% / Robust accuracy: %2.2f %% / L2: %1.5f (%2.3f it/s) \t' 444 | % (progress, rob_acc, l2, elapsed_time), end=end) 445 | 446 | @staticmethod 447 | def load(load_path, batch_size=128, shuffle=False, normalize=None, 448 | load_predictions=False, load_clean_inputs=False): 449 | save_dict = torch.load(load_path) 450 | keys = ['adv_inputs', 'labels'] 451 | 452 | if load_predictions: 453 | keys.append('preds') 454 | if load_clean_inputs: 455 | keys.append('clean_inputs') 456 | 457 | if save_dict['save_type'] == 'int': 458 | save_dict['adv_inputs'] = save_dict['adv_inputs'].float()/255 459 | if load_clean_inputs: 460 | save_dict['clean_inputs'] = save_dict['clean_inputs'].float() / 255 # nopep8 461 | 462 | if normalize is not None: 463 | n_channels = len(normalize['mean']) 464 | mean = torch.tensor(normalize['mean']).reshape(1, n_channels, 1, 1) 465 | std = torch.tensor(normalize['std']).reshape(1, n_channels, 1, 1) 466 | save_dict['adv_inputs'] = (save_dict['adv_inputs'] - mean) / std 467 | if load_clean_inputs: 468 | save_dict['clean_inputs'] = (save_dict['clean_inputs'] - mean) / std # nopep8 469 | 470 | adv_data = TensorDataset(*[save_dict[key] for key in keys]) 471 | adv_loader = DataLoader( 472 | adv_data, batch_size=batch_size, shuffle=shuffle) 473 | print("Data is loaded in the following order: [%s]" % (", ".join(keys))) # nopep8 474 | return adv_loader 475 | 476 | @torch.no_grad() 477 | def get_output_with_eval_nograd(self, inputs): 478 | given_training = self.model.training 479 | if given_training: 480 | self.model.eval() 481 | outputs = self.get_logits(inputs) 482 | if given_training: 483 | self.model.train() 484 | return outputs 485 | 486 | def get_target_label(self, inputs, labels=None): 487 | r""" 488 | Function for changing the attack mode. 489 | Return input labels. 490 | """ 491 | if self._target_map_function is None: 492 | raise ValueError( 493 | 'target_map_function is not initialized by set_mode_targeted.') 494 | if self.attack_mode == 'targeted(label)': 495 | target_labels = labels 496 | else: 497 | target_labels = self._target_map_function(inputs, labels) 498 | return target_labels 499 | 500 | @torch.no_grad() 501 | def get_least_likely_label(self, inputs, labels=None): 502 | outputs = self.get_output_with_eval_nograd(inputs) 503 | if labels is None: 504 | _, labels = torch.max(outputs, dim=1) 505 | n_classses = outputs.shape[-1] 506 | 507 | target_labels = torch.zeros_like(labels) 508 | for counter in range(labels.shape[0]): 509 | l = list(range(n_classses)) 510 | l.remove(labels[counter]) 511 | _, t = torch.kthvalue(outputs[counter][l], self._kth_min) 512 | target_labels[counter] = l[t] 513 | 514 | return target_labels.long().to(self.device) 515 | 516 | @torch.no_grad() 517 | def get_random_target_label(self, inputs, labels=None): 518 | outputs = self.get_output_with_eval_nograd(inputs) 519 | if labels is None: 520 | _, labels = torch.max(outputs, dim=1) 521 | n_classses = outputs.shape[-1] 522 | 523 | target_labels = torch.zeros_like(labels) 524 | for counter in range(labels.shape[0]): 525 | l = list(range(n_classses)) 526 | l.remove(labels[counter]) 527 | t = (len(l)*torch.rand([1])).long().to(self.device) 528 | target_labels[counter] = l[t] 529 | 530 | return target_labels.long().to(self.device) 531 | 532 | def __call__(self, images, labels=None, *args, **kwargs): 533 | given_training = self.model.training 534 | self._change_model_mode(given_training) 535 | images = self._check_inputs(images) 536 | adv_images = self.forward(images, labels, *args, **kwargs) 537 | adv_images = self._check_outputs(adv_images) 538 | self._recover_model_mode(given_training) 539 | return adv_images 540 | 541 | def __repr__(self): 542 | info = self.__dict__.copy() 543 | 544 | del_keys = ['model', 'attack', 'supported_mode'] 545 | 546 | for key in info.keys(): 547 | if key[0] == "_": 548 | del_keys.append(key) 549 | 550 | for key in del_keys: 551 | del info[key] 552 | 553 | info['attack_mode'] = self.attack_mode 554 | info['normalization_used'] = True if len(self.normalization_used) > 0 else False # nopep8 555 | 556 | return self.attack + "(" + ', '.join('{}={}'.format(key, val) for key, val in info.items()) + ")" 557 | 558 | def __setattr__(self, name, value): 559 | object.__setattr__(self, name, value) 560 | 561 | attacks = self.__dict__.get('_attacks') 562 | 563 | # Get all items in iterable items. 564 | def get_all_values(items, stack=[]): 565 | if (items not in stack): 566 | stack.append(items) 567 | if isinstance(items, list) or isinstance(items, dict): 568 | if isinstance(items, dict): 569 | items = (list(items.keys())+list(items.values())) 570 | for item in items: 571 | yield from get_all_values(item, stack) 572 | else: 573 | if isinstance(items, Attack): 574 | yield items 575 | else: 576 | if isinstance(items, Attack): 577 | yield items 578 | 579 | for num, value in enumerate(get_all_values(value)): 580 | attacks[name+"."+str(num)] = value 581 | for subname, subvalue in value.__dict__.get('_attacks').items(): 582 | attacks[name+"."+subname] = subvalue 583 | 584 | 585 | 586 | 587 | 588 | 589 | def get_auc_adversarial(model, test_loader, test_attack, device, num_classes): 590 | is_train = model.training 591 | model.eval() 592 | 593 | soft = torch.nn.Softmax(dim=1) 594 | anomaly_scores = [] 595 | preds = [] 596 | test_labels = [] 597 | test_labels_acc = [] 598 | with tqdm(test_loader, unit="batch") as tepoch: 599 | torch.cuda.empty_cache() 600 | for i, (data, target) in enumerate(tepoch): 601 | data, target = data.to(device), target.to(device) 602 | adv_data = test_attack(data, target) 603 | output = model(adv_data) 604 | predictions = output.argmax(dim=1, keepdim=True).squeeze() 605 | preds += predictions.detach().cpu().numpy().tolist() 606 | probs = soft(output) 607 | max_probabilities,_ = torch.max(probs[:,:num_classes] , dim=1) 608 | anomaly_scores+=max_probabilities.detach().cpu().numpy().tolist() 609 | target = target == num_classes 610 | 611 | test_labels += target.detach().cpu().numpy().tolist() 612 | anomaly_scores=[x * -1 for x in anomaly_scores] 613 | auc = roc_auc_score(test_labels, anomaly_scores) 614 | fpr95 = compute_fpr95(test_labels, anomaly_scores) 615 | auroc = compute_auroc(test_labels, anomaly_scores) 616 | aupr = compute_aupr(test_labels, anomaly_scores) 617 | 618 | print(f"FPR95: {fpr95}") 619 | print(f"AUROC is: {auroc}") 620 | print(f"AUPR: {aupr}") 621 | 622 | 623 | if is_train: 624 | model.train() 625 | else: 626 | model.eval() 627 | 628 | return auc 629 | 630 | 631 | 632 | 633 | 634 | 635 | class PGD_AUC(Attack): 636 | r""" 637 | PGD in the paper 'Towards Deep Learning Models Resistant to Adversarial Attacks' 638 | [https://arxiv.org/abs/1706.06083] 639 | 640 | Distance Measure : Linf 641 | 642 | Arguments: 643 | model (nn.Module): model to attack. 644 | eps (float): maximum perturbation. (Default: 8/255) 645 | alpha (float): step size. (Default: 2/255) 646 | steps (int): number of steps. (Default: 10) 647 | random_start (bool): using random initialization of delta. (Default: True) 648 | 649 | Shape: 650 | - images: :math:`(N, C, H, W)` where `N = number of batches`, `C = number of channels`, `H = height` and `W = width`. It must have a range [0, 1]. 651 | - labels: :math:`(N)` where each value :math:`y_i` is :math:`0 \leq y_i \leq` `number of labels`. 652 | - output: :math:`(N, C, H, W)`. 653 | 654 | Examples:: 655 | >>> attack = torchattacks.PGD(model, eps=8/255, alpha=1/255, steps=10, random_start=True) 656 | >>> adv_images = attack(images, labels) 657 | 658 | """ 659 | 660 | def __init__(self, model, eps=8/255, alpha=2/255, steps=10, random_start=True, num_classes=10): 661 | super().__init__("PGD", model) 662 | self.eps = eps 663 | self.alpha = alpha 664 | self.steps = steps 665 | self.random_start = random_start 666 | self.supported_mode = ['default', 'targeted'] 667 | self.num_classes = num_classes 668 | 669 | def forward(self, images, labels): 670 | r""" 671 | Overridden. 672 | """ 673 | 674 | images = images.clone().detach().to(self.device) 675 | labels = labels.clone().detach().to(self.device) 676 | 677 | softmax = nn.Softmax(dim=1) 678 | adv_images = images.clone().detach() 679 | 680 | if self.random_start: 681 | # Starting at a uniformly random point 682 | adv_images = adv_images + \ 683 | torch.empty_like(adv_images).uniform_(-self.eps, self.eps) 684 | adv_images = torch.clamp(adv_images, min=0, max=1).detach() 685 | 686 | ones = torch.ones_like(labels) 687 | multipliers = -1*(ones - 2 * ones * (labels == self.num_classes)) 688 | 689 | for _ in range(self.steps): 690 | adv_images.requires_grad = True 691 | outputs = self.get_logits(adv_images) 692 | final_outputs = softmax(outputs) 693 | # choose the value of probability of ood 694 | 695 | max_probabilities = torch.max(final_outputs , dim=1)[0] 696 | 697 | 698 | 699 | cost = torch.sum(max_probabilities * multipliers) 700 | 701 | 702 | 703 | # Update adversarial images 704 | grad = torch.autograd.grad(cost, adv_images, 705 | retain_graph=False, create_graph=False)[0] 706 | 707 | adv_images = adv_images.detach() + self.alpha*grad.sign() 708 | delta = torch.clamp(adv_images - images, 709 | min=-self.eps, max=self.eps) 710 | adv_images = torch.clamp(images + delta, min=0, max=1).detach() 711 | return adv_images 712 | 713 | 714 | 715 | 716 | 717 | 718 | def auc_MSP_adversarial(model, test_loader, test_attack, device, num_classes): 719 | is_train = model.training 720 | model.eval() 721 | 722 | soft = torch.nn.Softmax(dim=1) 723 | anomaly_scores = [] 724 | preds = [] 725 | test_labels = [] 726 | test_labels_acc = [] 727 | with tqdm(test_loader, unit="batch") as tepoch: 728 | torch.cuda.empty_cache() 729 | for i, (data, target) in enumerate(tepoch): 730 | data, target = data.to(device), target.to(device) 731 | adv_data = test_attack(data, target) 732 | 733 | output = model(adv_data) 734 | 735 | predictions = output.argmax(dim=1, keepdim=True).squeeze() 736 | preds += predictions.detach().cpu().numpy().tolist() 737 | 738 | # probs = soft(output).squeeze() 739 | # anomaly_scores += probs[:, num_classes].detach().cpu().numpy().tolist() 740 | 741 | 742 | probs = soft(output) 743 | 744 | max_probabilities,_ = torch.max(probs[:,:num_classes] , dim=1) 745 | 746 | 747 | 748 | 749 | anomaly_scores+=max_probabilities.detach().cpu().numpy().tolist() 750 | # anomaly_scores += probs[:, num_classes].detach().cpu().numpy().tolist() 751 | target = target == num_classes 752 | test_labels += target.detach().cpu().numpy().tolist() 753 | anomaly_scores=[x * -1 for x in anomaly_scores] 754 | auc = roc_auc_score(test_labels, anomaly_scores) 755 | 756 | if is_train: 757 | model.train() 758 | else: 759 | model.eval() 760 | 761 | print(auc) 762 | return auc 763 | 764 | 765 | class PGD_MSP(Attack): 766 | r""" 767 | PGD in the paper 'Towards Deep Learning Models Resistant to Adversarial Attacks' 768 | [https://arxiv.org/abs/1706.06083] 769 | 770 | Distance Measure : Linf 771 | 772 | Arguments: 773 | model (nn.Module): model to attack. 774 | eps (float): maximum perturbation. (Default: 8/255) 775 | alpha (float): step size. (Default: 2/255) 776 | steps (int): number of steps. (Default: 10) 777 | random_start (bool): using random initialization of delta. (Default: True) 778 | 779 | Shape: 780 | - images: :math:`(N, C, H, W)` where `N = number of batches`, `C = number of channels`, `H = height` and `W = width`. It must have a range [0, 1]. 781 | - labels: :math:`(N)` where each value :math:`y_i` is :math:`0 \leq y_i \leq` `number of labels`. 782 | - output: :math:`(N, C, H, W)`. 783 | 784 | Examples:: 785 | >>> attack = torchattacks.PGD(model, eps=8/255, alpha=1/255, steps=10, random_start=True) 786 | >>> adv_images = attack(images, labels) 787 | 788 | """ 789 | 790 | def __init__(self, model, eps=8/255, alpha=2/255, steps=10, random_start=True, num_classes=10): 791 | super().__init__("PGD", model) 792 | self.eps = eps 793 | self.alpha = alpha 794 | self.steps = steps 795 | self.random_start = random_start 796 | self.supported_mode = ['default', 'targeted'] 797 | self.num_classes = num_classes 798 | 799 | def forward(self, images, labels): 800 | r""" 801 | Overridden. 802 | """ 803 | 804 | images = images.clone().detach().to(self.device) 805 | labels = labels.clone().detach().to(self.device) 806 | 807 | softmax = nn.Softmax(dim=1) 808 | adv_images = images.clone().detach() 809 | 810 | if self.random_start: 811 | # Starting at a uniformly random point 812 | adv_images = adv_images + \ 813 | torch.empty_like(adv_images).uniform_(-self.eps, self.eps) 814 | adv_images = torch.clamp(adv_images, min=0, max=1).detach() 815 | 816 | ones = torch.ones_like(labels) 817 | multipliers = -1*(ones - 2 * ones * (labels == self.num_classes)) 818 | 819 | for _ in range(self.steps): 820 | adv_images.requires_grad = True 821 | outputs = self.get_logits(adv_images) 822 | final_outputs = softmax(outputs) 823 | # choose the value of probability of ood 824 | 825 | max_probabilities = torch.max(final_outputs , dim=1)[0] 826 | 827 | 828 | 829 | cost = torch.sum(max_probabilities * multipliers) 830 | 831 | 832 | 833 | # Update adversarial images 834 | grad = torch.autograd.grad(cost, adv_images, 835 | retain_graph=False, create_graph=False)[0] 836 | 837 | adv_images = adv_images.detach() + self.alpha*grad.sign() 838 | delta = torch.clamp(adv_images - images, 839 | min=-self.eps, max=self.eps) 840 | adv_images = torch.clamp(images + delta, min=0, max=1).detach() 841 | 842 | # imshow(torchvision.utils.make_grid(adv_images.cpu())) 843 | # print(multipliers) 844 | return adv_images 845 | 846 | 847 | 848 | 849 | 850 | def auc_MSP(model, test_loader , device, num_classes): 851 | is_train = model.training 852 | model.eval() 853 | 854 | soft = torch.nn.Softmax(dim=1) 855 | anomaly_scores = [] 856 | preds = [] 857 | test_labels = [] 858 | test_labels_acc = [] 859 | 860 | with torch.no_grad(): 861 | with tqdm(test_loader, unit="batch") as tepoch: 862 | torch.cuda.empty_cache() 863 | for i, (data, target) in enumerate(tepoch): 864 | data, target = data.to(device), target.to(device) 865 | output = model(data) 866 | 867 | predictions = output.argmax(dim=1, keepdim=True).squeeze() 868 | preds += predictions.detach().cpu().numpy().tolist() 869 | 870 | # probs = soft(output).squeeze() 871 | # anomaly_scores += probs[:, num_classes].detach().cpu().numpy().tolist() model 872 | 873 | probs = soft(output) 874 | max_probabilities,_ = torch.max(probs[:,:num_classes] , dim=1) 875 | anomaly_scores+=max_probabilities.detach().cpu().numpy().tolist() 876 | # anomaly_scores += probs[:, num_classes].detach().cpu().numpy().tolist() 877 | target = target == num_classes 878 | 879 | test_labels += target.detach().cpu().numpy().tolist() 880 | anomaly_scores=[x * -1 for x in anomaly_scores] 881 | auc = roc_auc_score(test_labels, anomaly_scores) 882 | print(auc) 883 | 884 | 885 | return auc 886 | 887 | -------------------------------------------------------------------------------- /aros_node/stability_loss_function.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from robustbench.utils import load_model 3 | import torch.nn as nn 4 | from torch.nn.parameter import Parameter 5 | import aros_node.utils 6 | from aros_node.utils import * 7 | from torch.utils.data import DataLoader, Dataset, TensorDataset, Subset, SubsetRandomSampler, ConcatDataset 8 | import numpy as np 9 | from tqdm.notebook import tqdm 10 | from torch.optim.lr_scheduler import StepLR 11 | 12 | 13 | 14 | 15 | robust_feature_savefolder = './CIFAR100_resnet_Nov_1' 16 | train_savepath='./CIFAR100_train_resnetNov1.npz' 17 | test_savepath='./CIFAR100_test_resnetNov1.npz' 18 | ODE_FC_save_folder = './CIFAR100_resnet_Nov_1' 19 | 20 | 21 | 22 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 23 | 24 | 25 | 26 | 27 | 28 | 29 | ODE_FC_odebatch=100 30 | def stability_loss_function_(trainloader,testloader,robust_backbone,class_numbers,fake_loader,last_layer,args): 31 | 32 | 33 | robust_backbone = load_model(model_name=args.model_name, dataset=args.in_dataset, threat_model=args.threat_model).to(device) 34 | 35 | 36 | last_layer_name, last_layer = list(robust_backbone.named_children())[-1] 37 | setattr(robust_backbone, last_layer_name, nn.Identity()) 38 | 39 | 40 | 41 | 42 | robust_backbone_fc_features = MLP_OUT_ORTH1024(last_layer.in_features) 43 | 44 | fc_layers_phase1 = MLP_OUT_BALL(class_numbers) 45 | 46 | for param in fc_layers_phase1.parameters(): 47 | param.requires_grad = False 48 | 49 | net_save_robustfeature = nn.Sequential(robust_backbone, robust_backbone_fc_features, fc_layers_phase1).to(device) 50 | 51 | for param in robust_backbone.parameters(): 52 | param.requires_grad = False 53 | 54 | 55 | 56 | 57 | print(net_save_robustfeature) 58 | net_save_robustfeature = net_save_robustfeature.to(device) 59 | data_gen = inf_generator(trainloader) 60 | batches_per_epoch = len(trainloader) 61 | optimizer1 = torch.optim.Adam(net_save_robustfeature.parameters(), lr=5e-3, eps=1e-2, amsgrad=True) 62 | scheduler = StepLR(optimizer1, step_size=1, gamma=0.5) # Adjust step_size and gamma as needed 63 | 64 | 65 | def train_save_robustfeature(epoch): 66 | best_acc = 0 67 | criterion = nn.CrossEntropyLoss() 68 | print('\nEpoch: %d' % epoch) 69 | net_save_robustfeature.train() 70 | train_loss = 0 71 | correct = 0 72 | total = 0 73 | for batch_idx, (inputs, targets) in enumerate(trainloader): 74 | inputs, targets = inputs.to(device), targets.to(device) 75 | optimizer1.zero_grad() 76 | x = inputs 77 | outputs = net_save_robustfeature(x) 78 | loss = criterion(outputs, targets) 79 | loss.backward() 80 | optimizer1.step() 81 | train_loss += loss.item() 82 | _, predicted = outputs.max(1) 83 | total += targets.size(0) 84 | correct += predicted.eq(targets).sum().item() 85 | progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)' % (train_loss/(batch_idx+1), 100.*correct/total, correct, total)) 86 | scheduler.step() 87 | 88 | 89 | 90 | def test_save_robustfeature(epoch): 91 | best_acc=0 92 | net_save_robustfeature.eval() 93 | test_loss = 0 94 | correct = 0 95 | total = 0 96 | with torch.no_grad(): 97 | for batch_idx, (inputs, targets) in enumerate(testloader): 98 | inputs, targets = inputs.to(device), targets.to(device) 99 | x = inputs 100 | outputs = net_save_robustfeature(x) 101 | loss = criterion(outputs, targets) 102 | test_loss += loss.item() 103 | _, predicted = outputs.max(1) 104 | total += targets.size(0) 105 | correct += predicted.eq(targets).sum().item() 106 | progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)' % (test_loss/(batch_idx+1), 100.*correct/total, correct, total)) 107 | acc = 100.*correct/total 108 | if acc > best_acc: 109 | print('Saving..') 110 | state = {'net_save_robustfeature': net_save_robustfeature.state_dict(),'acc': acc,'epoch': epoch} 111 | torch.save(state, robust_feature_savefolder+'/ckpt.pth') 112 | best_acc = acc 113 | save_training_feature(net_save_robustfeature, trainloader,fake_embeddings_loader=fake_loader) 114 | print('----') 115 | save_testing_feature(net_save_robustfeature, testloader) 116 | print('------------') 117 | 118 | makedirs(robust_feature_savefolder) 119 | for epoch in range(0, args.epoch1): 120 | train_save_robustfeature(epoch) 121 | test_save_robustfeature(epoch) 122 | print('save robust feature to ' + robust_feature_savefolder) 123 | 124 | 125 | def df_dz_regularizer(odefunc, z): 126 | regu_diag = 0. 127 | regu_offdiag = 0.0 128 | for ii in np.random.choice(z.shape[0], min(numm,z.shape[0]),replace=False): 129 | batchijacobian = torch.autograd.functional.jacobian(lambda x: odefunc(torch.tensor(time_df).to(device), x), z[ii:ii+1,...], create_graph=True) 130 | batchijacobian = batchijacobian.view(z.shape[1],-1) 131 | if batchijacobian.shape[0]!=batchijacobian.shape[1]: 132 | raise Exception("wrong dim in jacobian") 133 | 134 | tempdiag = torch.diagonal(batchijacobian, 0) 135 | regu_diag += torch.exp(exponent*(tempdiag+trans)) 136 | offdiat = torch.sum(torch.abs(batchijacobian)*((-1*torch.eye(batchijacobian.shape[0]).to(device)+0.5)*2), dim=0) 137 | off_diagtemp = torch.exp(exponent_off*(offdiat+transoffdig)) 138 | regu_offdiag += off_diagtemp 139 | 140 | return regu_diag/numm, regu_offdiag/numm 141 | 142 | 143 | def f_regularizer(odefunc, z): 144 | tempf = torch.abs(odefunc(torch.tensor(time_df).to(device), z)) 145 | regu_f = torch.pow(exponent_f*tempf,2) 146 | 147 | return regu_f 148 | 149 | 150 | 151 | 152 | 153 | makedirs(ODE_FC_save_folder) 154 | 155 | odefunc = ODEfunc_mlp(0) 156 | feature_layers = ODEBlocktemp(odefunc) 157 | fc_layers = MLP_OUT_LINEAR(class_numbers) 158 | for param in fc_layers.parameters(): 159 | param.requires_grad = False 160 | ODE_FCmodel = nn.Sequential(feature_layers, fc_layers).to(device) 161 | 162 | train_loader_ODE = DataLoader(DenseDatasetTrain(),batch_size=ODE_FC_odebatch,shuffle=True, num_workers=2) 163 | test_loader_ODE = DataLoader(DenseDatasetTest(),batch_size=ODE_FC_odebatch,shuffle=True, num_workers=2) 164 | data_gen = inf_generator(train_loader_ODE) 165 | batches_per_epoch = len(train_loader_ODE) 166 | 167 | 168 | 169 | 170 | optimizer2 = torch.optim.Adam(ODE_FCmodel.parameters(), lr=1e-2, eps=1e-3, amsgrad=True) 171 | 172 | 173 | scheduler = StepLR(optimizer2, step_size=1, gamma=0.5) # Adjust step_size and gamma as needed 174 | for epoch in range(args.epoch2): 175 | with tqdm(total=args.epoch2 * batches_per_epoch, desc="Training ODE block with loss function") as pbar: 176 | for itr in range(args.epoch2 * batches_per_epoch): 177 | optimizer2.zero_grad() 178 | x, y = data_gen.__next__() 179 | x = x.to(device) 180 | 181 | modulelist = list(ODE_FCmodel) 182 | y0 = x 183 | x = modulelist[0](x) 184 | y1 = x 185 | 186 | y00 = y0 187 | regu1, regu2 = df_dz_regularizer(odefunc, y00) 188 | regu1 = regu1.mean() 189 | regu2 = regu2.mean() 190 | 191 | regu3 = f_regularizer(odefunc, y00) 192 | regu3 = regu3.mean() 193 | 194 | loss = weight_f*regu3 + weight_diag*regu1 + weight_offdiag*regu2 195 | 196 | loss.backward() 197 | optimizer2.step() 198 | torch.cuda.empty_cache() 199 | 200 | # Set postfix to update progress bar with current loss 201 | pbar.set_postfix({"Loss": loss.item()}) 202 | pbar.update(1) 203 | print("Loss", loss.item()) 204 | scheduler.step() # Update the learning rate 205 | 206 | current_lr = optimizer2.param_groups[0]['lr'] 207 | tqdm.write(f"Epoch {epoch+1}, Learning Rate: {current_lr}") 208 | 209 | 210 | 211 | 212 | 213 | 214 | 215 | 216 | feature_layers = ODEBlock(odefunc) 217 | fc_layers = MLP_OUT_LINEAR(class_numbers) 218 | ODE_FCmodel = nn.Sequential(feature_layers, fc_layers).to(device) 219 | 220 | for param in odefunc.parameters(): 221 | param.requires_grad = True 222 | for param in robust_backbone_fc_features.parameters(): 223 | param.requires_grad = False 224 | for param in robust_backbone.parameters(): 225 | param.requires_grad = False 226 | 227 | new_model_full = nn.Sequential(robust_backbone, robust_backbone_fc_features, ODE_FCmodel).to(device) 228 | optimizer3 = torch.optim.Adam([{'params': odefunc.parameters(), 'lr': 1e-5, 'eps':1e-6,},{'params': fc_layers.parameters(), 'lr': 5e-3, 'eps':1e-4,}], amsgrad=True) 229 | 230 | 231 | def train(net, epoch): 232 | criterion = nn.CrossEntropyLoss() 233 | print('\nEpoch: %d' % epoch) 234 | net.train() 235 | train_loss = 0 236 | correct = 0 237 | total = 0 238 | for batch_idx, (inputs, targets) in enumerate(trainloader): 239 | inputs, targets = inputs.to(device), targets.to(device) 240 | optimizer3.zero_grad() 241 | x = inputs 242 | outputs = net(x) 243 | loss = criterion(outputs, targets) 244 | loss.backward() 245 | optimizer3.step() 246 | train_loss += loss.item() 247 | _, predicted = outputs.max(1) 248 | total += targets.size(0) 249 | correct += predicted.eq(targets).sum().item() 250 | progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)' % (train_loss/(batch_idx+1), 100.*correct/total, correct, total)) 251 | 252 | for epoch in range(0, args.epoch3): 253 | train(new_model_full, epoch) 254 | return new_model_full 255 | 256 | 257 | 258 | 259 | 260 | 261 | 262 | 263 | -------------------------------------------------------------------------------- /aros_node/utils.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch 3 | import torchvision 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torch.optim as optim 7 | import geotorch 8 | from torchdiffeq import odeint_adjoint as odeint 9 | import os 10 | import time 11 | from torch.utils.data import DataLoader, Dataset, SubsetRandomSampler, TensorDataset 12 | from torch.nn.parameter import Parameter 13 | import math 14 | import numpy as np 15 | from collections import OrderedDict 16 | 17 | robust_feature_savefolder = './CIFAR10_resnet_Nov_1' 18 | train_savepath='./CIFAR10_train_resnetNov1.npz' 19 | test_savepath='./CIFAR10_test_resnetNov1.npz' 20 | 21 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 22 | 23 | 24 | 25 | 26 | weight_diag = 10 27 | weight_offdiag = 0 28 | weight_f = 0.1 29 | 30 | weight_norm = 0 31 | weight_lossc = 0 32 | 33 | exponent = 1.0 34 | exponent_off = 0.1 35 | exponent_f = 20 36 | time_df = 1 37 | trans = 1.0 38 | transoffdig = 1.0 39 | numm = 16 40 | 41 | 42 | 43 | 44 | 45 | ODE_FC_odebatch = 32 46 | 47 | class Identity(nn.Module): 48 | def __init__(self): 49 | super(Identity, self).__init__() 50 | def forward(self, x): 51 | return x 52 | 53 | class ConcatFC(nn.Module): 54 | def __init__(self, dim_in, dim_out): 55 | super(ConcatFC, self).__init__() 56 | self._layer = nn.Linear(dim_in, dim_out) 57 | def forward(self, t, x): 58 | return self._layer(x) 59 | 60 | class ODEfunc_mlp(nn.Module): 61 | def __init__(self, dim): 62 | super(ODEfunc_mlp, self).__init__() 63 | self.fc1 = ConcatFC(128, 128) 64 | self.act1 = torch.sin 65 | self.nfe = 0 66 | def forward(self, t, x): 67 | self.nfe += 1 68 | out = -1*self.fc1(t, x) 69 | out = self.act1(out) 70 | return out 71 | 72 | 73 | 74 | 75 | class ODEBlocktemp(nn.Module): 76 | def __init__(self, odefunc): 77 | super(ODEBlocktemp, self).__init__() 78 | self.odefunc = odefunc 79 | self.integration_time = torch.tensor([0, 5]).float() 80 | def forward(self, x): 81 | out = self.odefunc(0, x) 82 | return out 83 | @property 84 | def nfe(self): 85 | return self.odefunc.nfe 86 | @nfe.setter 87 | def nfe(self, value): 88 | self.odefunc.nfe = value 89 | 90 | class MLP_OUT_ORTH1024(nn.Module): 91 | def __init__(self,layer_dim_): 92 | super(MLP_OUT_ORTH1024, self).__init__() 93 | self.layer_dim_ = layer_dim_ 94 | self.fc0 = ORTHFC(self.layer_dim_, 128, False) 95 | def forward(self, input_): 96 | h1 = self.fc0(input_) 97 | return h1 98 | 99 | class newLinear(nn.Module): 100 | def __init__(self, in_features, out_features, bias=True): 101 | super(newLinear, self).__init__() 102 | self.in_features = in_features 103 | self.out_features = out_features 104 | self.weight = Parameter(torch.Tensor(in_features,out_features)) 105 | # self.weight = self.weighttemp.T 106 | if bias: 107 | self.bias = Parameter(torch.Tensor(out_features)) 108 | else: 109 | self.register_parameter('bias', None) 110 | self.reset_parameters() 111 | 112 | def reset_parameters(self): 113 | nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) 114 | if self.bias is not None: 115 | fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight) 116 | bound = 1 / math.sqrt(fan_in) 117 | nn.init.uniform_(self.bias, -bound, bound) 118 | 119 | def forward(self, input): 120 | return F.linear(input, self.weight.T, self.bias) 121 | 122 | def extra_repr(self) -> str: 123 | return 'in_features={}, out_features={}, bias={}'.format( 124 | self.in_features, self.out_features, self.bias is not None 125 | ) 126 | 127 | 128 | class ORTHFC(nn.Module): 129 | def __init__(self, dimin, dimout, bias): 130 | super(ORTHFC, self).__init__() 131 | if dimin >= dimout: 132 | self.linear = newLinear(dimin, dimout, bias=bias) 133 | else: 134 | self.linear = nn.Linear(dimin, dimout, bias=bias) 135 | geotorch.orthogonal(self.linear, "weight") 136 | def forward(self, x): 137 | return self.linear(x) 138 | 139 | class MLP_OUT_LINEAR(nn.Module): 140 | def __init__(self,class_numbers): 141 | self.class_numbers = class_numbers 142 | super(MLP_OUT_LINEAR, self).__init__() 143 | self.fc0 = nn.Linear(128, class_numbers) 144 | def forward(self, input_): 145 | h1 = self.fc0(input_) 146 | return h1 147 | 148 | class MLP_OUT_BALL(nn.Module): 149 | def __init__(self,class_numbers): 150 | super(MLP_OUT_BALL, self).__init__() 151 | self.class_numbers = class_numbers 152 | self.fc0 = nn.Linear(128, class_numbers, bias=False) 153 | self.fc0.weight.data = torch.randn([class_numbers,128]) 154 | def forward(self, input_): 155 | h1 = self.fc0(input_) 156 | return h1 157 | 158 | 159 | 160 | 161 | criterion = nn.CrossEntropyLoss() 162 | 163 | 164 | def init_params(net): 165 | '''Init layer parameters.''' 166 | for m in net.modules(): 167 | if isinstance(m, nn.Conv2d): 168 | init.kaiming_normal_(m.weight, mode='fan_out') 169 | if m.bias is not None: 170 | init.constant_(m.bias, 0) 171 | elif isinstance(m, nn.BatchNorm2d): 172 | init.constant_(m.weight, 1) 173 | init.constant_(m.bias, 0) 174 | elif isinstance(m, nn.Linear): 175 | init.normal_(m.weight, std=1e-3) 176 | if m.bias is not None: 177 | init.constant_(m.bias, 0) 178 | 179 | # Try to get terminal width, default to 80 if it fails 180 | try: 181 | _, term_width = os.popen('stty size', 'r').read().split() 182 | term_width = int(term_width) 183 | except ValueError: 184 | term_width = 80 185 | 186 | TOTAL_BAR_LENGTH = 65. 187 | last_time = time.time() 188 | begin_time = last_time 189 | 190 | def progress_bar(current, total, msg=None): 191 | global last_time, begin_time 192 | if current == 0: 193 | begin_time = time.time() # Reset for new bar. 194 | 195 | cur_len = int(TOTAL_BAR_LENGTH*current/total) 196 | rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1 197 | 198 | sys.stdout.write(' [') 199 | for i in range(cur_len): 200 | sys.stdout.write('=') 201 | sys.stdout.write('>') 202 | for i in range(rest_len): 203 | sys.stdout.write('.') 204 | sys.stdout.write(']') 205 | 206 | cur_time = time.time() 207 | step_time = cur_time - last_time 208 | last_time = cur_time 209 | tot_time = cur_time - begin_time 210 | 211 | L = [] 212 | 213 | if msg: 214 | L.append(' | ' + msg) 215 | 216 | msg = ''.join(L) 217 | sys.stdout.write(msg) 218 | for i in range(term_width-int(TOTAL_BAR_LENGTH)-len(msg)-3): 219 | sys.stdout.write(' ') 220 | 221 | # Go back to the center of the bar. 222 | for i in range(term_width-int(TOTAL_BAR_LENGTH/2)+2): 223 | sys.stdout.write('\b') 224 | sys.stdout.write(' %d/%d ' % (current+1, total)) 225 | 226 | if current < total-1: 227 | sys.stdout.write('\r') 228 | else: 229 | sys.stdout.write('\n') 230 | sys.stdout.flush() 231 | 232 | 233 | 234 | 235 | 236 | def inf_generator(iterable): 237 | """Allows training with DataLoaders in a single infinite loop: 238 | for i, (x, y) in enumerate(inf_generator(train_loader)): 239 | """ 240 | iterator = iterable.__iter__() 241 | while True: 242 | try: 243 | yield iterator.__next__() 244 | except StopIteration: 245 | iterator = iterable.__iter__() 246 | 247 | 248 | 249 | def makedirs(dirname): 250 | if not os.path.exists(dirname): 251 | os.makedirs(dirname) 252 | 253 | 254 | 255 | 256 | class DenseDatasetTrain(Dataset): 257 | def __init__(self): 258 | """ 259 | """ 260 | npzfile = np.load(train_savepath) 261 | 262 | self.x = npzfile['x_save'] 263 | self.y = npzfile['y_save'] 264 | def __len__(self): 265 | return len(self.x) 266 | 267 | def __getitem__(self, idx): 268 | x = self.x[idx,...] 269 | y = self.y[idx] 270 | 271 | return x,y 272 | class DenseDatasetTest(Dataset): 273 | def __init__(self): 274 | """ 275 | """ 276 | npzfile = np.load(test_savepath) 277 | 278 | self.x = npzfile['x_save'] 279 | self.y = npzfile['y_save'] 280 | def __len__(self): 281 | return len(self.x) 282 | 283 | def __getitem__(self, idx): 284 | x = self.x[idx,...] 285 | y = self.y[idx] 286 | 287 | return x,y 288 | 289 | 290 | 291 | 292 | 293 | 294 | 295 | class OrthogonalBinaryLayer(nn.Module): 296 | def __init__(self, dimin, dimout=2, bias=True): 297 | super(OrthogonalBinaryLayer, self).__init__() 298 | if dimin >= dimout: 299 | # Use custom linear layer if dimin >= dimout 300 | self.linear = newLinear(dimin, dimout, bias=bias) 301 | else: 302 | self.linear = nn.Linear(dimin, dimout, bias=bias) 303 | 304 | # Apply orthogonality constraint on the weight 305 | geotorch.orthogonal(self.linear, "weight") 306 | 307 | # Binarization of weights using a sign function 308 | self.binarize_weights() 309 | 310 | def binarize_weights(self): 311 | with torch.no_grad(): 312 | # Replace weights with their sign (-1 or +1) 313 | self.linear.weight.data = torch.sign(self.linear.weight.data) 314 | 315 | def forward(self, x): 316 | # Ensure the weights are binary at forward pass 317 | self.binarize_weights() 318 | return self.linear(x) 319 | 320 | 321 | 322 | 323 | 324 | def train(net, epoch,trainloader,optimizer): 325 | print('\nEpoch: %d' % epoch) 326 | net.train() 327 | train_loss = 0 328 | correct = 0 329 | total = 0 330 | for batch_idx, (inputs, targets) in enumerate(trainloader): 331 | inputs, targets = inputs.to(device), targets.to(device) 332 | optimizer.zero_grad() 333 | x = inputs 334 | 335 | 336 | outputs = net(x) 337 | 338 | loss = criterion(outputs, targets) 339 | loss.backward() 340 | optimizer.step() 341 | 342 | train_loss += loss.item() 343 | _, predicted = outputs.max(1) 344 | total += targets.size(0) 345 | correct += predicted.eq(targets).sum().item() 346 | 347 | progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)' 348 | % (train_loss/(batch_idx+1), 100.*correct/total, correct, total)) 349 | 350 | 351 | 352 | 353 | def one_hot(x, K): 354 | return np.array(x[:, None] == np.arange(K)[None, :], dtype=int) 355 | 356 | 357 | 358 | class ODEBlock(nn.Module): 359 | def __init__(self, odefunc): 360 | super(ODEBlock, self).__init__() 361 | self.odefunc = odefunc 362 | self.integration_time = torch.tensor([0, 5]).float() 363 | def forward(self, x): 364 | self.integration_time = self.integration_time.type_as(x) 365 | out = odeint(self.odefunc, x, self.integration_time, rtol=1e-3, atol=1e-3) 366 | return out[1] 367 | @property 368 | def nfe(self): 369 | return self.odefunc.nfe 370 | @nfe.setter 371 | def nfe(self, value): 372 | self.odefunc.nfe = value 373 | 374 | 375 | 376 | 377 | 378 | 379 | def train_save_robustfeature(epoch,net_save_robustfeature,trainloader,optimizer,criterion): 380 | print('\nEpoch: %d' % epoch) 381 | net_save_robustfeature.train() 382 | train_loss = 0 383 | correct = 0 384 | total = 0 385 | for batch_idx, (inputs, targets) in enumerate(trainloader): 386 | inputs, targets = inputs.to(device), targets.to(device) 387 | optimizer.zero_grad() 388 | x = inputs 389 | # print(inputs.shape) 390 | 391 | outputs = net_save_robustfeature(x) 392 | 393 | loss = criterion(outputs, targets) 394 | loss.backward() 395 | optimizer.step() 396 | 397 | train_loss += loss.item() 398 | _, predicted = outputs.max(1) 399 | total += targets.size(0) 400 | correct += predicted.eq(targets).sum().item() 401 | 402 | progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)' 403 | % (train_loss/(batch_idx+1), 100.*correct/total, correct, total)) 404 | 405 | def test_save_robustfeature(epoch,net_save_robustfeature,train_eval_loader,fake_loader,testloader): 406 | best_acc=0 407 | net_save_robustfeature.eval() 408 | test_loss = 0 409 | correct = 0 410 | total = 0 411 | with torch.no_grad(): 412 | for batch_idx, (inputs, targets) in enumerate(testloader): 413 | inputs, targets = inputs.to(device), targets.to(device) 414 | x = inputs 415 | outputs = net_save_robustfeature(x) 416 | loss = criterion(outputs, targets) 417 | 418 | test_loss += loss.item() 419 | _, predicted = outputs.max(1) 420 | total += targets.size(0) 421 | correct += predicted.eq(targets).sum().item() 422 | 423 | progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)' 424 | % (test_loss/(batch_idx+1), 100.*correct/total, correct, total)) 425 | 426 | acc = 100.*correct/total 427 | if acc > best_acc: 428 | print('Saving..') 429 | state = { 430 | 'net_save_robustfeature': net_save_robustfeature.state_dict(), 431 | 'acc': acc, 432 | 'epoch': epoch, 433 | } 434 | 435 | torch.save(state, robust_feature_savefolder+'/ckpt.pth') 436 | best_acc = acc 437 | 438 | save_training_feature(net_save_robustfeature, train_eval_loader,fake_embeddings_loader=fake_loader) 439 | print('----') 440 | save_testing_feature(net_save_robustfeature, testloader) 441 | print('------------') 442 | 443 | 444 | 445 | 446 | class ODEBlocktemp(nn.Module): 447 | 448 | def __init__(self, odefunc): 449 | super(ODEBlocktemp, self).__init__() 450 | self.odefunc = odefunc 451 | self.integration_time = torch.tensor([0, 5]).float() 452 | 453 | def forward(self, x): 454 | out = self.odefunc(0, x) 455 | return out 456 | 457 | @property 458 | def nfe(self): 459 | return self.odefunc.nfe 460 | 461 | @nfe.setter 462 | def nfe(self, value): 463 | self.odefunc.nfe = value 464 | 465 | def df_dz_regularizer(odefunc, z): 466 | regu_diag = 0. 467 | regu_offdiag = 0.0 468 | for ii in np.random.choice(z.shape[0], min(numm,z.shape[0]),replace=False): 469 | batchijacobian = torch.autograd.functional.jacobian(lambda x: odefunc(torch.tensor(time_df).to(device), x), z[ii:ii+1,...], create_graph=True) 470 | batchijacobian = batchijacobian.view(z.shape[1],-1) 471 | if batchijacobian.shape[0]!=batchijacobian.shape[1]: 472 | raise Exception("wrong dim in jacobian") 473 | 474 | tempdiag = torch.diagonal(batchijacobian, 0) 475 | regu_diag += torch.exp(exponent*(tempdiag+trans)) 476 | offdiat = torch.sum(torch.abs(batchijacobian)*((-1*torch.eye(batchijacobian.shape[0]).to(device)+0.5)*2), dim=0) 477 | off_diagtemp = torch.exp(exponent_off*(offdiat+transoffdig)) 478 | regu_offdiag += off_diagtemp 479 | 480 | return regu_diag/numm, regu_offdiag/numm 481 | 482 | 483 | def f_regularizer(odefunc, z): 484 | tempf = torch.abs(odefunc(torch.tensor(time_df).to(device), z)) 485 | regu_f = torch.pow(exponent_f*tempf,2) 486 | 487 | return regu_f 488 | 489 | 490 | 491 | 492 | 493 | 494 | 495 | def save_testing_feature(model, dataset_loader): 496 | x_save = [] 497 | y_save = [] 498 | modulelist = list(model) 499 | layernum = 0 500 | for x, y in dataset_loader: 501 | x = x.to(device) 502 | y_ = np.array(y.numpy()) 503 | 504 | for l in modulelist[0:2]: 505 | x = l(x) 506 | xo = x 507 | x_ = xo.cpu().detach().numpy() 508 | x_save.append(x_) 509 | y_save.append(y_) 510 | 511 | x_save = np.concatenate(x_save) 512 | y_save = np.concatenate(y_save) 513 | 514 | np.savez(test_savepath, x_save=x_save, y_save=y_save) 515 | 516 | 517 | 518 | 519 | class DensemnistDatasetTrain(Dataset): 520 | def __init__(self): 521 | """ 522 | """ 523 | npzfile = np.load(train_savepath) 524 | 525 | self.x = npzfile['x_save'] 526 | self.y = npzfile['y_save'] 527 | def __len__(self): 528 | return len(self.x) 529 | 530 | def __getitem__(self, idx): 531 | x = self.x[idx,...] 532 | y = self.y[idx] 533 | 534 | return x,y 535 | class DensemnistDatasetTest(Dataset): 536 | def __init__(self): 537 | """ 538 | """ 539 | npzfile = np.load(test_savepath) 540 | 541 | self.x = npzfile['x_save'] 542 | self.y = npzfile['y_save'] 543 | def __len__(self): 544 | return len(self.x) 545 | 546 | def __getitem__(self, idx): 547 | x = self.x[idx,...] 548 | y = self.y[idx] 549 | 550 | return x,y 551 | 552 | 553 | 554 | 555 | class ODEBlocktemp(nn.Module): 556 | def __init__(self, odefunc): 557 | super(ODEBlocktemp, self).__init__() 558 | self.odefunc = odefunc 559 | self.integration_time = torch.tensor([0, 5]).float() 560 | def forward(self, x): 561 | out = self.odefunc(0, x) 562 | return out 563 | @property 564 | def nfe(self): 565 | return self.odefunc.nfe 566 | @nfe.setter 567 | def nfe(self, value): 568 | self.odefunc.nfe = value 569 | 570 | 571 | def accuracy(model, dataset_loader): 572 | total_correct = 0 573 | for x, y in dataset_loader: 574 | x = x.to(device) 575 | y = one_hot(np.array(y.numpy()), 10) 576 | 577 | 578 | target_class = np.argmax(y, axis=1) 579 | predicted_class = np.argmax(model(x).cpu().detach().numpy(), axis=1) 580 | total_correct += np.sum(predicted_class == target_class) 581 | return total_correct / len(dataset_loader.dataset) 582 | 583 | 584 | 585 | def save_training_feature(model, dataset_loader, fake_embeddings_loader=None ): 586 | x_save = [] 587 | y_save = [] 588 | modulelist = list(model) 589 | # Processing fake embeddings if provided 590 | if fake_embeddings_loader is not None: 591 | for x, y in fake_embeddings_loader: 592 | x = x.to(device) 593 | y_ = y.numpy() # No need to use np.array here 594 | 595 | # Forward pass through the model up to the desired layer 596 | for l in modulelist[1:2]: 597 | x = l(x) 598 | xo = x 599 | 600 | x_ = xo.cpu().detach().numpy() 601 | x_save.append(x_) 602 | y_save.append(y_) 603 | 604 | 605 | x_save = [] 606 | y_save = [] 607 | for x, y in dataset_loader: 608 | x = x.to(device) 609 | y_ = y.numpy() # No need to use np.array here 610 | 611 | # Forward pass through the model up to the desired layer 612 | for l in modulelist[0:2]: 613 | x = l(x) 614 | xo = x 615 | 616 | x_ = xo.cpu().detach().numpy() 617 | 618 | x_save.append(x_) 619 | 620 | y_save.append(y_) 621 | 622 | 623 | x_save = np.concatenate(x_save) 624 | 625 | y_save = np.concatenate(y_save) 626 | 627 | 628 | np.savez(train_savepath, x_save=x_save, y_save=y_save) 629 | 630 | 631 | 632 | def save_testing_feature(model, dataset_loader): 633 | x_save = [] 634 | y_save = [] 635 | modulelist = list(model) 636 | layernum = 0 637 | for x, y in dataset_loader: 638 | x = x.to(device) 639 | y_ = np.array(y.numpy()) 640 | 641 | for l in modulelist[0:2]: 642 | x = l(x) 643 | xo = x 644 | x_ = xo.cpu().detach().numpy() 645 | x_save.append(x_) 646 | y_save.append(y_) 647 | 648 | x_save = np.concatenate(x_save) 649 | y_save = np.concatenate(y_save) 650 | 651 | np.savez(test_savepath, x_save=x_save, y_save=y_save) 652 | 653 | 654 | 655 | 656 | -------------------------------------------------------------------------------- /aros_node/version.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.0.1" 2 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | 2 | import aros_node 3 | import argparse 4 | import torch 5 | import torch.nn as nn 6 | from aros_node.evaluate import * 7 | from aros_node.utils import * 8 | from tqdm.notebook import tqdm 9 | from aros_node.data_loader import * 10 | from aros_node.stability_loss_function import * 11 | 12 | def main(): 13 | parser = argparse.ArgumentParser(description="Hyperparameters for the script") 14 | 15 | parser.add_argument('--fast', type=bool, default=True, help='Toggle between fast and full fake data generation modes') 16 | parser.add_argument('--epoch1', type=int, default=2, help='Number of epochs for stage 1') 17 | parser.add_argument('--epoch2', type=int, default=1, help='Number of epochs for stage 2') 18 | parser.add_argument('--epoch3', type=int, default=2, help='Number of epochs for stage 3') 19 | parser.add_argument('--in_dataset', type=str, default='cifar10', choices=['cifar10', 'cifar100'], help='The in-distribution dataset to be used') 20 | parser.add_argument('--threat_model', type=str, default='Linf', help='Adversarial threat model for robust training') 21 | parser.add_argument('--noise_std', type=float, default=1, help='Standard deviation of noise for generating noisy fake embeddings') 22 | parser.add_argument('--attack_eps', type=float, default=8/255, help='Perturbation bound (epsilon) for PGD attack') 23 | parser.add_argument('--attack_steps', type=int, default=10, help='Number of steps for the PGD attack') 24 | parser.add_argument('--attack_alpha', type=float, default=2.5 * (8/255) / 10, help='Step size (alpha) for each PGD attack iteration') 25 | 26 | args = parser.parse_args('') 27 | 28 | # Set the default model name based on the selected dataset 29 | if args.in_dataset == 'cifar10': 30 | default_model_name = 'Rebuffi2021Fixing_70_16_cutmix_extra' 31 | elif args.in_dataset == 'cifar100': 32 | default_model_name = 'Wang2023Better_WRN-70-16' 33 | 34 | parser.add_argument('--model_name', type=str, default=default_model_name, choices=['Rebuffi2021Fixing_70_16_cutmix_extra', 'Wang2023Better_WRN-70-16'], help='The pre-trained model to be used for feature extraction') 35 | 36 | # Re-parse arguments to include model_name selection based on the dataset 37 | args = parser.parse_args('') 38 | num_classes = 10 if args.in_dataset == 'cifar10' else 100 39 | 40 | trainloader, testloader,test_set, ID_OOD_loader = get_loaders(in_dataset=args.in_dataset) 41 | 42 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 43 | 44 | 45 | robust_backbone = load_model(model_name=args.model_name, dataset=args.in_dataset, threat_model=args.threat_model).to(device) 46 | last_layer_name, last_layer = list(robust_backbone.named_children())[-1] 47 | setattr(robust_backbone, last_layer_name, nn.Identity()) 48 | fake_loader=None 49 | 50 | 51 | num_fake_samples = len(trainloader.dataset) // num_classes 52 | 53 | 54 | 55 | 56 | embeddings, labels = [], [] 57 | 58 | with torch.no_grad(): 59 | for imgs, lbls in trainloader: 60 | imgs = imgs.to(device, non_blocking=True) 61 | embed = robust_backbone(imgs).cpu() # move to CPU only once per batch 62 | embeddings.append(embed) 63 | labels.append(lbls) 64 | embeddings = torch.cat(embeddings).numpy() 65 | labels = torch.cat(labels).numpy() 66 | 67 | 68 | print("embedding computed...") 69 | 70 | 71 | if args.fast==False: 72 | gmm_dict = {} 73 | for cls in np.unique(labels): 74 | cls_embed = embeddings[labels == cls] 75 | gmm = GaussianMixture(n_components=1, covariance_type='full').fit(cls_embed) 76 | gmm_dict[cls] = gmm 77 | 78 | print("fake crafing...") 79 | 80 | fake_data = [] 81 | 82 | 83 | for cls, gmm in gmm_dict.items(): 84 | samples, likelihoods = [], [] 85 | while len(samples) < num_samples_needed: 86 | s = gmm.sample(100)[0] 87 | likelihood = gmm.score_samples(s) 88 | samples.append(s[likelihood < np.quantile(likelihood, 0.001)]) 89 | likelihoods.append(likelihood[likelihood < np.quantile(likelihood, 0.001)]) 90 | if sum(len(smp) for smp in samples) >= num_samples_needed: 91 | break 92 | samples = np.vstack(samples)[:num_samples_needed] 93 | fake_data.append(samples) 94 | 95 | fake_data = np.vstack(fake_data) 96 | fake_data = torch.tensor(fake_data).float() 97 | fake_data = F.normalize(fake_data, p=2, dim=1) 98 | fake_labels = torch.full((fake_data.shape[0],), 10) 99 | fake_loader = DataLoader(TensorDataset(fake_data, fake_labels), batch_size=128, shuffle=True) 100 | 101 | if args.fast==True: 102 | noisy_embeddings = torch.tensor(embeddings) + args.noise_std * torch.randn_like(torch.tensor(embeddings)) 103 | # Normalize Noisy Embeddings 104 | noisy_embeddings = F.normalize(noisy_embeddings, p=2, dim=1)[:len(trainloader.dataset)//num_classes] 105 | # Convert to DataLoader if needed 106 | fake_labels = torch.full((noisy_embeddings.shape[0],), num_classes)[:len(trainloader.dataset)//num_classes] 107 | fake_loader = DataLoader(TensorDataset(noisy_embeddings, fake_labels), batch_size=128, shuffle=True) 108 | 109 | 110 | final_model = stability_loss_function_(trainloader, testloader, robust_backbone, num_classes, fake_loader, last_layer, args) 111 | 112 | 113 | test_attack = PGD_AUC(final_model, eps=args.attack_eps, steps=args.attack_steps, alpha=args.attack_alpha, num_classes=num_classes) 114 | get_clean_AUC(final_model, ID_OOD_loader , device, num_classes) 115 | adv_auc = get_auc_adversarial(model=final_model, test_loader=ID_OOD_loader, test_attack=test_attack, device=device, num_classes=num_classes) 116 | 117 | 118 | 119 | if __name__ == "__main__": 120 | main() 121 | 122 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=42", "wheel"] 3 | build-backend = "setuptools.build_meta" 4 | -------------------------------------------------------------------------------- /reinstall.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Re-install the package. By running './reinstall.sh' 4 | # 5 | # Note that AROS uses the build 6 | # system specified in 7 | # PEP517 https://peps.python.org/pep-0517/ and 8 | # PEP518 https://peps.python.org/pep-0518/ 9 | # and hence there is no setup.py file. 10 | 11 | set -e # abort on error 12 | 13 | pip uninstall -y aros-node 14 | 15 | # Get version 16 | VERSION=0.0.1 17 | echo "Upgrading to AROS v${VERSION}" 18 | 19 | # Upgrade the build system (PEP517/518 compatible) 20 | python3 -m pip install virtualenv 21 | python3 -m pip install --upgrade build 22 | python3 -m build --sdist --wheel . 23 | 24 | # Reinstall the package with most recent version 25 | pip install --upgrade --no-cache-dir "dist/aros_node-${VERSION}-py3-none-any.whl" 26 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | geotorch 2 | torch 3 | torchdiffeq 4 | timm==1.0.9 5 | robustbench 6 | numpy 7 | scikit-learn 8 | scipy 9 | tqdm -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | name = aros-node 3 | version = attr: aros_node.version.__version__ 4 | author = Hossein Mirzaei, Mackenzie Mathis 5 | author_email = mackenzie@post.harvard.edu 6 | description = AROS: Adversarially Robust Out-of-Distribution Detection through Stability 7 | long_description = file: README.md 8 | long_description_content_type = text/markdown 9 | url = https://github.com/AdaptiveMotorControlLab/AROS 10 | 11 | [options] 12 | packages = find: 13 | include_package_data = True 14 | python_requires = >=3.10 15 | install_requires = file: requirements.txt 16 | 17 | [options.extras_require] 18 | dev = 19 | pylint 20 | toml 21 | yapf 22 | black 23 | pytest -------------------------------------------------------------------------------- /tests/test_dataloaders.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | from torch.utils.data import DataLoader, Subset 4 | from torchvision.datasets import CIFAR10, CIFAR100 5 | from torchvision.transforms import ToTensor 6 | from aros_node import ( 7 | LabelChangedDataset, 8 | get_subsampled_subset, 9 | get_loaders, 10 | ) 11 | 12 | # Set up transformations and datasets for tests 13 | transform_tensor = ToTensor() 14 | 15 | @pytest.fixture 16 | def cifar10_datasets(): 17 | trainset = CIFAR10(root='./data', train=True, download=True, transform=transform_tensor) 18 | testset = CIFAR10(root='./data', train=False, download=True, transform=transform_tensor) 19 | return trainset, testset 20 | 21 | @pytest.fixture 22 | def cifar100_datasets(): 23 | trainset = CIFAR100(root='./data', train=True, download=True, transform=transform_tensor) 24 | testset = CIFAR100(root='./data', train=False, download=True, transform=transform_tensor) 25 | return trainset, testset 26 | 27 | def test_label_changed_dataset(cifar10_datasets): 28 | _, testset = cifar10_datasets 29 | new_label = 99 30 | relabeled_dataset = LabelChangedDataset(testset, new_label) 31 | 32 | assert len(relabeled_dataset) == len(testset), "Relabeled dataset should match the original dataset length" 33 | 34 | for img, label in relabeled_dataset: 35 | assert label == new_label, "All labels should be changed to the new label" 36 | 37 | def test_get_subsampled_subset(cifar10_datasets): 38 | trainset, _ = cifar10_datasets 39 | subset_ratio = 0.1 40 | subset = get_subsampled_subset(trainset, subset_ratio=subset_ratio) 41 | 42 | expected_size = int(len(trainset) * subset_ratio) 43 | assert len(subset) == expected_size, f"Subset size should be {expected_size}" 44 | 45 | def test_get_loaders_cifar10(cifar10_datasets): 46 | train_loader, test_loader, test_set, test_loader_vs_other = get_loaders('cifar10') 47 | 48 | assert isinstance(train_loader, DataLoader) 49 | assert isinstance(test_loader, DataLoader) 50 | assert isinstance(test_loader_vs_other, DataLoader) 51 | 52 | for images, labels in test_loader: 53 | assert images.shape[0] == 16, "Test loader batch size should be 16" 54 | break 55 | 56 | def test_get_loaders_cifar100(cifar100_datasets): 57 | train_loader, test_loader, test_set, test_loader_vs_other = get_loaders('cifar100') 58 | 59 | assert isinstance(train_loader, DataLoader) 60 | assert isinstance(test_loader, DataLoader) 61 | assert isinstance(test_loader_vs_other, DataLoader) 62 | 63 | for images, labels in test_loader: 64 | assert images.shape[0] == 16, "Test loader batch size should be 16" 65 | break 66 | 67 | def test_get_loaders_invalid_dataset(): 68 | with pytest.raises(ValueError, match="Dataset 'invalid_dataset' is not supported."): 69 | get_loaders('invalid_dataset') 70 | --------------------------------------------------------------------------------