├── .gitignore ├── README.md ├── datasets ├── COCO2014 │ ├── README.md │ └── download_coco2014.sh └── VOC2007 │ └── README.md ├── environment.yml ├── examples ├── Intro Teaser.png └── Segmentation Comparisons.png ├── notebooks ├── compute_score.ipynb ├── error_measures-imports.ipynb ├── error_measures.ipynb ├── evaluate_class_scores.ipynb ├── image-plot.ipynb ├── image-plots-classes.ipynb ├── image_plot_topk.ipynb └── produce_tables.ipynb └── src ├── config_files ├── classifier_training │ ├── classifier_resnet50_coco_training.cfg │ ├── classifier_resnet50_voc_training.cfg │ ├── classifier_vgg16_coco_training.cfg │ └── classifier_vgg16_voc_training.cfg ├── explainer_training │ ├── explainer_resnet50_coco_training.cfg │ ├── explainer_resnet50_voc_training.cfg │ ├── explainer_vgg16_coco_training.cfg │ └── explainer_vgg16_voc_training.cfg └── testing_and_mask_saving │ ├── explainer_resnet50_coco_test_and_save_masks.cfg │ ├── explainer_resnet50_voc_test_and_save_masks.cfg │ ├── explainer_vgg16_coco_test_and_save_masks.cfg │ └── explainer_vgg16_voc_test_and_save_masks.cfg ├── data ├── dataloader.py └── dataset.py ├── evaluation ├── IGOS_pp │ ├── LICENSE │ ├── args.py │ ├── igos_utils.py │ ├── main.py │ ├── methods.py │ └── methods_helper.py ├── eval_utils │ ├── assessment_metrics.py │ ├── compute_masks.py │ └── compute_scores.py ├── evaluate.py ├── evaluate_segmentations.py ├── explainer.py ├── extremal_perturbations.py ├── generate_coco_segmentations.py ├── grad_cam.py ├── guided_backprop.py ├── print_mean_scores.py ├── rise.py └── rt_saliency.py ├── main.py ├── models ├── classifier.py ├── explainer.py ├── explainer_classifier.py ├── explainer_classifier_rtsal.py └── interpretable_fcnn.py └── utils ├── argparser.py ├── helper.py ├── image_utils.py ├── loss.py ├── metrics.py ├── rtsal_helper.py └── rtsal_loss.py /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | # Created by https://www.gitignore.io/api/macos,python,R,linux,vim,emacs,visualstudiocode,intellij 3 | # Edit at https://www.gitignore.io/?templates=macos,python,R,linux,vim,emacs,visualstudiocode,intellij 4 | 5 | /.idea/* 6 | /.vscode 7 | /MaskedTestImages/* 8 | /annotations/ 9 | /src/tb_logs/* 10 | /CUB_200_2011/* 11 | /VOCData/VOCtrainval* 12 | /VOCData/VOCtest* 13 | /VOCData/VOCdevkit/* 14 | src/checkpoints/* 15 | checkpoints/* 16 | /src/test.py 17 | /src/testing.sh 18 | /src/script_conda.sh 19 | /src/script_virtualenv.sh 20 | /src/benchmark/rise_masks/ 21 | /src/benchmark/coco_segmentations/ 22 | /src/benchmark/rise_masks_coco/ 23 | /src/benchmark/rise_masks_voc/ 24 | /datasets/VOC2007/VOCdevkit/ 25 | /datasets/VOC2007/VOC* 26 | /datasets/COCO2014/annotations/ 27 | /datasets/COCO2014/val2014/ 28 | /datasets/COCO2014/train2014/ 29 | /notebooks/ATTENUATED* 30 | *lightning_logs/ 31 | *.sh 32 | !/datasets/COCO2014/download_coco2014.sh 33 | *.jpg 34 | *.png 35 | *.pdf 36 | *.npz 37 | *.json 38 | *.txt 39 | *.zip 40 | 41 | ### Emacs ### 42 | # -*- mode: gitignore; -*- 43 | *~ 44 | \#*\# 45 | /.emacs.desktop 46 | /.emacs.desktop.lock 47 | *.elc 48 | auto-save-list 49 | tramp 50 | .\#* 51 | 52 | # Org-mode 53 | .org-id-locations 54 | *_archive 55 | 56 | # flymake-mode 57 | *_flymake.* 58 | 59 | # eshell files 60 | /eshell/history 61 | /eshell/lastdir 62 | 63 | # elpa packages 64 | /elpa/ 65 | 66 | # reftex files 67 | *.rel 68 | 69 | # AUCTeX auto folder 70 | /auto/ 71 | 72 | # cask packages 73 | .cask/ 74 | dist/ 75 | 76 | # Flycheck 77 | flycheck_*.el 78 | 79 | # server auth directory 80 | /server/ 81 | 82 | # projectiles files 83 | .projectile 84 | 85 | # directory configuration 86 | .dir-locals.el 87 | 88 | # network security 89 | /network-security.data 90 | 91 | 92 | ### Intellij ### 93 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and WebStorm 94 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 95 | 96 | # User-specific stuff 97 | .idea/**/workspace.xml 98 | .idea/**/tasks.xml 99 | .idea/**/usage.statistics.xml 100 | .idea/**/dictionaries 101 | .idea/**/shelf 102 | 103 | # Generated files 104 | .idea/**/contentModel.xml 105 | 106 | # Sensitive or high-churn files 107 | .idea/**/dataSources/ 108 | .idea/**/dataSources.ids 109 | .idea/**/dataSources.local.xml 110 | .idea/**/sqlDataSources.xml 111 | .idea/**/dynamic.xml 112 | .idea/**/uiDesigner.xml 113 | .idea/**/dbnavigator.xml 114 | 115 | # Gradle 116 | .idea/**/gradle.xml 117 | .idea/**/libraries 118 | 119 | # Gradle and Maven with auto-import 120 | # When using Gradle or Maven with auto-import, you should exclude module files, 121 | # since they will be recreated, and may cause churn. Uncomment if using 122 | # auto-import. 123 | # .idea/modules.xml 124 | # .idea/*.iml 125 | # .idea/modules 126 | 127 | # CMake 128 | cmake-build-*/ 129 | 130 | # Mongo Explorer plugin 131 | .idea/**/mongoSettings.xml 132 | 133 | # File-based project format 134 | *.iws 135 | 136 | # IntelliJ 137 | out/ 138 | 139 | # mpeltonen/sbt-idea plugin 140 | .idea_modules/ 141 | 142 | # JIRA plugin 143 | atlassian-ide-plugin.xml 144 | 145 | # Cursive Clojure plugin 146 | .idea/replstate.xml 147 | 148 | # Crashlytics plugin (for Android Studio and IntelliJ) 149 | com_crashlytics_export_strings.xml 150 | crashlytics.properties 151 | crashlytics-build.properties 152 | fabric.properties 153 | 154 | # Editor-based Rest Client 155 | .idea/httpRequests 156 | 157 | # Android studio 3.1+ serialized cache file 158 | .idea/caches/build_file_checksums.ser 159 | 160 | # JetBrains templates 161 | **___jb_tmp___ 162 | 163 | ### Intellij Patch ### 164 | # Comment Reason: https://github.com/joeblau/gitignore.io/issues/186#issuecomment-215987721 165 | 166 | # *.iml 167 | # modules.xml 168 | # .idea/misc.xml 169 | # *.ipr 170 | 171 | # Sonarlint plugin 172 | .idea/sonarlint 173 | 174 | ### Linux ### 175 | 176 | # temporary files which can be created if a process still has a handle open of a deleted file 177 | .fuse_hidden* 178 | 179 | # KDE directory preferences 180 | .directory 181 | 182 | # Linux trash folder which might appear on any partition or disk 183 | .Trash-* 184 | 185 | # .nfs files are created when an open file is removed but is still being accessed 186 | .nfs* 187 | 188 | ### macOS ### 189 | # General 190 | .DS_Store 191 | .AppleDouble 192 | .LSOverride 193 | 194 | # Icon must end with two \r 195 | Icon 196 | 197 | # Thumbnails 198 | ._* 199 | 200 | # Files that might appear in the root of a volume 201 | .DocumentRevisions-V100 202 | .fseventsd 203 | .Spotlight-V100 204 | .TemporaryItems 205 | .Trashes 206 | .VolumeIcon.icns 207 | .com.apple.timemachine.donotpresent 208 | 209 | # Directories potentially created on remote AFP share 210 | .AppleDB 211 | .AppleDesktop 212 | Network Trash Folder 213 | Temporary Items 214 | .apdisk 215 | 216 | ### Python ### 217 | # Byte-compiled / optimized / DLL files 218 | __pycache__/ 219 | *.py[cod] 220 | *$py.class 221 | 222 | # C extensions 223 | *.so 224 | 225 | # Distribution / packaging 226 | .Python 227 | build/ 228 | develop-eggs/ 229 | downloads/ 230 | eggs/ 231 | .eggs/ 232 | lib/ 233 | lib64/ 234 | parts/ 235 | sdist/ 236 | var/ 237 | wheels/ 238 | pip-wheel-metadata/ 239 | share/python-wheels/ 240 | *.egg-info/ 241 | .installed.cfg 242 | *.egg 243 | MANIFEST 244 | 245 | # PyInstaller 246 | # Usually these files are written by a python script from a template 247 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 248 | *.manifest 249 | *.spec 250 | 251 | # Installer logs 252 | pip-log.txt 253 | pip-delete-this-directory.txt 254 | 255 | # Unit test / coverage reports 256 | htmlcov/ 257 | .tox/ 258 | .nox/ 259 | .coverage 260 | .coverage.* 261 | .cache 262 | nosetests.xml 263 | coverage.xml 264 | *.cover 265 | .hypothesis/ 266 | .pytest_cache/ 267 | 268 | # Translations 269 | *.mo 270 | *.pot 271 | 272 | # Django stuff: 273 | *.log 274 | local_settings.py 275 | db.sqlite3 276 | 277 | # Flask stuff: 278 | instance/ 279 | .webassets-cache 280 | 281 | # Scrapy stuff: 282 | .scrapy 283 | 284 | # Sphinx documentation 285 | docs/_build/ 286 | 287 | # PyBuilder 288 | target/ 289 | 290 | # Jupyter Notebook 291 | .ipynb_checkpoints 292 | 293 | # IPython 294 | profile_default/ 295 | ipython_config.py 296 | 297 | # pyenv 298 | .python-version 299 | 300 | # pipenv 301 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 302 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 303 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 304 | # install all needed dependencies. 305 | #Pipfile.lock 306 | 307 | # celery beat schedule file 308 | celerybeat-schedule 309 | 310 | # SageMath parsed files 311 | *.sage.py 312 | 313 | # Environments 314 | .env 315 | .venv 316 | env/ 317 | venv/ 318 | ENV/ 319 | env.bak/ 320 | venv.bak/ 321 | 322 | # Spyder project settings 323 | .spyderproject 324 | .spyproject 325 | 326 | # Rope project settings 327 | .ropeproject 328 | 329 | # mkdocs documentation 330 | /site 331 | 332 | # mypy 333 | .mypy_cache/ 334 | .dmypy.json 335 | dmypy.json 336 | 337 | # Pyre type checker 338 | .pyre/ 339 | 340 | ### Vim ### 341 | # Swap 342 | [._]*.s[a-v][a-z] 343 | [._]*.sw[a-p] 344 | [._]s[a-rt-v][a-z] 345 | [._]ss[a-gi-z] 346 | [._]sw[a-p] 347 | 348 | # Session 349 | Session.vim 350 | 351 | # Temporary 352 | .netrwhist 353 | # Auto-generated tag files 354 | tags 355 | # Persistent undo 356 | [._]*.un~ 357 | 358 | ### VisualStudioCode ### 359 | .vscode/* 360 | !.vscode/settings.json 361 | !.vscode/tasks.json 362 | !.vscode/launch.json 363 | !.vscode/extensions.json 364 | 365 | ### VisualStudioCode Patch ### 366 | # Ignore all local history of files 367 | .history 368 | 369 | # End of https://www.gitignore.io/api/macos,python,R,linux,vim,emacs,visualstudiocode,intellij 370 | 371 | # Renku 372 | .renku.lock 373 | .renku/tmp 374 | .renku/cache 375 | -------------------------------------------------------------------------------- /datasets/COCO2014/README.md: -------------------------------------------------------------------------------- 1 | # How to use the COCO dataset 2 | 3 | Execute the script `download_coco2014.sh` to download the COCO2014 dataset. This will take quite some time to execute since the dataset is relatively large. Note that we have used the val2014 dataset as a testing dataset and therefore needed to do a custom train/val split on the train2014 dataset to do validation after each training epoch. The corresponding annotations for the split will be downloaded to the `annotations` directory and loaded automatically in the code. 4 | -------------------------------------------------------------------------------- /datasets/COCO2014/download_coco2014.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | wget http://images.cocodataset.org/zips/train2014.zip 4 | wget http://images.cocodataset.org/zips/val2014.zip 5 | 6 | unzip -q train2014.zip 7 | unzip -q val2014.zip 8 | 9 | rm train2014.zip 10 | rm val2014.zip 11 | 12 | mkdir -p annotations 13 | (cd annotations ; 14 | wget --content-disposition https://polybox.ethz.ch/index.php/s/a7x8TgZFE3R1vcY/download 15 | wget --content-disposition https://polybox.ethz.ch/index.php/s/Qu4238P6BJp6LBi/download 16 | wget --content-disposition https://polybox.ethz.ch/index.php/s/sfrdSn6pBamlDGf/download) 17 | -------------------------------------------------------------------------------- /datasets/VOC2007/README.md: -------------------------------------------------------------------------------- 1 | # How to use the VOC2007 dataset 2 | 3 | Contrary to the COCO2014 dataset, which you have to first download with our script in the respective folder, the VOC2007 dataset should be downloaded automatically when you execute the `main.py` script with the `--dataset="VOC"` option (default). Sometimes, the download servers are down, in which case you would need to get the dataset yourself and store it in this directory or try another time when the servers are up again. 4 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: pytorch_gpu 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | - defaults 6 | dependencies: 7 | - _libgcc_mutex=0.1=conda_forge 8 | - _openmp_mutex=4.5=1_llvm 9 | - absl-py=0.11.0=py38h578d9bd_0 10 | - aiohttp=3.7.3=py38h497a2fe_0 11 | - args=0.1.0=py38h32f6830_1003 12 | - asttokens=2.0.5=pyhd8ed1ab_0 13 | - async-timeout=3.0.1=py_1000 14 | - attrs=20.3.0=pyhd3deb0d_0 15 | - backcall=0.2.0=pyh9f0ad1d_0 16 | - backports=1.0=py_2 17 | - backports.functools_lru_cache=1.6.4=pyhd8ed1ab_0 18 | - blas=1.0=mkl 19 | - blinker=1.4=py_1 20 | - bottleneck=1.3.4=py38hce1f21e_0 21 | - brotlipy=0.7.0=py38h8df0ef7_1001 22 | - bzip2=1.0.8=h7b6447c_0 23 | - c-ares=1.17.1=h36c2ea0_0 24 | - ca-certificates=2022.3.29=h06a4308_1 25 | - cachetools=4.1.1=py_0 26 | - certifi=2021.10.8=py38h06a4308_2 27 | - cffi=1.14.4=py38ha65f79e_1 28 | - chardet=3.0.4=py38h924ce5b_1008 29 | - click=7.1.2=pyh9f0ad1d_0 30 | - clint=0.5.1=py_1 31 | - cryptography=3.3.1=py38h2b97feb_0 32 | - cudatoolkit=11.3.1=h2bc3f7f_2 33 | - cython=0.29.23=py38h709712a_1 34 | - debugpy=1.5.1=py38h295c915_0 35 | - decorator=5.1.1=pyhd8ed1ab_0 36 | - entrypoints=0.4=pyhd8ed1ab_0 37 | - executing=0.8.3=pyhd8ed1ab_0 38 | - ffmpeg=4.3=hf484d3e_0 39 | - freetype=2.10.4=h7ca028e_0 40 | - fsspec=0.8.5=pyhd8ed1ab_0 41 | - future=0.18.2=py38h578d9bd_2 42 | - gmp=6.2.1=h2531618_2 43 | - gnutls=3.6.15=he1e5248_0 44 | - google-auth=1.24.0=pyhd3deb0d_0 45 | - google-auth-oauthlib=0.4.1=py_2 46 | - grpcio=1.34.0=py38hdd6454d_0 47 | - idna=2.10=pyh9f0ad1d_0 48 | - importlib-metadata=3.3.0=py38h578d9bd_2 49 | - ipykernel=6.10.0=py38h7f3c49e_0 50 | - ipython=8.2.0=py38h578d9bd_0 51 | - jedi=0.18.1=py38h578d9bd_0 52 | - jpeg=9b=h024ee3a_2 53 | - jupyter_client=7.1.2=pyhd8ed1ab_0 54 | - jupyter_core=4.9.2=py38h578d9bd_0 55 | - lame=3.100=h7b6447c_0 56 | - lcms2=2.11=h396b838_0 57 | - ld_impl_linux-64=2.35.1=hea4e1c9_1 58 | - libffi=3.3=h58526e2_2 59 | - libgcc-ng=9.3.0=h5dbcf3e_17 60 | - libiconv=1.15=h63c8f33_5 61 | - libidn2=2.3.1=h27cfd23_0 62 | - libpng=1.6.37=h21135ba_2 63 | - libprotobuf=3.14.0=h780b84a_0 64 | - libsodium=1.0.18=h36c2ea0_1 65 | - libstdcxx-ng=9.3.0=h2ae2ef3_17 66 | - libtasn1=4.16.0=h27cfd23_0 67 | - libtiff=4.1.0=h2733197_1 68 | - libunistring=0.9.10=h27cfd23_0 69 | - libuv=1.40.0=h7f98852_0 70 | - llvm-openmp=11.0.0=hfc4b9b4_1 71 | - lz4-c=1.9.2=he1b5a44_3 72 | - markdown=3.3.3=pyh9f0ad1d_0 73 | - matplotlib-inline=0.1.3=pyhd8ed1ab_0 74 | - mkl=2020.4=h726a3e6_304 75 | - mkl-service=2.3.0=py38h1e0a361_2 76 | - mkl_fft=1.2.0=py38hab2c0dc_1 77 | - mkl_random=1.2.0=py38hc5bc63f_1 78 | - multidict=5.1.0=py38h497a2fe_0 79 | - ncurses=6.2=h58526e2_4 80 | - nest-asyncio=1.5.4=pyhd8ed1ab_0 81 | - nettle=3.7.2=hbbd107a_1 82 | - ninja=1.10.2=h4bd325d_0 83 | - numexpr=2.7.3=py38hb2eb853_0 84 | - numpy=1.19.2=py38h54aff64_0 85 | - numpy-base=1.19.2=py38hfa32c7d_0 86 | - oauthlib=3.0.1=py_0 87 | - olefile=0.46=pyh9f0ad1d_1 88 | - openh264=2.1.0=hd408876_0 89 | - openssl=1.1.1n=h7f8727e_0 90 | - packaging=20.8=pyhd3deb0d_0 91 | - parso=0.8.3=pyhd8ed1ab_0 92 | - pexpect=4.8.0=pyh9f0ad1d_2 93 | - pickleshare=0.7.5=py_1003 94 | - pillow=8.0.1=py38he98fc37_0 95 | - pip=20.3.3=pyhd8ed1ab_0 96 | - prompt-toolkit=3.0.27=pyha770c72_0 97 | - protobuf=3.14.0=py38h709712a_0 98 | - psutil=5.8.0=py38h27cfd23_1 99 | - ptyprocess=0.7.0=pyhd3deb0d_0 100 | - pure_eval=0.2.2=pyhd8ed1ab_0 101 | - pyasn1=0.4.8=py_0 102 | - pyasn1-modules=0.2.7=py_0 103 | - pycparser=2.20=pyh9f0ad1d_2 104 | - pygments=2.11.2=pyhd8ed1ab_0 105 | - pyjwt=1.7.1=py_0 106 | - pyopenssl=20.0.1=pyhd8ed1ab_0 107 | - pyparsing=2.4.7=pyh9f0ad1d_0 108 | - pysocks=1.7.1=py38h924ce5b_2 109 | - python=3.8.6=hffdb5ce_2_cpython 110 | - python_abi=3.8=1_cp38 111 | - pytorch=1.10.2=py3.8_cuda11.3_cudnn8.2.0_0 112 | - pytorch-lightning=1.1.2=pyhd8ed1ab_0 113 | - pytorch-mutex=1.0=cuda 114 | - pyyaml=5.3.1=py38h8df0ef7_1 115 | - pyzmq=19.0.2=py38ha71036d_2 116 | - readline=8.0=he28a2e2_2 117 | - requests=2.25.1=pyhd3deb0d_0 118 | - requests-oauthlib=1.3.0=pyh9f0ad1d_0 119 | - rsa=4.6=pyh9f0ad1d_0 120 | - setuptools=49.6.0=py38h924ce5b_2 121 | - six=1.15.0=pyh9f0ad1d_0 122 | - sqlite=3.34.0=h74cdb3f_0 123 | - stack_data=0.2.0=pyhd8ed1ab_0 124 | - tensorboard=2.8.0=pyhd8ed1ab_1 125 | - tensorboard-data-server=0.6.0=py38hca6d32c_0 126 | - tensorboard-plugin-wit=1.7.0=pyh9f0ad1d_0 127 | - tk=8.6.10=h21135ba_1 128 | - torchaudio=0.10.2=py38_cu113 129 | - torchvision=0.11.3=py38_cu113 130 | - tornado=6.1=py38h497a2fe_1 131 | - tqdm=4.54.1=pyhd8ed1ab_1 132 | - traitlets=5.1.1=pyhd8ed1ab_0 133 | - typing-extensions=3.7.4.3=0 134 | - typing_extensions=3.7.4.3=py_0 135 | - urllib3=1.26.2=pyhd8ed1ab_0 136 | - wcwidth=0.2.5=pyh9f0ad1d_2 137 | - werkzeug=1.0.1=pyh9f0ad1d_0 138 | - wheel=0.36.2=pyhd3deb0d_0 139 | - xz=5.2.5=h516909a_1 140 | - yaml=0.2.5=h516909a_0 141 | - yarl=1.6.3=py38h497a2fe_0 142 | - zeromq=4.3.4=h9c3ff4c_0 143 | - zipp=3.4.0=py_0 144 | - zlib=1.2.11=h516909a_1010 145 | - zstd=1.4.5=h6597ccf_2 146 | - pip: 147 | - argon2-cffi==21.3.0 148 | - argon2-cffi-bindings==21.2.0 149 | - argparse==1.4.0 150 | - beautifulsoup4==4.10.0 151 | - bleach==4.1.0 152 | - configargparse==1.5.3 153 | - cycler==0.10.0 154 | - cython==0.29.23 155 | - dataclasses==0.6 156 | - defusedxml==0.7.1 157 | - funcy==1.16 158 | - grad-cam==1.3.1 159 | - imageio==2.9.0 160 | - importlib-resources==5.2.2 161 | - ipython-genutils==0.2.0 162 | - ipywidgets==7.7.0 163 | - jinja2==3.1.1 164 | - joblib==1.0.1 165 | - jsonschema==4.4.0 166 | - jupyterlab-pygments==0.1.2 167 | - jupyterlab-widgets==1.1.0 168 | - kiwisolver==1.3.1 169 | - markupsafe==2.1.1 170 | - matplotlib==3.4.1 171 | - mistune==0.8.4 172 | - nbclient==0.5.13 173 | - nbconvert==6.4.5 174 | - nbformat==5.2.0 175 | - notebook==6.4.10 176 | - opencv-python==4.5.3.56 177 | - pandas==1.1.5 178 | - pandocfilters==1.5.0 179 | - prometheus-client==0.13.1 180 | - pycocotools==2.0.0 181 | - pymongo==3.12.0 182 | - pyrsistent==0.18.1 183 | - python-dateutil==2.8.1 184 | - pytz==2020.4 185 | - scikit-learn==0.24.2 186 | - scipy==1.6.3 187 | - seaborn==0.11.2 188 | - send2trash==1.8.0 189 | - sklearn==0.0 190 | - soupsieve==2.3.1 191 | - terminado==0.13.3 192 | - test-tube==0.7.5 193 | - testpath==0.6.0 194 | - threadpoolctl==2.2.0 195 | - torch-tb-profiler==0.4.0 196 | - torchmetrics==0.3.2 197 | - torchray==1.0.0.2 198 | - ttach==0.0.3 199 | - webencodings==0.5.1 200 | - widgetsnbextension==3.6.0 201 | prefix: /home/steven/miniconda3/envs/pytorch_gpu 202 | -------------------------------------------------------------------------------- /examples/Intro Teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stevenstalder/NN-Explainer/0a744c939cb46be1ded8a7364bddce04357631d1/examples/Intro Teaser.png -------------------------------------------------------------------------------- /examples/Segmentation Comparisons.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stevenstalder/NN-Explainer/0a744c939cb46be1ded8a7364bddce04357631d1/examples/Segmentation Comparisons.png -------------------------------------------------------------------------------- /notebooks/error_measures-imports.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import numpy as np\n", 10 | "import os, sys \n", 11 | " \n", 12 | "sys.path.append(\"../src\")\n", 13 | "\n", 14 | "from evaluation.eval_utils.assessment_metrics import prob_sparsity, prob_entropy, saliency, continuous_IOU, discrete_IOU\n" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": null, 20 | "metadata": {}, 21 | "outputs": [], 22 | "source": [ 23 | "# Create data\n", 24 | "nx = 100\n", 25 | "nc = 10\n", 26 | "c = 0\n", 27 | "mask1 = np.random.rand(nx,nx)\n", 28 | "seg1 = (np.random.rand(nx,nx)>0.5)*1\n", 29 | "\n", 30 | "mask2 = np.ones([nx,nx])\n", 31 | "mask3 = np.zeros([nx,nx])\n", 32 | "mask4 = seg1.copy()\n", 33 | "\n", 34 | "pvec1 = np.random.rand(nc)\n", 35 | "pvec1 = pvec1 / np.sum(pvec1)\n", 36 | "\n", 37 | "pvec2 = np.zeros(nc)\n", 38 | "pvec2[c] = 1\n", 39 | "\n", 40 | "pvec3 = np.ones(nc)/nc\n" 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": null, 46 | "metadata": {}, 47 | "outputs": [], 48 | "source": [ 49 | "def prob_sparsity(pvec):\n", 50 | " \"\"\"Sparsity measure.\n", 51 | " \n", 52 | " For pvec of the masked image, we want this to be low.\n", 53 | " For pvec of the inverse masked image, we want this to be high.\n", 54 | " \"\"\"\n", 55 | " return np.sum(pvec**2)\n", 56 | "\n", 57 | "def prob_entropy(pvec):\n", 58 | " \"\"\"Sparsity measure.\n", 59 | " \n", 60 | " For pvec of the masked image, we want this to be low.\n", 61 | " For pvec of the inverse masked image, we want this to be high.\n", 62 | " \"\"\"\n", 63 | " return -np.sum(pvec * np.log(np.maximum(pvec, 1e-15)))\n", 64 | "\n", 65 | "\n", 66 | "def saliency(pvec, c, mask):\n", 67 | " \"\"\"\n", 68 | " Continuous saliency measure. \n", 69 | " \n", 70 | " Adaptation from \"Real Time Image Saliency for Black Box Classifiers\n", 71 | " Piotr\", Dabkowski and Gal.\n", 72 | "\n", 73 | " For pvec of the masked image, the lower the better for the masked image.\n", 74 | " \n", 75 | " This measure does not make sense for the inverse masked image.\n", 76 | " \"\"\"\n", 77 | " a = np.maximum(np.mean(mask), 0.05)\n", 78 | " pclass = pvec[c]\n", 79 | " return np.log(a) - np.log(pclass)\n", 80 | "\n", 81 | "def continuous_IOU(mask, seg):\n", 82 | " ### this is no longer the IoU but 1 + the Soergel distance (which is 1- this ratio below)\n", 83 | " #intersection = np.sum(mask * seg)\n", 84 | " #union = np.sum(mask + seg)/2\n", 85 | " #union = np.sum(mask + seg) - intersection\n", 86 | " intersection = np.sum(np.minimum(mask, seg))\n", 87 | " union = np.sum(np.maximum(mask, seg))\n", 88 | " IOU = intersection/(union + 1e-15)\n", 89 | " return IOU\n", 90 | " \n", 91 | "def discrete_IOU(mask, seg, thresh=0.5):\n", 92 | " mask = mask>thresh\n", 93 | " seg = seg>thresh\n", 94 | " intersection = np.sum(np.logical_and(mask, seg)*1.)\n", 95 | " union = np.sum(np.logical_or(mask, seg)*1.)\n", 96 | " IOU = intersection/(union + 1e-15)\n", 97 | " return IOU\n", 98 | " \n", 99 | " " 100 | ] 101 | }, 102 | { 103 | "cell_type": "code", 104 | "execution_count": null, 105 | "metadata": {}, 106 | "outputs": [], 107 | "source": [ 108 | "prob_sparsity(pvec1), prob_sparsity(pvec2), prob_sparsity(pvec3)" 109 | ] 110 | }, 111 | { 112 | "cell_type": "code", 113 | "execution_count": null, 114 | "metadata": {}, 115 | "outputs": [], 116 | "source": [ 117 | "prob_entropy(pvec1), prob_entropy(pvec2), prob_entropy(pvec3)" 118 | ] 119 | }, 120 | { 121 | "cell_type": "code", 122 | "execution_count": null, 123 | "metadata": {}, 124 | "outputs": [], 125 | "source": [ 126 | "saliency(pvec1, c, mask1), saliency(pvec2, c, mask1), saliency(pvec3, c, mask1)" 127 | ] 128 | }, 129 | { 130 | "cell_type": "code", 131 | "execution_count": null, 132 | "metadata": {}, 133 | "outputs": [], 134 | "source": [ 135 | "continuous_IOU(mask1, seg1), continuous_IOU(mask4, seg1), continuous_IOU(mask2, seg1), continuous_IOU(mask3, seg1), " 136 | ] 137 | }, 138 | { 139 | "cell_type": "code", 140 | "execution_count": null, 141 | "metadata": {}, 142 | "outputs": [], 143 | "source": [ 144 | "discrete_IOU(mask1, seg1), discrete_IOU(mask4, seg1), discrete_IOU(mask2, seg1), discrete_IOU(mask3, seg1), \n" 145 | ] 146 | }, 147 | { 148 | "cell_type": "code", 149 | "execution_count": null, 150 | "metadata": {}, 151 | "outputs": [], 152 | "source": [] 153 | } 154 | ], 155 | "metadata": { 156 | "kernelspec": { 157 | "display_name": "pytorch_gpu", 158 | "language": "python", 159 | "name": "pytorch_gpu" 160 | }, 161 | "language_info": { 162 | "codemirror_mode": { 163 | "name": "ipython", 164 | "version": 3 165 | }, 166 | "file_extension": ".py", 167 | "mimetype": "text/x-python", 168 | "name": "python", 169 | "nbconvert_exporter": "python", 170 | "pygments_lexer": "ipython3", 171 | "version": "3.8.6" 172 | } 173 | }, 174 | "nbformat": 4, 175 | "nbformat_minor": 5 176 | } 177 | -------------------------------------------------------------------------------- /notebooks/error_measures.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "%load_ext autoreload\n", 10 | "%autoreload 2" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": null, 16 | "metadata": {}, 17 | "outputs": [], 18 | "source": [ 19 | "import sys\n", 20 | "import os\n" 21 | ] 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": null, 26 | "metadata": {}, 27 | "outputs": [], 28 | "source": [ 29 | "sys.path.insert(0, os.path.abspath(\"../src/\"))\n", 30 | "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"\"\n" 31 | ] 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": null, 36 | "metadata": {}, 37 | "outputs": [], 38 | "source": [ 39 | "import numpy as np\n", 40 | "import matplotlib.pyplot as plt\n", 41 | "from torchray.attribution.extremal_perturbation import extremal_perturbation, contrastive_reward, Perturbation\n", 42 | "from torchray.benchmark import get_example_data, plot_example\n", 43 | "from torchray.utils import get_device\n", 44 | "import torch\n", 45 | "from torchray.attribution.guided_backprop import guided_backprop\n", 46 | "\n", 47 | "from pathlib import Path\n", 48 | "from evaluation.eval_utils.compute_masks import compute_and_save_masks, vedaldi2019, compute_mask\n", 49 | "from evaluation.eval_utils.compute_scores import get_model_and_data\n", 50 | "from utils.helper import get_target_dictionary\n" 51 | ] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "execution_count": null, 56 | "metadata": {}, 57 | "outputs": [], 58 | "source": [ 59 | "# Create data\n", 60 | "nx = 100\n", 61 | "nc = 10\n", 62 | "c = 0\n", 63 | "mask1 = np.random.rand(nx,nx)\n", 64 | "seg1 = (np.random.rand(nx,nx)>0.5)*1\n", 65 | "\n", 66 | "mask2 = np.ones([nx,nx])\n", 67 | "mask3 = np.zeros([nx,nx])\n", 68 | "mask4 = seg1.copy()\n", 69 | "\n", 70 | "pvec1 = np.random.rand(nc)\n", 71 | "pvec1 = pvec1 / np.sum(pvec1)\n", 72 | "\n", 73 | "pvec2 = np.zeros(nc)\n", 74 | "pvec2[c] = 1\n", 75 | "\n", 76 | "pvec3 = np.ones(nc)/nc\n" 77 | ] 78 | }, 79 | { 80 | "cell_type": "code", 81 | "execution_count": null, 82 | "metadata": {}, 83 | "outputs": [], 84 | "source": [ 85 | "from evaluation.eval_utils.assessment_metrics import prob_entropy, prob_entropy, saliency, continuous_IOU, discrete_IOU, prob_sparsity" 86 | ] 87 | }, 88 | { 89 | "cell_type": "code", 90 | "execution_count": null, 91 | "metadata": {}, 92 | "outputs": [], 93 | "source": [ 94 | "prob_sparsity(pvec1), prob_sparsity(pvec2), prob_sparsity(pvec3)" 95 | ] 96 | }, 97 | { 98 | "cell_type": "code", 99 | "execution_count": null, 100 | "metadata": {}, 101 | "outputs": [], 102 | "source": [ 103 | "prob_entropy(pvec1), prob_entropy(pvec2), prob_entropy(pvec3)" 104 | ] 105 | }, 106 | { 107 | "cell_type": "code", 108 | "execution_count": null, 109 | "metadata": {}, 110 | "outputs": [], 111 | "source": [ 112 | "saliency(pvec1, c, mask1), saliency(pvec2, c, mask1), saliency(pvec3, c, mask1)" 113 | ] 114 | }, 115 | { 116 | "cell_type": "code", 117 | "execution_count": null, 118 | "metadata": {}, 119 | "outputs": [], 120 | "source": [ 121 | "continuous_IOU(mask1, seg1), continuous_IOU(mask4, seg1), continuous_IOU(mask2, seg1), continuous_IOU(mask3, seg1)" 122 | ] 123 | }, 124 | { 125 | "cell_type": "code", 126 | "execution_count": null, 127 | "metadata": {}, 128 | "outputs": [], 129 | "source": [ 130 | "discrete_IOU(mask1, seg1), discrete_IOU(mask4, seg1), discrete_IOU(mask2, seg1), discrete_IOU(mask3, seg1), \n" 131 | ] 132 | }, 133 | { 134 | "cell_type": "markdown", 135 | "metadata": {}, 136 | "source": [ 137 | "# Metrics using torchray" 138 | ] 139 | }, 140 | { 141 | "cell_type": "code", 142 | "execution_count": null, 143 | "metadata": {}, 144 | "outputs": [], 145 | "source": [ 146 | "from torchray.attribution.grad_cam import grad_cam\n", 147 | "from torchray.benchmark import get_example_data, plot_example\n", 148 | "from torchray.attribution.rise import rise\n", 149 | "\n" 150 | ] 151 | }, 152 | { 153 | "cell_type": "code", 154 | "execution_count": null, 155 | "metadata": {}, 156 | "outputs": [], 157 | "source": [ 158 | "path_masks = Path(\"../evaluation/\")\n", 159 | "data_path = Path(\"../datasets/VOC2007/\")\n", 160 | "dataset_name = \"VOC\"\n", 161 | "model_name = \"vgg16\"\n", 162 | "model_path = \"../src/checkpoints/pretrained_classifiers/vgg16_voc.ckpt\"\n", 163 | "\n", 164 | "if dataset_name==\"COCO\":\n", 165 | " path_segmentation = path_masks / Path('coco_segmentations')\n", 166 | "elif dataset_name==\"VOC\":\n", 167 | " path_segmentation = Path('../datasets/VOC2007/VOCdevkit/VOC2007/SegmentationClass/')\n", 168 | "\n", 169 | " " 170 | ] 171 | }, 172 | { 173 | "cell_type": "code", 174 | "execution_count": null, 175 | "metadata": {}, 176 | "outputs": [], 177 | "source": [ 178 | "model, data_module = get_model_and_data(data_path, dataset_name, model_name, model_path)" 179 | ] 180 | }, 181 | { 182 | "cell_type": "code", 183 | "execution_count": null, 184 | "metadata": {}, 185 | "outputs": [], 186 | "source": [ 187 | "device = get_device()\n", 188 | "model.to(device);\n" 189 | ] 190 | }, 191 | { 192 | "cell_type": "code", 193 | "execution_count": null, 194 | "metadata": {}, 195 | "outputs": [], 196 | "source": [ 197 | "method = \"guided_backprop\"\n", 198 | "\n", 199 | "path_outputs = path_masks / Path('{}_{}_{}/'.format(dataset_name, model_name, method))\n", 200 | "\n", 201 | "count = compute_and_save_masks(model, data_module, path_segmentation, path_outputs, method=method)\n", 202 | "print(count)" 203 | ] 204 | }, 205 | { 206 | "cell_type": "code", 207 | "execution_count": null, 208 | "metadata": {}, 209 | "outputs": [], 210 | "source": [ 211 | "it = data_module.test_dataloader()" 212 | ] 213 | }, 214 | { 215 | "cell_type": "code", 216 | "execution_count": null, 217 | "metadata": {}, 218 | "outputs": [], 219 | "source": [ 220 | "\n", 221 | "# # Obtain example data.\n", 222 | "# model, x, category_id, _ = get_example_data()\n", 223 | "c = 36\n", 224 | "for s in it:\n", 225 | " img, meta = s\n", 226 | " x = img\n", 227 | " if dataset_name == \"COCO\":\n", 228 | " category_id = meta[0][\"targets\"]\n", 229 | " elif dataset_name == \"VOC\":\n", 230 | " target_dict = get_target_dictionary(include_background_class=False)\n", 231 | " objects = meta[0]['annotation']['object']\n", 232 | " category_id = [target_dict[e[\"name\"]] for e in objects]\n", 233 | " c -= 1\n", 234 | " if c<0:\n", 235 | " break\n", 236 | "x = x.to(device)" 237 | ] 238 | }, 239 | { 240 | "cell_type": "code", 241 | "execution_count": null, 242 | "metadata": {}, 243 | "outputs": [], 244 | "source": [ 245 | "# Grad-CAM backprop.\n", 246 | "if model_name == \"vgg16\":\n", 247 | " feature_extractor = model.feature_extractor[-1]\n", 248 | "elif model_name == \"resnet50\":\n", 249 | " feature_extractor = model.feature_extractor[-2][-1]\n", 250 | " \n", 251 | "saliency = grad_cam(model, x, 1, saliency_layer = feature_extractor, resize=True)\n", 252 | "\n", 253 | "# Plots.\n", 254 | "plot_example(x, saliency, 'grad-cam backprop', category_id)\n", 255 | "np.max(saliency.detach().cpu().numpy()), np.min(saliency.detach().cpu().numpy())" 256 | ] 257 | }, 258 | { 259 | "cell_type": "code", 260 | "execution_count": null, 261 | "metadata": {}, 262 | "outputs": [], 263 | "source": [] 264 | } 265 | ], 266 | "metadata": { 267 | "kernelspec": { 268 | "display_name": "pytorch_gpu", 269 | "language": "python", 270 | "name": "pytorch_gpu" 271 | }, 272 | "language_info": { 273 | "codemirror_mode": { 274 | "name": "ipython", 275 | "version": 3 276 | }, 277 | "file_extension": ".py", 278 | "mimetype": "text/x-python", 279 | "name": "python", 280 | "nbconvert_exporter": "python", 281 | "pygments_lexer": "ipython3", 282 | "version": "3.8.6" 283 | } 284 | }, 285 | "nbformat": 4, 286 | "nbformat_minor": 5 287 | } 288 | -------------------------------------------------------------------------------- /notebooks/image-plot.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "71765b74", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "%load_ext autoreload\n", 11 | "%autoreload 2" 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": null, 17 | "id": "f593ba32", 18 | "metadata": {}, 19 | "outputs": [], 20 | "source": [ 21 | "import matplotlib.pyplot as plt\n", 22 | "import numpy as np\n", 23 | "from pathlib import Path\n", 24 | "from PIL import Image\n", 25 | "import torchvision.transforms as transforms\n", 26 | "import traceback\n", 27 | "import random" 28 | ] 29 | }, 30 | { 31 | "cell_type": "code", 32 | "execution_count": null, 33 | "id": "b07de6ae", 34 | "metadata": {}, 35 | "outputs": [], 36 | "source": [ 37 | "import cv2\n", 38 | "\n", 39 | "def show_cam_on_image(img: np.ndarray,\n", 40 | " mask: np.ndarray,\n", 41 | " use_rgb: bool = False,\n", 42 | " colormap: int = cv2.COLORMAP_JET) -> np.ndarray:\n", 43 | " \"\"\" This function overlays the cam mask on the image as an heatmap.\n", 44 | " By default the heatmap is in BGR format.\n", 45 | " :param img: The base image in RGB or BGR format.\n", 46 | " :param mask: The cam mask.\n", 47 | " :param use_rgb: Whether to use an RGB or BGR heatmap, this should be set to True if 'img' is in RGB format.\n", 48 | " :param colormap: The OpenCV colormap to be used.\n", 49 | " :returns: The default image with the cam overlay.\n", 50 | " \"\"\"\n", 51 | " heatmap = cv2.applyColorMap(np.uint8(255 * (1-mask)), colormap)\n", 52 | " if use_rgb:\n", 53 | " heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)\n", 54 | " heatmap = np.float32(heatmap) / 255\n", 55 | "\n", 56 | " if np.max(img) > 1:\n", 57 | " raise Exception(\n", 58 | " \"The input image should np.float32 in the range [0, 1]\")\n", 59 | "\n", 60 | " cam = heatmap + img\n", 61 | " cam = cam / np.max(cam)\n", 62 | " return np.uint8(255 * cam)\n", 63 | "\n", 64 | "def open_image(filename):\n", 65 | " transformer = transforms.Compose([transforms.Resize((224, 224))])\n", 66 | " im = Image.open(filename)\n", 67 | " im = transformer(im)\n", 68 | " im = np.array(im) / 255.0\n", 69 | " return im" 70 | ] 71 | }, 72 | { 73 | "cell_type": "code", 74 | "execution_count": null, 75 | "id": "ceeb5c0b", 76 | "metadata": {}, 77 | "outputs": [], 78 | "source": [ 79 | "def get_path_mask(path_masks, dataset_name, model_name, method):\n", 80 | " return path_masks / Path('{}_{}_{}/'.format(dataset_name, model_name, method))\n", 81 | "\n", 82 | "def load_original_image(path_images, filename):\n", 83 | " jpg_name = Path(str(filename)[:-4] + \".jpg\")\n", 84 | " x = open_image(path_images / jpg_name)\n", 85 | " return x" 86 | ] 87 | }, 88 | { 89 | "cell_type": "code", 90 | "execution_count": null, 91 | "id": "c3b47b5f", 92 | "metadata": {}, 93 | "outputs": [], 94 | "source": [ 95 | "methods = [ \"grad_cam\", \"rise\", \"extremal_perturbations\", \"igos_pp\", \"rt_saliency\", \"explainer\"]" 96 | ] 97 | }, 98 | { 99 | "cell_type": "code", 100 | "execution_count": null, 101 | "id": "5ad6e4f7", 102 | "metadata": {}, 103 | "outputs": [], 104 | "source": [ 105 | "path_masks = Path(\"../src/evaluation/masks/\")" 106 | ] 107 | }, 108 | { 109 | "cell_type": "code", 110 | "execution_count": null, 111 | "id": "90336a49", 112 | "metadata": {}, 113 | "outputs": [], 114 | "source": [ 115 | "def plot_filename(filename, dataset_name, model_name, methods):\n", 116 | " if dataset_name==\"COCO\":\n", 117 | " path_images = Path(\"../datasets/COCO2014/val2014/\")\n", 118 | " else:\n", 119 | " path_images = Path(\"../datasets/VOC2007/VOCdevkit/VOC2007/JPEGImages/\")\n", 120 | " try:\n", 121 | " x = load_original_image(path_images, filename)\n", 122 | " masks = []\n", 123 | " for method in methods: \n", 124 | " p = get_path_mask(path_masks, dataset_name, model_name, method)\n", 125 | " # try:\n", 126 | " npz_name = Path(str(filename)[:-4] + \".npz\")\n", 127 | " m_ = np.load(p / npz_name, dataset_name)[\"arr_0\"]\n", 128 | " m_ /= m_.max()\n", 129 | " \n", 130 | " masks.append(m_)\n", 131 | " # except:\n", 132 | " # jpg_name = Path(str(filename)[:-4] + \".jpg\")\n", 133 | " # masks.append(load_original_image(p, jpg_name)) \n", 134 | " except:\n", 135 | " traceback.print_exc()\n", 136 | " \n", 137 | " \n", 138 | " fig = plt.figure(figsize=(15, 4))\n", 139 | " n_methods = len(methods)\n", 140 | " plt.subplot(1,n_methods+1, 1)\n", 141 | " plt.imshow(x, vmin=0, vmax= 1)\n", 142 | " plt.axis(\"off\")\n", 143 | "\n", 144 | "\n", 145 | " for i, (mask,m) in enumerate(zip(masks, methods)):\n", 146 | " plt.subplot(1,n_methods+1, i + 2)\n", 147 | " plt.imshow(show_cam_on_image(x, mask), vmin=0, vmax= 1)\n", 148 | " # plt.imshow( mask, vmin=0, vmax= 1, cmap=plt.cm.gray_r)\n", 149 | " plt.axis(\"off\")\n", 150 | " \n", 151 | " fig.tight_layout()\n", 152 | " outfolder = Path(\"ATTENUATED_\" + dataset_name + \"_\" + model_name)\n", 153 | " outfolder.mkdir(exist_ok=True, parents=True)\n", 154 | " pdf_name = str(filename)[:-4] + \".pdf\"\n", 155 | " plt.savefig(outfolder / Path(pdf_name), bbox_inches='tight')\n" 156 | ] 157 | }, 158 | { 159 | "cell_type": "code", 160 | "execution_count": null, 161 | "id": "44b9c213", 162 | "metadata": {}, 163 | "outputs": [], 164 | "source": [ 165 | "n_images = 5\n", 166 | "dataset_name = \"VOC\"\n", 167 | "model_name = \"vgg16\"\n", 168 | "p_explainer = get_path_mask(path_masks, dataset_name, model_name, \"explainer\")\n", 169 | "image_list = list(enumerate(p_explainer.glob(\"*.png\")))\n", 170 | "random.shuffle(image_list)\n", 171 | "count = 0\n", 172 | "for i, p in image_list:\n", 173 | " filename = p.name\n", 174 | " plot_filename(filename, dataset_name, model_name, methods)\n", 175 | " count += 1\n", 176 | " if count>=n_images:\n", 177 | " break" 178 | ] 179 | }, 180 | { 181 | "cell_type": "code", 182 | "execution_count": null, 183 | "id": "9c8a6035", 184 | "metadata": {}, 185 | "outputs": [], 186 | "source": [] 187 | } 188 | ], 189 | "metadata": { 190 | "kernelspec": { 191 | "display_name": "pytorch_gpu", 192 | "language": "python", 193 | "name": "pytorch_gpu" 194 | }, 195 | "language_info": { 196 | "codemirror_mode": { 197 | "name": "ipython", 198 | "version": 3 199 | }, 200 | "file_extension": ".py", 201 | "mimetype": "text/x-python", 202 | "name": "python", 203 | "nbconvert_exporter": "python", 204 | "pygments_lexer": "ipython3", 205 | "version": "3.8.6" 206 | } 207 | }, 208 | "nbformat": 4, 209 | "nbformat_minor": 5 210 | } 211 | -------------------------------------------------------------------------------- /notebooks/image_plot_topk.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "%load_ext autoreload\n", 10 | "%autoreload 2" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": null, 16 | "metadata": {}, 17 | "outputs": [], 18 | "source": [ 19 | "import sys\n", 20 | "sys.path.append('../src')" 21 | ] 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": null, 26 | "metadata": {}, 27 | "outputs": [], 28 | "source": [ 29 | "import torch\n", 30 | "import cv2\n", 31 | "import random\n", 32 | "import matplotlib.pyplot as plt\n", 33 | "import numpy as np\n", 34 | "\n", 35 | "from torchray.utils import get_device\n", 36 | "from pathlib import Path\n", 37 | "\n", 38 | "from data.dataloader import VOCDataModule, COCODataModule\n", 39 | "from utils.image_utils import get_unnormalized_image\n", 40 | "from utils.helper import get_targets_from_annotations, get_filename_from_annotations, extract_masks\n", 41 | "from models.explainer_classifier import ExplainerClassifierModel" 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": null, 47 | "metadata": {}, 48 | "outputs": [], 49 | "source": [ 50 | "def show_cam_on_image(img: np.ndarray,\n", 51 | " mask: np.ndarray,\n", 52 | " use_rgb: bool = False,\n", 53 | " colormap: int = cv2.COLORMAP_JET) -> np.ndarray:\n", 54 | " \"\"\" This function overlays the cam mask on the image as an heatmap.\n", 55 | " By default the heatmap is in BGR format.\n", 56 | " :param img: The base image in RGB or BGR format.\n", 57 | " :param mask: The cam mask.\n", 58 | " :param use_rgb: Whether to use an RGB or BGR heatmap, this should be set to True if 'img' is in RGB format.\n", 59 | " :param colormap: The OpenCV colormap to be used.\n", 60 | " :returns: The default image with the cam overlay.\n", 61 | " \"\"\"\n", 62 | " heatmap = cv2.applyColorMap(np.uint8(255 * (1-mask)), colormap)\n", 63 | " if use_rgb:\n", 64 | " heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)\n", 65 | " heatmap = np.float32(heatmap) / 255\n", 66 | "\n", 67 | " if np.max(img) > 1:\n", 68 | " raise Exception(\n", 69 | " \"The input image should np.float32 in the range [0, 1]\")\n", 70 | "\n", 71 | " cam = heatmap + img\n", 72 | " cam = cam / np.max(cam)\n", 73 | " return np.uint8(255 * cam)" 74 | ] 75 | }, 76 | { 77 | "cell_type": "code", 78 | "execution_count": null, 79 | "metadata": {}, 80 | "outputs": [], 81 | "source": [ 82 | "def get_class_mask(explainer, image, class_id):\n", 83 | " class_mask = explainer(image)[0][class_id].sigmoid()\n", 84 | "\n", 85 | " return class_mask.cpu().numpy()" 86 | ] 87 | }, 88 | { 89 | "cell_type": "code", 90 | "execution_count": null, 91 | "metadata": {}, 92 | "outputs": [], 93 | "source": [ 94 | "def get_topk_classes(explainer, image, k=5):\n", 95 | " class_masks = explainer(image)[0].sigmoid()\n", 96 | " class_mask_means = class_masks.mean(dim=(1,2))\n", 97 | "\n", 98 | " values, topk_classes = class_mask_means.topk(k)\n", 99 | " return values.cpu().numpy(), topk_classes.cpu().numpy()" 100 | ] 101 | }, 102 | { 103 | "cell_type": "code", 104 | "execution_count": null, 105 | "metadata": {}, 106 | "outputs": [], 107 | "source": [ 108 | "def get_class_scores(explainer, classifier, image, class_id):\n", 109 | " class_mask = explainer(image)[0][class_id].sigmoid()\n", 110 | " masked_image = class_mask.unsqueeze(0).unsqueeze(0) * image\n", 111 | "\n", 112 | " unmasked_logits = classifier(image)[0]\n", 113 | " masked_logits = classifier(masked_image)[0]\n", 114 | "\n", 115 | " unmasked_class_prob = unmasked_logits.sigmoid()[class_id]\n", 116 | " masked_class_prob = masked_logits.sigmoid()[class_id]\n", 117 | "\n", 118 | " return unmasked_class_prob, masked_class_prob" 119 | ] 120 | }, 121 | { 122 | "cell_type": "code", 123 | "execution_count": null, 124 | "metadata": {}, 125 | "outputs": [], 126 | "source": [ 127 | "def get_test_dataloader(dataset, data_path):\n", 128 | " if dataset == \"VOC\":\n", 129 | " data_module = VOCDataModule(data_path, test_batch_size=1)\n", 130 | " elif dataset == \"COCO\":\n", 131 | " data_module = COCODataModule(data_path, test_batch_size=1)\n", 132 | " \n", 133 | " data_module.setup(stage=\"test\")\n", 134 | " test_dataloader = data_module.test_dataloader()\n", 135 | "\n", 136 | " return test_dataloader" 137 | ] 138 | }, 139 | { 140 | "cell_type": "code", 141 | "execution_count": null, 142 | "metadata": {}, 143 | "outputs": [], 144 | "source": [ 145 | "dataset = \"VOC\"\n", 146 | "num_classes = 20\n", 147 | "data_path = Path(\"../datasets/VOC2007/\")\n", 148 | "classifier_type = \"resnet50\"\n", 149 | "explainer_classifier_path = Path(\"../src/checkpoints/explainer_resnet50_voc.ckpt\")\n", 150 | "output_dir = Path(f\"./topk_attributions/{classifier_type}\")\n", 151 | "\n", 152 | "explainer_classifier = ExplainerClassifierModel.load_from_checkpoint(explainer_classifier_path, \n", 153 | " num_classes=num_classes, \n", 154 | " dataset=dataset, \n", 155 | " classifier_type=classifier_type)\n", 156 | " \n", 157 | "device = get_device()\n", 158 | "explainer = explainer_classifier.explainer.to(device)\n", 159 | "explainer.freeze()\n", 160 | "classifier = explainer_classifier.classifier.to(device)\n", 161 | "classifier.freeze()\n", 162 | "\n", 163 | "dataloader = get_test_dataloader(dataset, data_path)\n", 164 | "image_list = list(enumerate(dataloader))" 165 | ] 166 | }, 167 | { 168 | "cell_type": "code", 169 | "execution_count": null, 170 | "metadata": {}, 171 | "outputs": [], 172 | "source": [ 173 | "output_dir.mkdir(parents=True, exist_ok=True)\n", 174 | "\n", 175 | "random.shuffle(image_list)\n", 176 | "n_images = 8\n", 177 | "count = 0\n", 178 | "for i, item in image_list:\n", 179 | " image, annotation = item\n", 180 | " image = image.to(device)\n", 181 | " filename = get_filename_from_annotations(annotation, dataset=dataset)[:-4]\n", 182 | " targets = get_targets_from_annotations(annotation, dataset=dataset)\n", 183 | " target_classes = [i for i, val in enumerate(targets[0]) if val == 1.0]\n", 184 | " topk_values, topk_classes = get_topk_classes(explainer, image, k=5)\n", 185 | "\n", 186 | " fig = plt.figure(figsize=(25, 5))\n", 187 | " original_image = np.transpose(get_unnormalized_image(image).cpu().numpy().squeeze(), (1, 2, 0))\n", 188 | " plt.imsave(output_dir / f\"{filename}_original.png\", original_image, format=\"png\")\n", 189 | " plt.subplot(1, 7, 1)\n", 190 | " plt.imshow(original_image)\n", 191 | " plt.axis(\"off\")\n", 192 | "\n", 193 | " segmentations = explainer(image)\n", 194 | " aggregated_mask, _ = extract_masks(segmentations, targets)\n", 195 | " aggregated_mask = aggregated_mask[0].cpu().numpy()\n", 196 | " aggregated_attribution = show_cam_on_image(original_image, aggregated_mask)\n", 197 | " plt.imsave(output_dir / f\"{filename}_aggregated.png\", aggregated_attribution, format=\"png\")\n", 198 | " plt.subplot(1, 7, 2)\n", 199 | " plt.imshow(aggregated_attribution)\n", 200 | " plt.axis(\"off\")\n", 201 | " for j, class_id in enumerate(topk_classes):\n", 202 | " unmasked_class_prob, masked_class_prob = get_class_scores(explainer, classifier, image, class_id)\n", 203 | " class_mask = get_class_mask(explainer, image, class_id)\n", 204 | " attribution = show_cam_on_image(original_image, class_mask)\n", 205 | " plt.imsave(output_dir / f\"{filename}_rank_{j}_class_{class_id+1}.png\", attribution, format=\"png\")\n", 206 | " plt.subplot(1, 7, j+3)\n", 207 | " plt.imshow(attribution, vmin=0, vmax=1)\n", 208 | " class_title = f\"{class_id+1}**\" if class_id in target_classes else f\"{class_id+1}\"\n", 209 | " plt.title(f\"{class_title}: CLS={unmasked_class_prob*100:.2f}, MASK={topk_values[j]*100:.2f}\")\n", 210 | " plt.axis(\"off\")\n", 211 | "\n", 212 | " count += 1\n", 213 | " if count >= n_images:\n", 214 | " break" 215 | ] 216 | }, 217 | { 218 | "cell_type": "code", 219 | "execution_count": null, 220 | "metadata": {}, 221 | "outputs": [], 222 | "source": [] 223 | } 224 | ], 225 | "metadata": { 226 | "interpreter": { 227 | "hash": "7ca055faf3ae29c3c66579db75b878dd7b19a8a2f3a26672c87b0ac03bac584e" 228 | }, 229 | "kernelspec": { 230 | "display_name": "Python 3.8.6 ('pytorch_gpu')", 231 | "language": "python", 232 | "name": "python3" 233 | }, 234 | "language_info": { 235 | "codemirror_mode": { 236 | "name": "ipython", 237 | "version": 3 238 | }, 239 | "file_extension": ".py", 240 | "mimetype": "text/x-python", 241 | "name": "python", 242 | "nbconvert_exporter": "python", 243 | "pygments_lexer": "ipython3", 244 | "version": "3.8.6" 245 | }, 246 | "orig_nbformat": 4 247 | }, 248 | "nbformat": 4, 249 | "nbformat_minor": 2 250 | } 251 | -------------------------------------------------------------------------------- /notebooks/produce_tables.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "4bb9b98d-79ed-4c39-8929-7fc543540292", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "%load_ext autoreload\n", 11 | "%autoreload 2" 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": null, 17 | "id": "a309e1be-7182-4ca8-9b48-8b21e1e5751a", 18 | "metadata": {}, 19 | "outputs": [], 20 | "source": [ 21 | "import numpy as np\n", 22 | "import matplotlib.pyplot as plt\n", 23 | "import pandas as pd" 24 | ] 25 | }, 26 | { 27 | "cell_type": "code", 28 | "execution_count": null, 29 | "id": "f19b35ec-62b8-4330-ba7e-5b5b5c46318f", 30 | "metadata": {}, 31 | "outputs": [], 32 | "source": [ 33 | "results = np.load(\"../src/evaluation/results.npz\", allow_pickle=True)[\"results\"].item()" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": null, 39 | "id": "dc2c2342-099c-4cb7-b18c-e8b4adc57191", 40 | "metadata": {}, 41 | "outputs": [], 42 | "source": [ 43 | "# def make_table()\n", 44 | "database_name = \"COCO\"\n", 45 | "model_names = [\"vgg16\", \"resnet50\"]\n", 46 | "methods = [\"0\", \"0.5\", \"1\", \"perfect\", \"grad_cam\", \"rise\", \"extremal_perturbations\", \"igos_pp\", \"rt_saliency\", \"explainer\"]\n", 47 | "methods_names = [\"0\", \"0.5\", \"1\", \"Ideal\", \"GradCam\", \"RISE\", \"Fong 19\", \"iGOS++\", \"Dabkowski 17\", \"Explainer\", ]\n", 48 | "metrics = [\"d_f1\", \"c_f1\", \"acc\", \"d_IOUs\", \"c_IOU\" ,\"aucs\",'sal', 'over', 'background_c', 'mask_c']\n", 49 | "metrics_names = [\"Mean F1\", \"Cont F1\", \"Accuracy\", \"Discr. IOU\", \"Cont. IOU\", \"AUC\", 'Saliency', 'Accuracy', 'Background cov.', 'Object cov.']\n", 50 | "\n", 51 | "indexes = pd.MultiIndex.from_product([metrics_names, model_names], names=[\"Metric\", \"Network\"])\n", 52 | "\n", 53 | "m = np.zeros([len(metrics)*2, len(methods)])\n", 54 | "for i,metric in enumerate(metrics):\n", 55 | " for j, method in enumerate(methods):\n", 56 | " for k in [0,1]:\n", 57 | " try:\n", 58 | " v = results[database_name][model_names[k]][method][metric]\n", 59 | " m[2*i+k,j] = \"{:.2f}\".format(np.mean(np.array([e for e in v if e is not None])))\n", 60 | " except:\n", 61 | " m[2*i+k,j] = np.nan\n", 62 | "df = pd.DataFrame(m, index=indexes, columns=methods_names)\n", 63 | "\n", 64 | "\n", 65 | "print(df.to_latex(index=True, escape=False))\n" 66 | ] 67 | }, 68 | { 69 | "cell_type": "code", 70 | "execution_count": null, 71 | "id": "c58aec0f-63d3-4d8a-91a0-7aae7cc8b39b", 72 | "metadata": {}, 73 | "outputs": [], 74 | "source": [ 75 | "model_name = \"vgg16\"\n", 76 | "plt.figure(figsize=(6, 4))\n", 77 | "x = np.arange(0.1, 1, .1)\n", 78 | "for method, method_name in zip(methods, methods_names):\n", 79 | " if method_name in [\"0.5\", \"Ideal\"]: continue\n", 80 | " if method_name in [\"0\", \"1\"]:\n", 81 | " plt.plot(x, np.mean(np.array(results[database_name][model_name][method][\"a_f1s\"]), axis=0), '--', label=method_name)\n", 82 | " else:\n", 83 | " plt.plot(x, np.mean(np.array(results[database_name][model_name][method][\"a_f1s\"]), axis=0), label=method_name)\n", 84 | "plt.legend(bbox_to_anchor=(-0.03, 1.0, 1, 0.2), loc=\"lower left\", ncol=4)\n", 85 | "plt.xlabel(\"Threshold\")\n", 86 | "plt.ylabel(\"F1 score\")\n", 87 | "plt.savefig(\"f1s_{}_{}.pdf\".format(database_name, model_name))" 88 | ] 89 | }, 90 | { 91 | "cell_type": "code", 92 | "execution_count": null, 93 | "id": "0224e1cf-7fea-4c57-9aea-98d332e0e96e", 94 | "metadata": {}, 95 | "outputs": [], 96 | "source": [] 97 | } 98 | ], 99 | "metadata": { 100 | "kernelspec": { 101 | "display_name": "Python 3", 102 | "language": "python", 103 | "name": "python3" 104 | }, 105 | "language_info": { 106 | "codemirror_mode": { 107 | "name": "ipython", 108 | "version": 3 109 | }, 110 | "file_extension": ".py", 111 | "mimetype": "text/x-python", 112 | "name": "python", 113 | "nbconvert_exporter": "python", 114 | "pygments_lexer": "ipython3", 115 | "version": "3.8.6" 116 | } 117 | }, 118 | "nbformat": 4, 119 | "nbformat_minor": 5 120 | } 121 | -------------------------------------------------------------------------------- /src/config_files/classifier_training/classifier_resnet50_coco_training.cfg: -------------------------------------------------------------------------------- 1 | classifier_type = resnet50 2 | data_base_path = ../datasets/ 3 | dataset = COCO 4 | learning_rate = 1e-4 5 | metrics_threshold = 0.0 6 | model_to_train = classifier 7 | -------------------------------------------------------------------------------- /src/config_files/classifier_training/classifier_resnet50_voc_training.cfg: -------------------------------------------------------------------------------- 1 | classifier_type = resnet50 2 | data_base_path = ../datasets/ 3 | dataset = VOC 4 | metrics_threshold = 0.0 5 | model_to_train = classifier 6 | learning_rate=1e-4 7 | -------------------------------------------------------------------------------- /src/config_files/classifier_training/classifier_vgg16_coco_training.cfg: -------------------------------------------------------------------------------- 1 | classifier_type = vgg16 2 | data_base_path = ../datasets/ 3 | dataset = COCO 4 | metrics_threshold = 0.0 5 | model_to_train = classifier 6 | -------------------------------------------------------------------------------- /src/config_files/classifier_training/classifier_vgg16_voc_training.cfg: -------------------------------------------------------------------------------- 1 | classifier_type = vgg16 2 | data_base_path = ../datasets/ 3 | dataset = VOC 4 | metrics_threshold = 0.0 5 | model_to_train = classifier 6 | -------------------------------------------------------------------------------- /src/config_files/explainer_training/explainer_resnet50_coco_training.cfg: -------------------------------------------------------------------------------- 1 | classifier_type = resnet50 2 | data_base_path = ../datasets/ 3 | dataset = COCO 4 | early_stop_patience=2 5 | entropy_regularizer = 3.0 6 | learning_rate = 1e-05 7 | classifier_checkpoint = 8 | mask_total_area_regularizer = 0.01 9 | mask_variation_regularizer = 0.05 10 | ncmask_total_area_regularizer = 0.03 11 | mask_area_constraint_regularizer = 0.1 12 | model_to_train = explainer 13 | 14 | -------------------------------------------------------------------------------- /src/config_files/explainer_training/explainer_resnet50_voc_training.cfg: -------------------------------------------------------------------------------- 1 | classifier_type = resnet50 2 | data_base_path = ../datasets/ 3 | dataset = VOC 4 | entropy_regularizer = 3.0 5 | learning_rate = 1e-04 6 | classifier_checkpoint = 7 | mask_area_constraint_regularizer = 0.8 8 | mask_total_area_regularizer = 0.01 9 | model_to_train = explainer 10 | -------------------------------------------------------------------------------- /src/config_files/explainer_training/explainer_vgg16_coco_training.cfg: -------------------------------------------------------------------------------- 1 | classifier_type = vgg16 2 | data_base_path = ../datasets/ 3 | dataset = COCO 4 | early_stop_patience = 2 5 | learning_rate = 1e-05 6 | classifier_checkpoint = 7 | model_to_train = explainer 8 | mask_variation_regularizer = 0.3 9 | mask_area_constraint_regularizer = 0.3 10 | mask_total_area_regularizer = 0.03 11 | ncmask_total_area_regularizer = 0.1 12 | -------------------------------------------------------------------------------- /src/config_files/explainer_training/explainer_vgg16_voc_training.cfg: -------------------------------------------------------------------------------- 1 | classifier_type = vgg16 2 | data_base_path = ../datasets/ 3 | dataset = VOC 4 | learning_rate = 1e-05 5 | classifier_checkpoint = 6 | model_to_train = explainer 7 | -------------------------------------------------------------------------------- /src/config_files/testing_and_mask_saving/explainer_resnet50_coco_test_and_save_masks.cfg: -------------------------------------------------------------------------------- 1 | classifier_type = resnet50 2 | data_base_path = ../datasets/ 3 | dataset = COCO 4 | explainer_classifier_checkpoint = 5 | model_to_train = explainer 6 | save_masks = True 7 | save_masked_images = True 8 | save_path = ./results/ 9 | test_batch_size = 1 10 | train_model = False 11 | -------------------------------------------------------------------------------- /src/config_files/testing_and_mask_saving/explainer_resnet50_voc_test_and_save_masks.cfg: -------------------------------------------------------------------------------- 1 | classifier_type = resnet50 2 | data_base_path = ../datasets/ 3 | dataset = VOC 4 | explainer_classifier_checkpoint = 5 | model_to_train = explainer 6 | save_masks = True 7 | save_masked_images = True 8 | save_path = ./results/ 9 | test_batch_size = 1 10 | train_model = False 11 | -------------------------------------------------------------------------------- /src/config_files/testing_and_mask_saving/explainer_vgg16_coco_test_and_save_masks.cfg: -------------------------------------------------------------------------------- 1 | classifier_type = vgg16 2 | data_base_path = ../datasets/ 3 | dataset = COCO 4 | explainer_classifier_checkpoint = 5 | model_to_train = explainer 6 | save_masks = True 7 | save_masked_images = True 8 | save_path = ./results/ 9 | test_batch_size = 1 10 | train_model = False 11 | -------------------------------------------------------------------------------- /src/config_files/testing_and_mask_saving/explainer_vgg16_voc_test_and_save_masks.cfg: -------------------------------------------------------------------------------- 1 | classifier_type = vgg16 2 | data_base_path = ../datasets/ 3 | dataset = VOC 4 | explainer_classifier_checkpoint = 5 | model_to_train = explainer 6 | save_masks = True 7 | save_masked_images = True 8 | save_path = ./results/ 9 | test_batch_size = 1 10 | train_model = False 11 | -------------------------------------------------------------------------------- /src/data/dataloader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | 5 | import pytorch_lightning as pl 6 | import torchvision.transforms as T 7 | 8 | from torchvision.datasets import VOCDetection 9 | from torch.utils.data import DataLoader 10 | from typing import Optional 11 | from pathlib import Path 12 | 13 | from data.dataset import COCODataset, CUB200Dataset 14 | 15 | class VOCDataModule(pl.LightningDataModule): 16 | 17 | def __init__(self, data_path, train_batch_size=16, val_batch_size=16, test_batch_size=16, use_data_augmentation=False): 18 | super().__init__() 19 | 20 | self.data_path = Path(data_path) 21 | 22 | if os.path.exists(self.data_path) and len(os.listdir(self.data_path)) > 2: 23 | self.download = False 24 | else: 25 | self.download = True 26 | 27 | self.train_transformer = get_training_image_transformer(use_data_augmentation) 28 | self.test_transformer = get_testing_image_transformer() 29 | 30 | self.train_batch_size = train_batch_size 31 | self.val_batch_size = val_batch_size 32 | self.test_batch_size = test_batch_size 33 | 34 | def prepare_data(self): 35 | pass 36 | 37 | def setup(self, stage: Optional[str] = None): 38 | if stage == "fit" or stage is None: 39 | self.train = VOCDetection(self.data_path, year="2007", image_set="train", 40 | download=self.download, transform=self.train_transformer) 41 | 42 | self.val = VOCDetection(self.data_path, year="2007", image_set="val", 43 | download=self.download, transform=self.test_transformer) 44 | 45 | if stage == "test" or stage is None: 46 | self.test = VOCDetection(self.data_path, year="2007", image_set="test", 47 | download=self.download, transform=self.test_transformer) 48 | 49 | def train_dataloader(self): 50 | return DataLoader(self.train, batch_size=self.train_batch_size, collate_fn=collate_fn, 51 | shuffle=True, num_workers=4, pin_memory=torch.cuda.is_available()) 52 | 53 | def val_dataloader(self): 54 | return DataLoader(self.val, batch_size=self.val_batch_size, collate_fn=collate_fn, 55 | num_workers=4, pin_memory=torch.cuda.is_available()) 56 | 57 | def test_dataloader(self): 58 | return DataLoader(self.test, batch_size=self.test_batch_size, collate_fn=collate_fn, 59 | num_workers=4, pin_memory=torch.cuda.is_available()) 60 | 61 | class COCODataModule(pl.LightningDataModule): 62 | 63 | def __init__(self, data_path, train_batch_size=16, val_batch_size=16, test_batch_size=16, use_data_augmentation=False): 64 | super().__init__() 65 | 66 | self.data_path = Path(data_path) 67 | self.annotations_path = self.data_path / 'annotations' 68 | 69 | self.train_transformer = get_training_image_transformer(use_data_augmentation) 70 | self.test_transformer = get_testing_image_transformer() 71 | 72 | self.train_batch_size = train_batch_size 73 | self.val_batch_size = val_batch_size 74 | self.test_batch_size = test_batch_size 75 | 76 | def prepare_data(self): 77 | pass 78 | 79 | def setup(self, stage: Optional[str] = None): 80 | if stage == "fit" or stage is None: 81 | self.train = COCODataset(root=self.data_path / 'train2014', 82 | annotation=self.annotations_path / 'train2014_train_split.json', 83 | transform_fn=self.train_transformer) 84 | 85 | self.val = COCODataset(root=self.data_path / 'train2014', 86 | annotation=self.annotations_path / 'train2014_val_split.json', 87 | transform_fn=self.test_transformer) 88 | 89 | if stage == "test" or stage is None: 90 | self.test = COCODataset(root=self.data_path / 'val2014', 91 | annotation=self.annotations_path / 'instances_val2014.json', 92 | transform_fn=self.test_transformer) 93 | 94 | def train_dataloader(self): 95 | return DataLoader(self.train, batch_size=self.train_batch_size, collate_fn=collate_fn, 96 | shuffle=True, num_workers=4, pin_memory=torch.cuda.is_available()) 97 | 98 | def val_dataloader(self): 99 | return DataLoader(self.val, batch_size=self.val_batch_size, collate_fn=collate_fn, 100 | num_workers=4, pin_memory=torch.cuda.is_available()) 101 | 102 | def test_dataloader(self): 103 | return DataLoader(self.test, batch_size=self.test_batch_size, collate_fn=collate_fn, 104 | num_workers=4, pin_memory=torch.cuda.is_available()) 105 | 106 | class CUB200DataModule(pl.LightningDataModule): 107 | 108 | def __init__(self, data_path, train_batch_size=16, val_batch_size=16, test_batch_size=16, use_data_augmentation=False): 109 | super().__init__() 110 | 111 | self.data_path = Path(data_path) 112 | self.annotations_path = self.data_path / 'annotations' 113 | 114 | self.train_transformer = get_training_image_transformer(use_data_augmentation) 115 | self.test_transformer = get_testing_image_transformer() 116 | 117 | self.train_batch_size = train_batch_size 118 | self.val_batch_size = val_batch_size 119 | self.test_batch_size = test_batch_size 120 | 121 | def prepare_data(self): 122 | pass 123 | 124 | def setup(self, stage: Optional[str] = None): 125 | if stage == "fit" or stage is None: 126 | self.train = CUB200Dataset(root=self.data_path / 'train', 127 | annotations=self.annotations_path / 'train.txt', 128 | transform_fn=self.train_transformer) 129 | 130 | self.val = CUB200Dataset(root=self.data_path / 'val', 131 | annotations=self.annotations_path / 'val.txt', 132 | transform_fn=self.test_transformer) 133 | 134 | if stage == "test" or stage is None: 135 | self.test = CUB200Dataset(root=self.data_path / 'test', 136 | annotations=self.annotations_path / 'test.txt', 137 | transform_fn=self.test_transformer) 138 | 139 | def train_dataloader(self): 140 | return DataLoader(self.train, batch_size=self.train_batch_size, collate_fn=collate_fn, 141 | shuffle=True, num_workers=4, pin_memory=torch.cuda.is_available()) 142 | 143 | def val_dataloader(self): 144 | return DataLoader(self.val, batch_size=self.val_batch_size, collate_fn=collate_fn, 145 | num_workers=4, pin_memory=torch.cuda.is_available()) 146 | 147 | def test_dataloader(self): 148 | return DataLoader(self.test, batch_size=self.test_batch_size, collate_fn=collate_fn, 149 | num_workers=4, pin_memory=torch.cuda.is_available()) 150 | 151 | 152 | def get_training_image_transformer(use_data_augmentation=False): 153 | if use_data_augmentation: 154 | transformer = T.Compose([ T.RandomHorizontalFlip(), 155 | T.RandomRotation(10), 156 | T.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1), 157 | #T.Resize(256), 158 | #T.CenterCrop(224), 159 | T.Resize(size=(224,224)), 160 | T.ToTensor(), 161 | T.Normalize(mean = [0.485, 0.456, 0.406], std = [0.229, 0.224, 0.225])]) 162 | else: 163 | transformer = T.Compose([ T.Resize(size=(224,224)), 164 | # T.Resize(256), 165 | # T.CenterCrop(224), 166 | T.ToTensor(), 167 | T.Normalize(mean = [0.485, 0.456, 0.406], std = [0.229, 0.224, 0.225])]) 168 | 169 | return transformer 170 | 171 | def get_testing_image_transformer(): 172 | transformer = T.Compose([ T.Resize(size=(224,224)), 173 | T.ToTensor(), 174 | T.Normalize(mean = [0.485, 0.456, 0.406], std = [0.229, 0.224, 0.225])]) 175 | 176 | return transformer 177 | 178 | def collate_fn(batch): 179 | data = torch.stack([item[0] for item in batch]) 180 | target = [item[1] for item in batch] 181 | return data, target 182 | -------------------------------------------------------------------------------- /src/data/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import pathlib 4 | import torch 5 | from torch.utils.data import Dataset 6 | import torchvision 7 | from PIL import Image 8 | from pycocotools.coco import COCO 9 | 10 | class COCODataset(Dataset): 11 | def __init__(self, root, annotation, transform_fn=None): 12 | self.root = root 13 | self.transform = transform_fn 14 | self.coco = COCO(annotation) 15 | self.ids = list(sorted(self.coco.imgs.keys())) 16 | 17 | def __getitem__(self, index): 18 | coco = self.coco 19 | img_id = self.ids[index] 20 | ann_ids = coco.getAnnIds(imgIds=img_id) 21 | coco_annotation = coco.loadAnns(ann_ids) 22 | path = coco.loadImgs(img_id)[0]['file_name'] 23 | img = Image.open(os.path.join(self.root, path)).convert("RGB") 24 | 25 | num_objects = len(coco_annotation) 26 | cat_ids = [] 27 | for i in range(num_objects): 28 | cat_ids.append(coco_annotation[i]['category_id']) 29 | 30 | targets = coco.getCatIds(catIds=cat_ids) 31 | 32 | my_annotation = {} 33 | my_annotation["targets"] = targets 34 | my_annotation["image_id"] = img_id 35 | my_annotation["filename"] = path 36 | 37 | if self.transform is not None: 38 | img = self.transform(img) 39 | 40 | return img, my_annotation 41 | 42 | def __len__(self): 43 | return len(self.ids) 44 | 45 | class CUB200Dataset(Dataset): 46 | def __init__(self, root, annotations, transform_fn=None): 47 | self.root = root 48 | self.transform = transform_fn 49 | with open(annotations, 'rb') as fp: 50 | self.annotations = pickle.load(fp) 51 | 52 | def __getitem__(self, index): 53 | annotation = self.annotations[index] 54 | filename = pathlib.PurePath(annotation['filename']).name 55 | 56 | img = Image.open(os.path.join(self.root, filename)).convert("RGB") 57 | class_label = annotation['class']['label'] 58 | 59 | my_annotation = {} 60 | my_annotation["target"] = class_label 61 | my_annotation["filename"] = filename 62 | 63 | if self.transform is not None: 64 | img = self.transform(img) 65 | 66 | return img, my_annotation 67 | 68 | def __len__(self): 69 | return len(self.annotations) 70 | -------------------------------------------------------------------------------- /src/evaluation/IGOS_pp/args.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | 4 | def init_args(): 5 | 6 | parser = argparse.ArgumentParser( 7 | description='Generate explanations using I-GOS and iGOS++.' 8 | ) 9 | 10 | parser.add_argument( 11 | '--shuffle', 12 | action='store_true', 13 | help='Shuffle the dataset.') 14 | 15 | parser.add_argument( 16 | '--size', 17 | type=int, 18 | default=28, 19 | help='The resolution of mask to be generated.') 20 | 21 | parser.add_argument( 22 | '--input_size', 23 | type=int, 24 | default=224, 25 | help='The input size to the network.') 26 | 27 | parser.add_argument( 28 | '--num_samples', 29 | type=int, 30 | default=5000, 31 | help='The number of samples to run explanation on.') 32 | 33 | parser.add_argument( 34 | '--batch_size', 35 | type=int, 36 | default=1, 37 | help='The number of images to generate explanations at once.') 38 | 39 | parser.add_argument( 40 | '--manual_seed', 41 | type=int, 42 | default=63, 43 | help='The manual seed for experiments.') 44 | 45 | parser.add_argument( 46 | '--L1', 47 | type=float, 48 | default=1 49 | ) 50 | 51 | parser.add_argument( 52 | '--L2', 53 | type=float, 54 | default=20 55 | ) 56 | 57 | parser.add_argument( 58 | '--ig_iter', 59 | type=int, 60 | default=20) 61 | 62 | parser.add_argument( 63 | '--iterations', 64 | type=int, 65 | default=15 66 | ) 67 | 68 | parser.add_argument( 69 | '--alpha', 70 | type=float, 71 | default=1000 72 | ) 73 | 74 | return parser.parse_args() 75 | -------------------------------------------------------------------------------- /src/evaluation/IGOS_pp/igos_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utility functions to use for logging results and explanations. 3 | © copyright Tyler Lawson, Saeed khorram. https://github.com/saeed-khorram/IGOS 4 | """ 5 | 6 | import torch 7 | import os 8 | import cv2 9 | import time 10 | import sys 11 | import requests 12 | 13 | import matplotlib.pyplot as plt 14 | import torchvision.transforms as transforms 15 | import seaborn as sns 16 | import numpy as np 17 | 18 | from PIL import Image 19 | 20 | 21 | # mean and standard deviation for the imagenet dataset 22 | mean = torch.tensor([0.485, 0.456, 0.406]) 23 | std = torch.tensor([0.229, 0.224, 0.225]) 24 | 25 | 26 | def init_sns(): 27 | """ 28 | Sets up desired configs for plotting with Seaborn. 29 | 30 | :return: 31 | """ 32 | sns.set() 33 | sns.despine(offset=10, trim=True) 34 | sns.set(font='serif') 35 | sns.set_style("darkgrid", {"font.family": "serif", "font.serif": ["Times"]}) 36 | sns.set_context("paper", rc={"font.size":10,"axes.titlesize":14,"axes.labelsize":14}) 37 | 38 | 39 | def init_logger(args): 40 | """ 41 | Initializes output directory to save the results and log the arguments. 42 | 43 | :param args: 44 | :return: 45 | """ 46 | # make output directoty 47 | out_dir = os.path.join('Output', f"{args.method}_{time.strftime('%m_%d_%Y-%H:%M:%S')}") 48 | if not os.path.exists(out_dir): 49 | os.makedirs(out_dir) 50 | eprint(f'Output Directory: {out_dir}\n') 51 | 52 | # save args into text file 53 | with open(os.path.join(out_dir, 'args.txt'), 'w') as file: 54 | file.write(str(args.__dict__)) 55 | 56 | return out_dir 57 | 58 | 59 | def eprint(*args, **kwargs): 60 | """ 61 | Prints to the std.err 62 | 63 | :param args: 64 | :param kwargs: 65 | :return: 66 | """ 67 | print(*args, file=sys.stderr, **kwargs) 68 | 69 | 70 | def get_imagenet_classes(labels_url='https://s3.amazonaws.com/deep-learning-models/image-models/imagenet_class_index.json'): 71 | """ 72 | downloads the label file for imagenet 73 | 74 | :param labels_url: 75 | :return: 76 | """ 77 | labels = requests.get(labels_url) 78 | return {int(key): value[1] for key, value in labels.json().items()} 79 | 80 | 81 | class ImageSet(torch.utils.data.Dataset): 82 | def __init__(self, root_dir, transform=None, blur=False): 83 | """ 84 | Loads data given a path (root_dir) and preprocess them (transforms, blur) 85 | 86 | :param root_dir: 87 | :param transform: 88 | :param blur: 89 | """ 90 | self.root_dir = root_dir 91 | self.transform = transform 92 | self.blur = blur 93 | self.transform = transforms.Compose( 94 | [transforms.Resize((224, 224)), 95 | transforms.ToTensor(), 96 | transforms.Normalize(mean,std) 97 | ] 98 | ) 99 | 100 | eprint(f"\nLoading filenames from '{root_dir}' directory...") 101 | (_, _, self.filenames) = next(os.walk(root_dir)) 102 | self.filenames = sorted(self.filenames) 103 | eprint(f"{len(self.filenames)} file(s) loaded.\n") 104 | 105 | def __getitem__(self, idx): 106 | if torch.is_tensor(idx): 107 | idx = idx.tolist() 108 | 109 | img_name = os.path.join(self.root_dir, self.filenames[idx]) 110 | image = Image.open(img_name).convert('RGB') 111 | 112 | if self.blur: 113 | resized_image = image.resize((224, 224)) 114 | blurred = cv2.GaussianBlur(np.asarray(resized_image), (51, 51), sigmaX=50) 115 | blurred = Image.fromarray(blurred.astype(np.uint8)) 116 | 117 | if self.transform: 118 | image = self.transform(image) 119 | 120 | if self.blur: 121 | blurred = self.transform(blurred) 122 | 123 | returns = [image] 124 | 125 | if self.blur: 126 | returns.append(blurred) 127 | return (*returns,) 128 | 129 | def __len__(self): 130 | return len(self.filenames) 131 | 132 | 133 | # Added function 134 | def get_blurred_image(img_name): 135 | transform_fn = transforms.Compose( 136 | [transforms.Resize((224, 224)), 137 | transforms.ToTensor(), 138 | transforms.Normalize(mean,std) 139 | ] 140 | ) 141 | 142 | image = Image.open(img_name).convert('RGB') 143 | resized_image = image.resize((224, 224)) 144 | blurred = cv2.GaussianBlur(np.asarray(resized_image), (51, 51), sigmaX=50) 145 | blurred = Image.fromarray(blurred.astype(np.uint8)) 146 | blurred = transform_fn(blurred) 147 | 148 | return blurred 149 | 150 | 151 | def save_heatmaps(masks, images, size, index, outdir, out=224): 152 | """ 153 | Save masks and corresponding overlay 154 | 155 | :param masks: 156 | :param images: 157 | :param size: 158 | :param index: 159 | :param outdir: 160 | :param out: 161 | :return: 162 | """ 163 | masks = masks.view(-1, 1, size, size) 164 | up = torch.nn.UpsamplingBilinear2d(size=(out, out)).cuda() 165 | 166 | u_mask = up(masks) 167 | u_mask = u_mask.permute((0,2, 3, 1)) 168 | 169 | # Normalize the mask 170 | u_mask = (u_mask - torch.min(u_mask)) / (torch.max(u_mask) - torch.min(u_mask)) 171 | u_mask = u_mask.cpu().detach().numpy() 172 | 173 | # deprocess images 174 | images = images.cpu().detach().permute((0, 2, 3, 1)) * std + mean 175 | images = images.numpy() 176 | 177 | for i, (image, u_mask) in enumerate(zip(images, u_mask)): 178 | 179 | # get the color map and normalize to 0-1 180 | heatmap = cv2.applyColorMap(np.uint8(255 * u_mask), cv2.COLORMAP_JET) 181 | heatmap = np.float32(heatmap / 255) 182 | # overlay the mask over the image 183 | overlay = (u_mask ** 0.8) * image + (1 - u_mask ** 0.8) * heatmap 184 | 185 | plt.imsave(os.path.join(outdir, f'{index+i+1}_heatmap.jpg'), heatmap) 186 | plt.imsave(os.path.join(outdir, f'{index+i+1}_overlay.jpg'), overlay) 187 | 188 | 189 | def save_masks(masks, index, categories, mask_name, outdir): 190 | """ 191 | Saves the generated masks as numpy.ndarrays. 192 | 193 | :param masks: 194 | :param index: 195 | :param categories: 196 | :param mask_name: 197 | :param outdir: 198 | :return: 199 | """ 200 | masks = masks.cpu().detach().numpy() 201 | for i, (mask, category) in enumerate(zip(masks, categories), start=index): 202 | np.save(os.path.join(outdir, f'{mask_name}_{i+1}_mask_{category}.npy'), mask) 203 | 204 | 205 | def save_curves(del_curve, ins_curve, index_curve, index, outdir): 206 | """ 207 | Save the deletion/insertion curves for the generated masks. 208 | 209 | :param del_curve: 210 | :param ins_curve: 211 | :param index_curve: 212 | :param index: 213 | :param outdir: 214 | :return: 215 | """ 216 | for i in range(len(del_curve)): 217 | fig, (ax, ax1) = plt.subplots(2, 1) 218 | ax.plot(index_curve, del_curve[i], color='r', label='deletion') 219 | ax.fill_between(index_curve, del_curve[i], facecolor='maroon', alpha=0.4) 220 | ax.set_ylim([-0.05, 1.05]) 221 | ax.tick_params(labelsize=14) 222 | ax.set_yticks(np.arange(0, 1.01, 1)) 223 | ax.legend(['Deletion'], fontsize='x-large') 224 | ax.text(0.5, 0.5, 'AUC: {:.4f}'.format(auc(del_curve[i])), fontsize=14, horizontalalignment='center', verticalalignment='center') 225 | 226 | ax1.plot(index_curve, ins_curve[i], color='b', label='Insertion') 227 | ax1.fill_between(index_curve, ins_curve[i], facecolor='darkblue', alpha=0.4) 228 | ax1.set_ylim([-0.05, 1.05]) 229 | ax1.tick_params(labelsize=14) 230 | ax1.set_yticks(np.arange(0, 1.01, 1)) 231 | ax1.legend(['Insertion'], fontsize='x-large') 232 | ax1.text(0.5, 0.5, 'AUC: {:.4f}'.format(auc(ins_curve[i])), fontsize=14, horizontalalignment='center', verticalalignment='center') 233 | 234 | # save the plot 235 | plt.savefig(os.path.join(outdir, f'{index+i+1}_curves.jpg'), bbox_inches='tight', pad_inches = 0) 236 | plt.close() 237 | 238 | 239 | def save_images(images, index, outdir, classes, labels): 240 | """ 241 | saves original images into output directory 242 | 243 | :param images: 244 | :param index: 245 | :param outdir: 246 | :param classes: 247 | :param labels: 248 | :return: 249 | """ 250 | images_ = images.cpu().detach().permute((0, 2, 3, 1)) * std + mean 251 | for i, image in enumerate(images_): 252 | plt.imsave(os.path.join(outdir, f'{index+i+1}_image_{classes[labels[i].item()]}.jpg'), image.numpy()) 253 | 254 | 255 | def load_image(path): 256 | """ 257 | loades an image given a path 258 | 259 | :param path: 260 | :return: 261 | """ 262 | mask = Image.open(path).convert('RGB') 263 | mask = np.array(mask, dtype=np.float32) 264 | return mask / 255 265 | 266 | 267 | def auc(array): 268 | """ 269 | calculates area under the curve (AUC) 270 | 271 | :param array: 272 | :return: 273 | """ 274 | return (sum(array) - array[0]/2 - array[-1]/2)/len(array) 275 | 276 | 277 | -------------------------------------------------------------------------------- /src/evaluation/IGOS_pp/main.py: -------------------------------------------------------------------------------- 1 | """ 2 | main file to call the explanations methods and run experiments, given a pre-trained 3 | model and a data loader. 4 | © copyright Tyler Lawson, Saeed khorram. https://github.com/saeed-khorram/IGOS 5 | """ 6 | import torch 7 | import os 8 | import sys 9 | 10 | sys.path.insert(0, os.path.abspath("../..")) 11 | 12 | from pathlib import Path 13 | from torchray.utils import get_device 14 | from tqdm import tqdm 15 | from timeit import default_timer 16 | 17 | from args import init_args 18 | from igos_utils import * 19 | from methods_helper import * 20 | from methods import iGOS_pp 21 | 22 | 23 | from data.dataloader import VOCDataModule, COCODataModule 24 | from models.classifier import VGG16ClassifierModel, Resnet50ClassifierModel 25 | from utils.helper import * 26 | from utils.image_utils import save_mask 27 | 28 | ##################### Change to your settings ################################ 29 | dataset = "VOC" # one of: ["VOC", "COCO"] 30 | data_base_path = "../../../datasets/" 31 | classifier_type = "vgg16" # one of: ["vgg16", "resnet50"] 32 | classifier_checkpoint = "../../checkpoints/pretrained_classifiers/vgg16_voc.ckpt" 33 | VOC_segmentations_path = '../../../datasets/VOC2007/VOCdevkit/VOC2007/SegmentationClass/' 34 | COCO_segmentations_path = '../coco_segmentations/' 35 | ############################################################################## 36 | 37 | def gen_explanations(model, dataloader, args): 38 | 39 | device = get_device() 40 | 41 | model.eval() 42 | 43 | method = iGOS_pp 44 | 45 | save_path = Path('../masks/{}_{}_{}/'.format(dataset, classifier_type, "igos_pp")) 46 | if not os.path.isdir(save_path): 47 | os.makedirs(save_path) 48 | 49 | eprint(f'Size is {args.size}x{args.size}') 50 | 51 | total_time = 0.0 52 | for data in tqdm(dataloader): 53 | 54 | image, annotations = data 55 | 56 | filename = get_filename_from_annotations(annotations, dataset=dataset) 57 | if dataset == "VOC": 58 | num_classes = 20 59 | segmentation_filename = Path(VOC_segmentations_path) / (os.path.splitext(filename)[0] + '.png') 60 | data_path = Path(data_base_path) / 'VOC2007' / 'VOCdevkit' / 'VOC2007' / 'JPEGImages' 61 | elif dataset == "COCO": 62 | num_classes = 91 63 | segmentation_filename = Path(COCO_segmentations_path) / (os.path.splitext(filename)[0] + '.png') 64 | data_path = Path(data_base_path) / 'COCO2014' / 'val2014' 65 | if not os.path.exists(segmentation_filename): 66 | continue 67 | 68 | img_path = data_path / filename 69 | blurred_image = get_blurred_image(img_path) 70 | 71 | image = image.to(device) 72 | blurred_image = blurred_image.to(device).unsqueeze(0) 73 | 74 | targets = get_targets_from_annotations(annotations, dataset=dataset)[0] 75 | 76 | masks = torch.zeros(num_classes, 224, 224) 77 | 78 | start_time = default_timer() 79 | for i in range(num_classes): 80 | if targets[i] == 1.0: 81 | # generate masks 82 | label = torch.tensor([i], dtype=torch.int64, device=device) 83 | mask = method( 84 | model, 85 | images=image.detach(), 86 | baselines=blurred_image.detach(), 87 | labels=label, 88 | size=args.size, 89 | iterations=args.ig_iter, 90 | ig_iter=args.iterations, 91 | L1=args.L1, 92 | L2=args.L2, 93 | alpha=args.alpha, 94 | ) 95 | 96 | mask = mask.view(-1, 1, args.size, args.size) 97 | up = torch.nn.UpsamplingBilinear2d(size=(224, 224)).to(device) 98 | 99 | u_mask = up(mask) 100 | u_mask = u_mask.permute((0, 2, 3, 1)) 101 | 102 | # Normalize the mask 103 | u_mask = (u_mask - torch.min(u_mask)) / (torch.max(u_mask) - torch.min(u_mask)) 104 | 105 | masks[i] = torch.ones(size=(224, 224), device=device) - u_mask.squeeze() 106 | 107 | aggregated_mask = masks.amax(dim=0) 108 | end_time = default_timer() 109 | total_time += end_time - start_time 110 | 111 | save_mask(aggregated_mask, save_path / filename) 112 | 113 | print("Total time for masking process of iGOS++ with dataset {} and classifier {}: {} seconds".format(dataset, classifier_type, total_time)) 114 | model.train() 115 | 116 | 117 | if __name__ == "__main__": 118 | 119 | args = init_args() 120 | eprint(f"args:\n {args}") 121 | 122 | torch.manual_seed(args.manual_seed) 123 | 124 | init(args.input_size) 125 | init_sns() 126 | 127 | # Set up data module 128 | if dataset == "VOC": 129 | num_classes = 20 130 | data_path = Path(data_base_path) / "VOC2007" 131 | data_module = VOCDataModule(data_path=data_path, test_batch_size=1) 132 | elif dataset == "COCO": 133 | num_classes = 91 134 | data_path = Path(data_base_path) / "COCO2014" 135 | data_module = COCODataModule(data_path=data_path, test_batch_size=1) 136 | else: 137 | raise Exception("Unknown dataset " + dataset) 138 | 139 | data_module.setup(stage = "test") 140 | 141 | 142 | eprint("Loading the model...") 143 | 144 | if classifier_type == "vgg16": 145 | classifier = VGG16ClassifierModel.load_from_checkpoint(classifier_checkpoint, num_classes=num_classes, dataset=dataset) 146 | elif classifier_type == "resnet50": 147 | classifier = Resnet50ClassifierModel.load_from_checkpoint(classifier_checkpoint, num_classes=num_classes, dataset=dataset) 148 | else: 149 | raise Exception("Unknown classifier type: " + classifier_type) 150 | 151 | for param in classifier.parameters(): 152 | param.requires_grad_(False) 153 | 154 | device = get_device() 155 | classifier.to(device) 156 | classifier.eval() 157 | 158 | eprint(f"Model({classifier_type}) successfully loaded!\n") 159 | 160 | 161 | gen_explanations(classifier, data_module.test_dataloader(), args) -------------------------------------------------------------------------------- /src/evaluation/IGOS_pp/methods_helper.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helper function for the IGOS explanation methods. 3 | © copyright Tyler Lawson, Saeed khorram. https://github.com/saeed-khorram/IGOS 4 | """ 5 | 6 | import torch 7 | 8 | 9 | # Initializes the upsampling function for the upscale function 10 | def init(out_size): 11 | """ 12 | Initializes output size for the upsampler. 13 | :param out_size: 14 | :return: 15 | """ 16 | upscale.out_size = out_size 17 | upscale.up = torch.nn.UpsamplingBilinear2d(size=(out_size, out_size)).cuda() 18 | 19 | 20 | def tv_norm(image, beta=2): 21 | """ 22 | Calculates the total variation. 23 | :param image: 24 | :param beta: 25 | :return: 26 | """ 27 | image = image[:, 0, :, :] 28 | a = torch.mean(torch.abs((image[:, :-1, :] - image[:, 1:, :]).view(image.shape[0], -1)).pow(beta), dim=1) 29 | b = torch.mean(torch.abs((image[:, :, :-1] - image[:, :, 1:]).view(image.shape[0], -1)).pow(beta), dim=1) 30 | return a + b 31 | 32 | 33 | def bilateral_tv_norm(image, mask, tv_beta=2, sigma=1): 34 | """ 35 | Calculates the bilateral total variation. 36 | 37 | :param image: 38 | :param mask: 39 | :param tv_beta: 40 | :param sigma: 41 | :return: 42 | """ 43 | # tv term 44 | mask_ = mask[:, 0, :] 45 | a = torch.mean(torch.abs((mask_[:, :-1, :] - mask_[:, 1:, :]).view(mask.shape[0], -1)).pow(tv_beta), dim=1) 46 | b = torch.mean(torch.abs((mask_[:, :, :-1] - mask_[:, :, 1:]).view(mask.shape[0], -1)).pow(tv_beta), dim=1) 47 | # bilateral tv in the image space 48 | up_mask_ = upscale(mask) 49 | 50 | bil_a = torch.mean(torch.exp(-(image[:, :, :-1, :] - image[:, :, 1:, :]).mean(dim=1) ** 2 / sigma).view(mask.shape[0], -1) 51 | * torch.abs((up_mask_[:, :, :-1, :] - up_mask_[:, :, 1:, :]).view(up_mask_.shape[0], -1)).pow(tv_beta), dim=1) 52 | bil_b = torch.mean(torch.exp(-(image[:, :, :, :-1] - image[:, :, :, 1:]) ** 2 / sigma).mean(dim=1).view(mask.shape[0], -1) 53 | * torch.abs((up_mask_[:, :, :, :-1] - up_mask_[:, :, :, 1:]).view(up_mask_.shape[0], -1)).pow(tv_beta), dim=1) 54 | return 0.5 * (a + b + bil_a + bil_b) 55 | 56 | 57 | def upscale(masks): 58 | """ 59 | Upsamples masks and expands it to the same dimensions as the image 60 | 61 | values are set using the init() function 62 | :param masks: 63 | :return: 64 | """ 65 | return upscale.up(masks).expand((-1,1,upscale.out_size,upscale.out_size)) # TODO 66 | 67 | 68 | def interval_score(model, images, baselines, labels, up_masks, num_iter, output_func, noise=True): 69 | """ 70 | Computes the score of masked images in a straight line 71 | path from baselines to masked images, with num_iter intervals. 72 | 73 | :param model: 74 | :param images: 75 | :param baselines: 76 | :param labels: 77 | :param up_masks: 78 | :param num_iter: 79 | :param noise: 80 | :param output_func: 81 | :return: 82 | """ 83 | # The intervals to approximate the integral over 84 | intervals = torch.linspace(1/num_iter, 1, num_iter, requires_grad=False).cuda().view(-1, 1, 1, 1) 85 | interval_masks = up_masks.unsqueeze(1) * intervals 86 | local_images = phi(images.unsqueeze(1), baselines.unsqueeze(1), interval_masks) 87 | 88 | if noise: 89 | local_images = local_images + torch.randn_like(local_images) * .2 90 | 91 | # Shape of image tensor when viewed in batch form 92 | new_shape = torch.tensor(images.shape) * torch.tensor(intervals.shape) 93 | losses = output_func(local_images.view(*new_shape), model).view(images.shape[0], num_iter, -1) 94 | losses = torch.gather(losses, 2, labels.view(-1, 1).expand(-1, num_iter).view(-1, num_iter, 1)) 95 | return losses / num_iter 96 | 97 | 98 | def integrated_gradient(model, images, baselines, labels, up_masks, num_iter, output_func=None, noise=True): 99 | """ 100 | Calculates and backprops the integrated gradient. 101 | Does not have the original mask, so does not return the gradient 102 | 103 | :param model: 104 | :param images: 105 | :param baselines: 106 | :param labels: 107 | :param up_masks: 108 | :param num_iter: 109 | :param noise: 110 | :param output_func: 111 | :return: 112 | """ 113 | for i in range(images.shape[0]): 114 | loss = interval_score( 115 | model, 116 | images[i].unsqueeze(0), 117 | baselines[i].unsqueeze(0), 118 | labels[i].unsqueeze(0), 119 | up_masks[i].unsqueeze(0), 120 | num_iter, 121 | output_func, 122 | noise, 123 | ) 124 | loss.sum().backward(retain_graph=True) 125 | 126 | 127 | def line_search(masks, total_grads, loss_func, alpha=8, beta=0.0001, decay=0.2,): 128 | """ 129 | Computes a line search in batch. Works by starting far in the direction of total_grads and works 130 | backward until all meet the target condition or their corresponding alpha value is below some value. 131 | Uses loss_func for the target condition. 132 | 133 | :param masks: 134 | :param total_grads: 135 | :param loss_func: 136 | :param alpha: 137 | :param beta: 138 | :param decay: 139 | :return: 140 | """ 141 | # Speed up computations, reduce memory usage, and ensure no autograd 142 | # graphs are created 143 | with torch.no_grad(): 144 | i = 0 145 | mod = len(masks.shape) - 3 146 | num_inputs = masks.shape[0] 147 | # The indices of masks that still need their alphas updated 148 | indices = torch.ones(num_inputs, dtype=torch.bool).cuda() 149 | # Create initial alpha values for each mask 150 | alphas = torch.ones(num_inputs).cuda() * alpha 151 | 152 | up_masks = upscale(masks.view(-1,*masks.shape[mod:])).view(-1, *masks.shape[1:mod], 1, upscale.out_size, upscale.out_size) 153 | 154 | # Compute the base loss used in the condition 155 | base_losses = loss_func(up_masks, masks, indices).view(-1) 156 | t = -beta * (total_grads ** 2).view(num_inputs, -1).sum(dim=1).view(num_inputs) 157 | 158 | while True: 159 | # Create a new mask with the updated alpha value to 160 | # see if it meets condition 161 | new_masks = torch.clamp(masks[indices] - alphas[indices].view(-1,*(1,) * mod,1,1) * total_grads[indices], 0, 1) 162 | up_masks = upscale(new_masks.view(-1,*masks.shape[mod:])).view(-1,*masks.shape[1:mod], 1, upscale.out_size, upscale.out_size) 163 | # Calculate new losses 164 | losses = loss_func(up_masks, new_masks, indices).view(-1) 165 | # Get indices for each alpha that meets the condition for 166 | # their corresponding mask 167 | indices[indices.clone()] = losses > base_losses[indices] + alphas[indices] * t[indices] 168 | # Same for this, but for if the alpha values are too low (\alpha_l) 169 | indices[indices.clone()] = indices[indices] * (alphas[indices] >= 0.00001) 170 | # Break out of the loop if all alpha values satisfy the condition 171 | # or are too low 172 | if not indices.sum(): 173 | break 174 | # Otherwise update alphas 175 | alphas[indices] *= decay 176 | i += 1 177 | return alphas.view(-1,1,1,1) 178 | 179 | 180 | def phi(img, baseline, mask): 181 | """ 182 | Composes an image from img and baseline according to the mask values. 183 | 184 | :param img: 185 | :param baseline: 186 | :param mask: 187 | :return: 188 | """ 189 | return img.mul(mask) + baseline.mul(1-mask) 190 | 191 | 192 | def softmax_output(inputs, model): 193 | """ 194 | Applies softamx over the output of the model. 195 | 196 | :param inputs: 197 | :param model: 198 | :return: 199 | """ 200 | return torch.nn.Softmax(dim=1)(model(inputs)) 201 | 202 | 203 | def logit_output(inputs, model): 204 | """ 205 | Simply returns the output of the model, given an input. 206 | 207 | :param inputs: 208 | :param model: 209 | :return: 210 | """ 211 | return model(inputs) 212 | 213 | 214 | def metric(image, baseline, mask, model, labels, step=5, size=28,): 215 | """ 216 | Calculates the deletion/insertion scores/curves given the image and generated masks. 217 | 218 | :param image: 219 | :param baseline: 220 | :param mask: 221 | :param model: 222 | :param labels: 223 | :param step: 224 | :param size: 225 | :return: 226 | """ 227 | with torch.no_grad(): 228 | # The dimensions for the image 229 | out_size = image.shape[-1] 230 | # Compute the total number of pixels in a mask 231 | num_pixels = torch.prod(torch.tensor(mask.shape[1:])).item() 232 | # Used for indexing with batch sizes 233 | l = torch.arange(image.shape[0]) 234 | # The unmasked score 235 | og_scores = torch.nn.Softmax(dim=1)(model(image))[l,labels] 236 | # The baseline score 237 | blur_scores = torch.nn.Softmax(dim=1)(model(baseline))[l,labels] 238 | # Initial values for the curves 239 | del_curve = [og_scores] 240 | ins_curve = [blur_scores] 241 | index = [0.] 242 | 243 | up = torch.nn.UpsamplingBilinear2d(size=(out_size,out_size)).cuda() 244 | # True_mask is used to hold 1 or 0. Either show that pixel or blur it. 245 | true_mask = torch.ones((mask.shape[0], num_pixels)).cuda() 246 | del_scores = torch.zeros(mask.shape[0]).cuda() 247 | ins_scores = torch.zeros(mask.shape[0]).cuda() 248 | # Sort each mask by values and store the indices. 249 | elements = torch.argsort(mask.view(mask.shape[0], -1), dim=1) 250 | 251 | for pixels in range(0, num_pixels, step): 252 | 253 | # Get the indices used in this iteration 254 | indices = elements[l,pixels:pixels+step].squeeze().view(image.shape[0], -1) 255 | # Set those indices to 0 256 | true_mask[l, indices.permute(1,0)] = 0 257 | up_mask = upscale(true_mask.view(-1, 1, size,size)) 258 | # Mask the images for deletion 259 | del_image = phi(image, baseline, up_mask) 260 | # Calculate new scores 261 | outputs = torch.nn.Softmax(dim=1)(model(del_image))[l,labels] 262 | del_curve.append(outputs) 263 | index.append((pixels+step)/num_pixels) 264 | del_scores += outputs * step if pixels + step < num_pixels else\ 265 | num_pixels - pixels 266 | 267 | # Mask the images for insertion 268 | ins_image = phi(baseline, image, up_mask) 269 | 270 | # Calculate the new scores 271 | outputs = torch.nn.Softmax(dim=1)(model(ins_image))[l,labels] 272 | 273 | ins_curve.append(outputs) 274 | ins_scores += outputs * step if pixels + step < num_pixels else\ 275 | num_pixels - pixels 276 | 277 | # Force scores between 0 and 1. 278 | del_scores /= size*size 279 | ins_scores /= size*size 280 | 281 | del_curve = list(map(lambda x: [y.item() for y in x], zip(*del_curve))) 282 | ins_curve = list(map(lambda x: [y.item() for y in x], zip(*ins_curve))) 283 | 284 | return del_scores, ins_scores, del_curve, ins_curve, index 285 | -------------------------------------------------------------------------------- /src/evaluation/eval_utils/assessment_metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from sklearn.metrics import roc_auc_score 4 | 5 | #### PLAIN METRICS TO BENCHMARK 6 | def prob_sparsity(pvec): 7 | """Sparsity measure. 8 | 9 | For pvec of the masked image, we want this to be low. 10 | For pvec of the inverse masked image, we want this to be high. 11 | """ 12 | return np.sum(pvec**2) 13 | 14 | 15 | def prob_entropy(pvec): 16 | """Sparsity measure. 17 | 18 | For pvec of the masked image, we want this to be low. 19 | For pvec of the inverse masked image, we want this to be high. 20 | """ 21 | return -np.sum(pvec * np.log(np.maximum(pvec, 1e-15))) 22 | 23 | def saliency(pvec, c, mask): 24 | """ 25 | Continuous saliency measure. 26 | 27 | Adaptation from "Real Time Image Saliency for Black Box Classifiers 28 | Piotr", Dabkowski and Gal. 29 | 30 | For pvec of the masked image, the lower the better for the masked image. 31 | 32 | This measure does not make sense for the inverse masked image. 33 | """ 34 | a = np.maximum(np.mean(mask), 0.05) 35 | if isinstance(c, int): 36 | pclass = pvec[c] 37 | else: 38 | pclass = 0 39 | for e in c: 40 | pclass += pvec[e] 41 | return np.log(a) - np.log(pclass) 42 | 43 | def continuous_IOU(mask, seg): 44 | ### this is no longer the IoU but 1 + the Soergel distance (which is 1 - this ratio below) 45 | #intersection = np.sum(mask * seg) 46 | #union = np.sum(mask + seg)/2 47 | #union = np.sum(mask + seg) - intersection 48 | intersection = np.sum(np.minimum(mask, seg)) 49 | union = np.sum(np.maximum(mask, seg)) 50 | IOU = intersection/(union + 1e-15) 51 | 52 | return IOU 53 | 54 | def discrete_IOU(mask, seg, thresh=0.5): 55 | 56 | """ Binarize the mask at 'tresh'. Then compute intersection (AND) and union (OR) and calculate intersection/union. """ 57 | 58 | mask = mask>thresh 59 | seg = seg>thresh 60 | intersection = np.sum(np.logical_and(mask, seg)*1.) 61 | union = np.sum(np.logical_or(mask, seg)*1.) 62 | IOU = intersection/(union + 1e-15) 63 | return IOU 64 | 65 | def IOUs(mask, seg): 66 | return [discrete_IOU(mask, seg, thresh) for thresh in np.arange(0.1, 1, 0.1)] 67 | 68 | def soft_accuracy(mask, seg): 69 | """ Calculate continuous accuracy (without binarizing before). """ 70 | 71 | tp = np.sum(mask*seg) 72 | tn = np.sum((1-mask)*(1-seg)) 73 | fp = np.sum(mask*(1-seg)) 74 | fn = np.sum((1-mask) * seg) 75 | 76 | accuracy = (tp+tn) / (tp+tn+fp+fn) 77 | return accuracy 78 | 79 | def soft_f1(mask, seg): 80 | 81 | """ Calculate continuous f1-score (without binarizing before). """ 82 | 83 | # F1 = 2 * (precision * recall) / (precision + recall) 84 | tp = np.sum(mask*seg) 85 | fp = np.sum(mask*(1-seg)) 86 | fn = np.sum((1-mask) * seg) 87 | precision = tp/(tp+fp + 1e-10) 88 | recall = tp/(tp+fn + 1e-10) 89 | f1 = 2 * (precision * recall) / (precision + recall + 1e-10) 90 | return f1 91 | 92 | def f1(actual, predicted, label): 93 | 94 | """ A helper function to calculate f1-score for the given `label` """ 95 | 96 | # F1 = 2 * (precision * recall) / (precision + recall) 97 | tp = np.sum((actual==label) & (predicted==label)) 98 | fp = np.sum((actual!=label) & (predicted==label)) 99 | fn = np.sum((predicted!=label) & (actual==label)) 100 | precision = tp/(tp+fp + 1e-10) 101 | recall = tp/(tp+fn + 1e-10) 102 | f1 = 2 * (precision * recall) / (precision + recall + 1e-10) 103 | return f1 104 | 105 | def f1_macro(actual, predicted): 106 | # `macro` f1- unweighted mean of f1 per label 107 | # return np.mean([f1(actual, predicted, label) for label in [True, False]]) 108 | return np.mean([f1(actual, predicted, label) for label in [True]]) 109 | 110 | def discrete_f1(mask, seg, thresh=0.5): 111 | 112 | """ Binarize the mask at 'tresh', then calculate f1-score for both labels (True, False) and take mean.""" 113 | 114 | mask = mask>thresh 115 | seg = seg>thresh 116 | return f1_macro(seg, mask) 117 | 118 | def f1s(mask, seg): 119 | 120 | """ Calculate the discrete f1-score for each treshold from 0.1 to 1 in 0.1 steps.""" 121 | 122 | return [discrete_f1(mask, seg, thresh) for thresh in np.arange(0.1, 1, 0.1)] 123 | 124 | def auc(mask, seg): 125 | 126 | "Compute area under ROC " 127 | 128 | try: 129 | return roc_auc_score((seg.flatten()>0)*1., mask.flatten()) 130 | except: 131 | return None 132 | 133 | def sim_ratio(mask, seg): 134 | scp = np.sum(mask*seg) 135 | sr = scp / (scp + np.sum((mask - seg)**2) + 1e-10) 136 | 137 | return sr 138 | 139 | 140 | 141 | def mask_coverage(mask, seg_mask): 142 | 143 | """Compute the true positive rate (proportion of correct classified foreground compared to foreground).""" 144 | 145 | seg_area = np.average(seg_mask) 146 | 147 | if seg_area > 0.0: 148 | seg_diff = np.clip(seg_mask - mask, a_min=0.0, a_max=None) 149 | return 1 - (np.average(seg_diff) * (1 / seg_area)) 150 | else: 151 | return 0 152 | 153 | def background_coverage(mask, seg_mask): 154 | 155 | """Compute the false positive rate (proportion of false classified foreground compared to background).""" 156 | 157 | seg_area = np.average(seg_mask) 158 | 159 | if seg_area < 1.0: 160 | non_seg_diff = np.clip(mask - seg_mask, a_min=0.0, a_max=None) 161 | return (np.average(non_seg_diff) * (1 / (1-seg_area))) 162 | else: 163 | return 0 164 | 165 | def overlap(mask, seg_mask): 166 | 167 | "Return proportion where mask and groundtruth match. " 168 | 169 | return 1 - np.average(np.absolute(seg_mask - mask)) 170 | -------------------------------------------------------------------------------- /src/evaluation/eval_utils/compute_masks.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | sys.path.insert(0, os.path.abspath("..")) 5 | 6 | import torch 7 | import numpy as np 8 | import matplotlib.pyplot as plt 9 | 10 | from pathlib import Path 11 | from torchray.utils import get_device 12 | from tqdm import tqdm 13 | from timeit import default_timer 14 | 15 | from torchray.attribution.guided_backprop import guided_backprop 16 | from torchray.attribution.extremal_perturbation import extremal_perturbation, contrastive_reward, Perturbation 17 | 18 | from utils.helper import get_target_dictionary 19 | 20 | def merge_mask(masks): 21 | if len(masks) > 1: 22 | return torch.max(masks, dim = 0, keepdims=True)[0] 23 | return masks 24 | 25 | def rescale_mask(mask): 26 | min_val = mask.min() 27 | max_val = mask.max() 28 | 29 | mask = (mask - min_val) / (max_val - min_val) 30 | mask = mask.clamp(0, 1) 31 | return mask 32 | 33 | def compute_mask(x, model, category_id, method="extremal_perturbations"): 34 | masks = [] 35 | for c in category_id: 36 | if method=="guided_backprop": 37 | m = guided_backprop(model, x,c) 38 | 39 | elif method=="extremal_perturbations": 40 | m = vedaldi2019(model, x, c) 41 | else: 42 | raise ValueError("Unknown method") 43 | masks.append(m) 44 | mask = merge_mask(torch.cat(masks, dim=0)) 45 | mask = rescale_mask(mask) 46 | return mask 47 | 48 | def vedaldi2019(model, x, c): 49 | areas = [0.05, 0.1, 0.2, 0.3] 50 | num_levels = 8 51 | mask, energy = extremal_perturbation( 52 | model, x, c, 53 | areas=areas, 54 | debug=False, 55 | jitter=True, 56 | num_levels=num_levels, 57 | step=7, 58 | sigma=7 * 3, 59 | max_iter=800, 60 | smooth=0.09, 61 | reward_func=contrastive_reward, 62 | perturbation='blur' 63 | ) 64 | 65 | perturbation = Perturbation(x, num_levels=num_levels, type='blur') 66 | x_perturbed = perturbation.apply(mask=mask) 67 | 68 | # saliency = mask.mean(dim=0, keepdim=True) 69 | 70 | logit_perturbed = model.forward(x_perturbed)[:, c] 71 | logit = model.forward(x)[:, c] 72 | vec = logit_perturbed < logit 73 | i = 0 74 | while vec[i] and i<(len(vec)-1): 75 | i += 1 76 | saliency = mask[i:i+1] 77 | return saliency 78 | 79 | def save_mask(mask, p): 80 | path_file = str(p)[:-4] 81 | img = mask.detach().cpu().numpy().squeeze() 82 | plt.imsave(path_file + ".png", img, cmap='gray',format="png") 83 | np.savez_compressed(path_file+ ".npz", img) 84 | 85 | def compute_and_save_masks(model, data_module, path_segmentation, path_masks, method="guided_backprop"): 86 | # Run on GPU if available. 87 | device = get_device() 88 | model.to(device) 89 | 90 | path_masks.mkdir(parents=True, exist_ok=True) 91 | count = 0 92 | total_time = 0.0 93 | for s in tqdm(data_module.test_dataloader()): 94 | img, meta = s 95 | x = img 96 | x = x.to(device) 97 | assert(len(x)==1) 98 | try: 99 | category_id = meta[0]["targets"] 100 | filename = meta[0]["filename"] 101 | filename = Path(os.path.splitext(filename)[0] + '.png') 102 | except: 103 | filename = Path(meta[0]['annotation']["filename"][:-4]+".png") 104 | target_dict = get_target_dictionary(include_background_class=False) 105 | objects = meta[0]['annotation']['object'] 106 | category_id = [target_dict[e["name"]] for e in objects] 107 | segmentation_filename = path_segmentation / filename 108 | if not os.path.exists(segmentation_filename): 109 | continue 110 | 111 | count += 1 112 | 113 | start_time = default_timer() 114 | mask = compute_mask(x, model, category_id, method=method) 115 | end_time = default_timer() 116 | total_time += end_time - start_time 117 | 118 | save_mask(mask, path_masks / filename) 119 | 120 | path_split = str(path_masks).split('/')[1].split('_') 121 | print("Total time for masking process of {} with dataset {} and classifier {}: {} seconds".format(method, path_split[0], path_split[1], total_time)) 122 | return count -------------------------------------------------------------------------------- /src/evaluation/eval_utils/compute_scores.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | from pathlib import Path 5 | 6 | from evaluation.eval_utils.assessment_metrics import saliency, continuous_IOU, discrete_IOU, IOUs, discrete_f1, soft_accuracy, soft_f1 7 | from evaluation.eval_utils.assessment_metrics import mask_coverage, background_coverage, overlap, sim_ratio, f1s, auc 8 | from data.dataloader import VOCDataModule, COCODataModule 9 | from models.classifier import VGG16ClassifierModel, Resnet50ClassifierModel 10 | from utils.helper import get_target_dictionary 11 | 12 | from torchray.utils import get_device 13 | import torch 14 | 15 | from PIL import Image 16 | import torchvision.transforms as transforms 17 | 18 | from tqdm import tqdm 19 | 20 | def get_model_and_data(data_path, dataset_name, model_name, model_path): 21 | if dataset_name == "VOC": 22 | data_module = VOCDataModule(data_path, test_batch_size=1) 23 | if model_name == "vgg16": 24 | model = VGG16ClassifierModel.load_from_checkpoint(model_path, num_classes=20, dataset=dataset_name) 25 | elif model_name == "resnet50": 26 | model = Resnet50ClassifierModel.load_from_checkpoint(model_path, num_classes=20, dataset=dataset_name) 27 | elif dataset_name == "COCO": 28 | data_module = COCODataModule(data_path, test_batch_size=1) 29 | if model_name == "vgg16": 30 | model = VGG16ClassifierModel.load_from_checkpoint(model_path, num_classes=91, dataset=dataset_name) 31 | elif model_name == "resnet50": 32 | model = Resnet50ClassifierModel.load_from_checkpoint(model_path, num_classes=91, dataset=dataset_name) 33 | 34 | data_module.setup(stage="test") 35 | 36 | return model, data_module 37 | 38 | def segmented_generator(data_module, segmentations_path): 39 | """Generator that return all the segmented images""" 40 | for s in tqdm(data_module.test_dataloader()): 41 | img, meta = s 42 | x = img 43 | assert(len(x)==1) 44 | try: 45 | category_id = meta[0]["targets"] 46 | filename = Path(meta[0]["filename"][:-4]+".png") 47 | except: 48 | filename = Path(meta[0]['annotation']["filename"][:-4]+".png") 49 | target_dict = get_target_dictionary(include_background_class=False) 50 | objects = meta[0]['annotation']['object'] 51 | category_id = [target_dict[e["name"]] for e in objects] 52 | segmentation_filename = segmentations_path / filename 53 | if not os.path.exists(segmentation_filename): 54 | continue 55 | else: 56 | yield x, category_id, filename 57 | 58 | def open_segmentation_mask(segmentation_filename, dataset_name): 59 | transformer = transforms.Compose([transforms.Resize((224, 224))]) 60 | mask = Image.open(segmentation_filename).convert('L') 61 | mask = transformer(mask) 62 | mask = np.array(mask) / 255.0 63 | if dataset_name == "VOC": 64 | mask[mask > 0] = 1 65 | return mask 66 | 67 | def get_path_mask(masks_path, dataset_name, model_name, method): 68 | return masks_path / Path('{}_{}_{}/'.format(dataset_name, model_name, method)) 69 | 70 | def gen_evaluation(data_path, masks_path, segmentations_path, dataset_name, model_name, model_path, method, compute_p=True): 71 | # Load the model and data 72 | model, data_module = get_model_and_data(data_path, dataset_name, model_name, model_path) 73 | # Path of the masks 74 | if method in ["0.5", "0", "1", "perfect"]: 75 | masks_path_method = None 76 | else: 77 | masks_path_method = get_path_mask(masks_path, dataset_name, model_name, method) 78 | 79 | if compute_p: 80 | device = get_device() 81 | model = model.to(device) 82 | 83 | for x, category_id, filename in segmented_generator(data_module, segmentations_path): 84 | seg_mask = open_segmentation_mask(segmentations_path / filename, dataset_name) 85 | if method in ["0.5", "0", "1", "perfect"]: 86 | if method=="0": 87 | mask = np.zeros(seg_mask.shape, dtype=np.float32) 88 | elif method=="0.5": 89 | mask = 0.5*np.ones(seg_mask.shape, dtype=np.float32) 90 | elif method=="1": 91 | mask = np.ones(seg_mask.shape, dtype=np.float32) 92 | elif method=="perfect": 93 | mask = seg_mask.copy().astype(np.float32) 94 | else: 95 | raise ValueError("Something went wrong!") 96 | 97 | else: 98 | try: 99 | npz_name = Path(str(filename)[:-4] + ".npz") 100 | mask = np.load(masks_path_method / npz_name, dataset_name)["arr_0"] 101 | except: 102 | continue 103 | if np.sum(np.isnan(mask)): 104 | mask = np.zeros(shape=mask.shape, dtype=np.float32) 105 | if compute_p: 106 | x = x.to(device) 107 | logits = model.forward(x) 108 | p = torch.nn.functional.softmax(logits, dim=1).detach().cpu().numpy().squeeze() 109 | x_masked = torch.tensor(np.reshape(mask, [1,1, *mask.shape])).to(device) * x 110 | logits_mask = model.forward(x_masked) 111 | p_mask = torch.nn.functional.softmax(logits_mask, dim=1).detach().cpu().numpy().squeeze() 112 | x_background = torch.tensor(np.reshape(1-mask, [1,1, *mask.shape])).to(device) * x 113 | logits_background = model.forward(x_background) 114 | p_background = torch.nn.functional.softmax(logits_background, dim=1).detach().cpu().numpy().squeeze() 115 | else: 116 | p = None 117 | p_mask = None 118 | p_background = None 119 | yield mask, seg_mask, p, p_mask, p_background, category_id, x.detach().cpu().numpy().squeeze() 120 | 121 | 122 | def compute_numbers(data_path, masks_path, segmentations_path, dataset_name, model_name, model_path, method, compute_p=True): 123 | # sparsity = [] 124 | # sparsity_masked = [] 125 | # sparsity_background = [] 126 | 127 | # entropy = [] 128 | # entropy_masked = [] 129 | # entropy_background = [] 130 | 131 | d_f1_25 = [] 132 | d_f1_50 = [] 133 | d_f1_75 = [] 134 | c_f1 = [] 135 | a_f1s = [] 136 | acc = [] 137 | aucs = [] 138 | 139 | d_IOUs = [] 140 | c_IOU = [] 141 | 142 | sal = [] 143 | over = [] 144 | background_c = [] 145 | mask_c = [] 146 | sr = [] 147 | 148 | 149 | 150 | 151 | for mask, seg_mask, p, p_mask, p_background, category_id, x in gen_evaluation(data_path, masks_path, 152 | segmentations_path, dataset_name, 153 | model_name, model_path, 154 | method, compute_p=compute_p): 155 | 156 | # sparsity.append(prob_sparsity(p)) 157 | # sparsity_masked.append(prob_sparsity(p_mask)) 158 | # sparsity_background.append(prob_sparsity(p_background)) 159 | 160 | # entropy.append(prob_entropy(p)) 161 | # entropy_masked.append(prob_entropy(p_mask)) 162 | # entropy_background.append(prob_entropy(p_background)) 163 | d_f1_25.append(discrete_f1(mask, seg_mask, 0.25)) 164 | d_f1_50.append(discrete_f1(mask, seg_mask, 0.50)) 165 | d_f1_75.append(discrete_f1(mask, seg_mask, 0.75)) 166 | c_f1.append(soft_f1(mask, seg_mask)) 167 | a_f1s.append(f1s(mask, seg_mask)) 168 | acc.append(soft_accuracy(mask, seg_mask)) 169 | aucs.append(auc(mask, seg_mask)) 170 | 171 | d_IOUs.append(IOUs(mask, seg_mask)) 172 | c_IOU.append(continuous_IOU(mask, seg_mask)) 173 | 174 | sal.append(saliency(p_mask, category_id, mask)) 175 | 176 | over.append(overlap(mask, seg_mask)) 177 | background_c.append(background_coverage(mask, seg_mask)) 178 | mask_c.append(mask_coverage(mask, seg_mask)) 179 | 180 | sr.append(sim_ratio(mask, seg_mask)) 181 | 182 | return d_f1_25,d_f1_50,d_f1_75,c_f1,a_f1s, acc, aucs, d_IOUs, c_IOU, sal, over, background_c, mask_c, sr 183 | -------------------------------------------------------------------------------- /src/evaluation/evaluate.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | sys.path.insert(0, os.path.abspath("..")) 5 | 6 | import numpy as np 7 | from pathlib import Path 8 | 9 | from eval_utils.compute_scores import compute_numbers 10 | 11 | ############################################## Change to your settings ########################################################## 12 | masks_path = Path("./masks/") 13 | data_base_path = Path("../../datasets/") 14 | VOC_segmentations_path = Path("../../datasets/VOC2007/VOCdevkit/VOC2007/SegmentationClass/") 15 | COCO_segmentations_path = Path("./coco_segmentations/") 16 | 17 | datasets = ["VOC", "COCO"] 18 | classifiers = ["vgg16", "resnet50"] 19 | vgg16_voc_checkpoint = "../checkpoints/pretrained_classifiers/vgg16_voc.ckpt" 20 | vgg16_coco_checkpoint = "../checkpoints/pretrained_classifiers/vgg16_coco.ckpt" 21 | resnet50_voc_checkpoint = "../checkpoints/pretrained_classifiers/resnet50_voc.ckpt" 22 | resnet50_coco_checkpoint = "../checkpoints/pretrained_classifiers/resnet50_coco.ckpt" 23 | 24 | methods = ["explainer", "extremal_perturbations", "grad_cam", "rise", 25 | "rt_saliency", "igos_pp", "0.5", "0", "1", "perfect"] 26 | ################################################################################################################################# 27 | 28 | try: 29 | results = np.load("results.npz", allow_pickle=True)["results"].item() 30 | except: 31 | results = {} 32 | 33 | for dataset in datasets: 34 | if not(dataset in results): 35 | results[dataset] = {} 36 | for classifier in classifiers: 37 | if not(classifier in results[dataset]): 38 | results[dataset][classifier] = {} 39 | for method in methods: 40 | if not(method in results[dataset][classifier]): 41 | results[dataset][classifier][method] = {} 42 | try: 43 | if dataset == "VOC": 44 | data_path = data_base_path / "VOC2007" 45 | segmentations_path = VOC_segmentations_path 46 | if classifier == "vgg16": 47 | model_path = vgg16_voc_checkpoint 48 | elif classifier == "resnet50": 49 | model_path = resnet50_voc_checkpoint 50 | elif dataset == "COCO": 51 | data_path = data_base_path / "COCO2014" 52 | segmentations_path = COCO_segmentations_path 53 | if classifier == "vgg16": 54 | model_path = vgg16_coco_checkpoint 55 | elif classifier == "resnet50": 56 | model_path = resnet50_coco_checkpoint 57 | 58 | d_f1_25,d_f1_50,d_f1_75,c_f1,a_f1s, acc, aucs, d_IOUs, c_IOU, sal, over, background_c, mask_c, sr = compute_numbers(data_path=data_path, 59 | masks_path=masks_path, 60 | segmentations_path=segmentations_path, 61 | dataset_name=dataset, 62 | model_name=classifier, 63 | model_path=model_path, 64 | method=method) 65 | 66 | 67 | d = {} 68 | d["d_f1_25"] = d_f1_25 69 | d["d_f1_50"] = d_f1_50 70 | d["d_f1_75"] = d_f1_75 71 | d["d_f1"] = ((np.array(d_f1_25) + np.array(d_f1_50) + np.array(d_f1_75)) /3).tolist() 72 | d["c_f1"] = c_f1 73 | d["a_f1s"] = a_f1s 74 | d["acc"] = acc 75 | d["aucs"] = aucs 76 | d["d_IOUs"] = d_IOUs 77 | d["c_IOU"] = c_IOU 78 | d["sal"] = sal 79 | d["over"] = over 80 | d["background_c"] = background_c 81 | d["mask_c"] = mask_c 82 | d["sr"] = sr 83 | results[dataset][classifier][method] = d 84 | print("Scores computed for: {} - {} - {}".format(dataset, classifier, method)) 85 | except: 86 | print("Cannot compute scores for: {} - {} - {}!".format(dataset, classifier, method)) 87 | 88 | np.savez("results.npz", results=results) -------------------------------------------------------------------------------- /src/evaluation/evaluate_segmentations.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import torchvision.transforms as T 3 | import numpy as np 4 | import os 5 | 6 | ################################# Change to your settings ################################# 7 | dataset = 'VOC' 8 | mask_dir = './masks/VOC_vgg16_explainer/' 9 | VOC_segmentations_directory = '../../datasets/VOC2007/VOCdevkit/VOC2007/SegmentationClass/' 10 | COCO_segmentations_directory = './coco_segmentations/' 11 | ########################################################################################### 12 | 13 | 14 | transformer = T.Compose([T.Resize(size=(224,224))]) 15 | 16 | n_files = len(os.listdir(mask_dir)) 17 | 18 | total_mask_coverage = 0.0 19 | total_non_seg_mask_coverage = 0.0 20 | total_overlap = 0.0 21 | for filename in os.listdir(mask_dir): 22 | if os.path.splitext(filename)[1] == '.npz': 23 | n_files -= 1 24 | continue 25 | 26 | if dataset == "VOC": 27 | segmentation_filename = VOC_segmentations_directory + os.path.splitext(filename)[0] + '.png' 28 | elif dataset == "COCO": 29 | segmentation_filename = COCO_segmentations_directory + filename 30 | else: 31 | raise Exception("Illegal dataset: " + dataset) 32 | 33 | if not os.path.isfile(segmentation_filename): 34 | n_files -= 1 35 | continue 36 | 37 | mask = np.array(Image.open(mask_dir + filename).convert('L')) / 255.0 38 | 39 | seg_mask = Image.open(segmentation_filename).convert('L') 40 | seg_mask = transformer(seg_mask) 41 | seg_mask = np.array(seg_mask) / 255.0 42 | if dataset == "VOC": 43 | seg_mask[seg_mask > 0] = 1 44 | 45 | seg_area = np.average(seg_mask) 46 | 47 | if seg_area > 0.0: 48 | seg_diff = np.clip(seg_mask - mask, a_min=0.0, a_max=None) 49 | mask_coverage = 1 - (np.average(seg_diff) * (1 / seg_area)) 50 | total_mask_coverage += mask_coverage 51 | 52 | if seg_area < 1.0: 53 | non_seg_diff = np.clip(mask - seg_mask, a_min=0.0, a_max=None) 54 | non_seg_mask_coverage = (np.average(non_seg_diff) * (1 / (1-seg_area))) 55 | total_non_seg_mask_coverage += non_seg_mask_coverage 56 | 57 | overlap = 1 - np.average(np.absolute(seg_mask - mask)) 58 | total_overlap += overlap 59 | 60 | mean_mask_coverage = total_mask_coverage / n_files 61 | mean_non_seg_mask_coverage = total_non_seg_mask_coverage / n_files 62 | mean_overlap = total_overlap / n_files 63 | 64 | print("Number of files: ", n_files) 65 | print("Mean mask coverage: ", mean_mask_coverage) 66 | print("Mean mask coverage for non-segmented parts: ", mean_non_seg_mask_coverage) 67 | print("Mean overlap between segmentation and mask: ", mean_overlap) 68 | -------------------------------------------------------------------------------- /src/evaluation/explainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | sys.path.insert(0, os.path.abspath("..")) 5 | 6 | from tqdm import tqdm 7 | from pathlib import Path 8 | from torchray.utils import get_device 9 | from timeit import default_timer 10 | 11 | from models.explainer_classifier import ExplainerClassifierModel 12 | from data.dataloader import * 13 | from utils.helper import * 14 | from utils.image_utils import * 15 | 16 | ############################## Change to your settings ############################## 17 | dataset = 'VOC' # one of: ['VOC', 'COCO'] 18 | data_base_path = '../../datasets/' 19 | classifier_type = 'vgg16' # one of: ['vgg16', 'resnet50'] 20 | explainer_classifier_checkpoint = '../checkpoints/explainer_vgg16_voc.ckpt' 21 | 22 | mode = 'seg' # one of: ['seg', 'classes'] 23 | if mode == 'seg': 24 | VOC_segmentations_directory = '../../datasets/VOC2007/VOCdevkit/VOC2007/SegmentationClass/' 25 | COCO_segmentations_directory = './coco_segmentations/' 26 | elif mode == 'classes': 27 | masks_for_classes = [4, 6, 7, 11, 14] 28 | ##################################################################################### 29 | 30 | # Set up data module 31 | if dataset == "VOC": 32 | num_classes = 20 33 | data_path = Path(data_base_path) / "VOC2007" 34 | data_module = VOCDataModule(data_path=data_path, test_batch_size=1) 35 | elif dataset == "COCO": 36 | num_classes = 91 37 | data_path = Path(data_base_path) / "COCO2014" 38 | data_module = COCODataModule(data_path=data_path, test_batch_size=1) 39 | else: 40 | raise Exception("Unknown dataset " + dataset) 41 | 42 | data_module.setup(stage = "test") 43 | 44 | model = ExplainerClassifierModel.load_from_checkpoint(explainer_classifier_checkpoint, num_classes=num_classes, dataset=dataset, classifier_type=classifier_type) 45 | device = get_device() 46 | model.to(device) 47 | model.eval() 48 | 49 | save_path = Path('masks/{}_{}_{}/'.format(dataset, classifier_type, "explainer")) 50 | if not os.path.isdir(save_path): 51 | os.makedirs(save_path) 52 | 53 | total_time = 0.0 54 | for batch in tqdm(data_module.test_dataloader()): 55 | image, annotations = batch 56 | image = image.to(device) 57 | 58 | filename = get_filename_from_annotations(annotations, dataset=dataset) 59 | targets = get_targets_from_annotations(annotations, dataset=dataset) 60 | 61 | if mode == 'seg': 62 | if dataset == "VOC": 63 | segmentation_filename = VOC_segmentations_directory + os.path.splitext(filename)[0] + '.png' 64 | elif dataset == "COCO": 65 | segmentation_filename = COCO_segmentations_directory + os.path.splitext(filename)[0] + '.png' 66 | 67 | if not os.path.exists(segmentation_filename): 68 | continue 69 | 70 | start_time = default_timer() 71 | _, _, mask, _, _ = model(image, targets) 72 | end_time = default_timer() 73 | total_time += end_time - start_time 74 | 75 | save_mask(mask, save_path / filename) 76 | save_masked_image(image, mask, save_path / "images" / filename) 77 | 78 | elif mode == 'classes': 79 | target_classes = [index for index, value in enumerate(targets[0]) if value == 1.0] 80 | intersection = set(target_classes) & set(masks_for_classes) 81 | if intersection: 82 | segmentations = model.explainer(image) 83 | for target_class in intersection: 84 | for mask_class in masks_for_classes: 85 | mask = segmentations[0][mask_class].sigmoid() 86 | 87 | save_mask(mask, save_path / "class_masks" 88 | / "target_class_{}".format(target_class) 89 | / "masks_for_class_{}".format(mask_class) 90 | / filename) 91 | 92 | if mode == 'seg': 93 | print("Total time for masking process of the Explainer with dataset {} and classifier {}: {} seconds".format(dataset, classifier_type, total_time)) 94 | 95 | -------------------------------------------------------------------------------- /src/evaluation/extremal_perturbations.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | sys.path.insert(0, os.path.abspath("..")) 5 | 6 | from pathlib import Path 7 | 8 | from eval_utils.compute_masks import compute_and_save_masks 9 | from data.dataloader import VOCDataModule, COCODataModule 10 | from models.classifier import VGG16ClassifierModel, Resnet50ClassifierModel 11 | 12 | ##################### Change to your settings ################################ 13 | dataset = "VOC" # one of: ["VOC", "COCO"] 14 | data_base_path = "../../datasets/" 15 | classifier_type = "vgg16" # one of: ["vgg16", "resnet50"] 16 | classifier_checkpoint = "../checkpoints/pretrained_classifiers/vgg16_voc.ckpt" 17 | VOC_segmentations_path = '../../datasets/VOC2007/VOCdevkit/VOC2007/SegmentationClass/' 18 | COCO_segmentations_path = './coco_segmentations/' 19 | ############################################################################## 20 | 21 | if dataset == "VOC": 22 | num_classes = 20 23 | data_path = Path(data_base_path) / 'VOC2007' 24 | data_module = VOCDataModule(data_path=data_path, test_batch_size=1) 25 | segmentations_path = Path(VOC_segmentations_path) 26 | elif dataset == "COCO": 27 | num_classes = 91 28 | data_path = Path(data_base_path) / 'COCO2014' 29 | data_module = COCODataModule(data_path=data_path, test_batch_size=1) 30 | segmentations_path = Path(COCO_segmentations_path) 31 | else: 32 | raise Exception("Unknown dataset: " + dataset) 33 | 34 | if classifier_type == "vgg16": 35 | classifier = VGG16ClassifierModel.load_from_checkpoint(classifier_checkpoint, num_classes=num_classes, dataset=dataset) 36 | elif classifier_type == "resnet50": 37 | classifier = Resnet50ClassifierModel.load_from_checkpoint(classifier_checkpoint, num_classes=num_classes, dataset=dataset) 38 | else: 39 | raise Exception("Unknown classifier type: " + classifier_type) 40 | 41 | for param in classifier.parameters(): 42 | param.requires_grad_(False) 43 | 44 | data_module.setup(stage = "test") 45 | 46 | save_path = Path('masks/{}_{}_{}/'.format(dataset, classifier_type, "extremal_perturbations")) 47 | 48 | compute_and_save_masks(model=classifier, data_module=data_module, path_segmentation=segmentations_path, path_masks=save_path, method="extremal_perturbations") -------------------------------------------------------------------------------- /src/evaluation/generate_coco_segmentations.py: -------------------------------------------------------------------------------- 1 | from pycocotools.coco import COCO 2 | from pathlib import Path 3 | from PIL import Image 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | import random 7 | import os 8 | import sys 9 | 10 | ############################## Change to your settings ############################## 11 | test_annotations_path = '../../datasets/COCO2014/annotations/instances_val2014.json' 12 | num_segmentations = 1000 13 | seg_dir = './coco_segmentations/' 14 | ##################################################################################### 15 | 16 | if not os.path.isdir(seg_dir): 17 | os.mkdir(seg_dir) 18 | 19 | annotation = Path(test_annotations_path) 20 | coco = COCO(annotation) 21 | keys = coco.imgs.keys() 22 | ids = list(random.sample(keys, len(keys))) 23 | for index in range(num_segmentations): 24 | img_id = ids[index] 25 | ann_ids = coco.getAnnIds(imgIds=img_id) 26 | coco_annotation = coco.loadAnns(ann_ids) 27 | img = coco.loadImgs(img_id)[0] 28 | path = img['file_name'] 29 | 30 | num_objects = len(coco_annotation) 31 | mask = np.zeros((img["height"], img["width"])) 32 | for i in range(num_objects): 33 | mask = np.maximum(coco.annToMask(coco_annotation[i]), mask) 34 | 35 | path_file = seg_dir + os.path.splitext(path)[0] 36 | plt.imsave(path_file + ".png", mask, cmap='gray',format="png") 37 | np.savez_compressed(path_file + ".npz", mask) 38 | 39 | 40 | 41 | -------------------------------------------------------------------------------- /src/evaluation/grad_cam.py: -------------------------------------------------------------------------------- 1 | import pytorch_lightning as pl 2 | import torch 3 | import os 4 | import sys 5 | 6 | sys.path.insert(0, os.path.abspath("..")) 7 | 8 | from pathlib import Path 9 | from timeit import default_timer 10 | 11 | from data.dataloader import * 12 | from utils.helper import * 13 | from utils.image_utils import * 14 | from models.classifier import * 15 | 16 | from pytorch_grad_cam import GradCAM 17 | from pytorch_grad_cam.utils.image import * 18 | 19 | ############################## Change to your settings ############################## 20 | dataset = 'VOC' # one of: ['VOC', 'COCO'] 21 | data_base_path = '../../datasets/' 22 | classifier_type = 'vgg16' # one of: ['vgg16', 'resnet50'] 23 | classifier_checkpoint = '../checkpoints/pretrained_classifiers/vgg16_voc.ckpt' 24 | 25 | mode = 'seg' # one of: ['seg', 'classes'] 26 | if mode == 'seg': 27 | VOC_segmentations_directory = '../../datasets/VOC2007/VOCdevkit/VOC2007/SegmentationClass/' 28 | COCO_segmentations_directory = './coco_segmentations/' 29 | elif mode == 'classes': 30 | masks_for_classes = [4, 6, 7, 11, 14] 31 | ##################################################################################### 32 | 33 | # Set up data module 34 | if dataset == "VOC": 35 | num_classes = 20 36 | data_path = Path(data_base_path) / "VOC2007" 37 | data_module = VOCDataModule(data_path=data_path, test_batch_size=1) 38 | elif dataset == "COCO": 39 | num_classes = 91 40 | data_path = Path(data_base_path) / "COCO2014" 41 | data_module = COCODataModule(data_path=data_path, test_batch_size=1) 42 | else: 43 | raise Exception("Unknown dataset " + dataset) 44 | 45 | save_path = Path('masks/{}_{}_{}/'.format(dataset, classifier_type, "grad_cam")) 46 | if not os.path.isdir(save_path): 47 | os.makedirs(save_path) 48 | 49 | total_time = 0.0 50 | 51 | class GradCAMModel(pl.LightningModule): 52 | def __init__(self, num_classes): 53 | super().__init__() 54 | 55 | self.use_cuda = (torch.cuda.device_count() > 0) 56 | # Set up model 57 | if classifier_type == "vgg16": 58 | self.model = VGG16ClassifierModel.load_from_checkpoint(classifier_checkpoint, num_classes=num_classes, dataset=dataset, fix_classifier_backbone=False) 59 | self.target_layer = self.model.feature_extractor[-1] 60 | elif classifier_type == "resnet50": 61 | self.model = Resnet50ClassifierModel.load_from_checkpoint(classifier_checkpoint, num_classes=num_classes, dataset=dataset, fix_classifier_backbone=False) 62 | self.target_layer = self.model.feature_extractor[-2][-1] 63 | else: 64 | raise Exception("Unknown classifier type " + classifier_type) 65 | 66 | self.cam = GradCAM(model=self.model, target_layer=self.target_layer, use_cuda=self.use_cuda) 67 | 68 | def forward(self, image, target): 69 | saliency = self.cam(input_tensor=image, target_category=target) 70 | 71 | return saliency 72 | 73 | def test_step(self, batch, batch_idx): 74 | torch.set_grad_enabled(True) 75 | image, annotations = batch 76 | targets = get_targets_from_annotations(annotations, dataset=dataset) 77 | filename = get_filename_from_annotations(annotations, dataset=dataset) 78 | 79 | assert(targets.size()[0] == 1) 80 | 81 | global total_time 82 | if mode == 'seg': 83 | if dataset == "VOC": 84 | segmentation_filename = VOC_segmentations_directory + os.path.splitext(filename)[0] + '.png' 85 | elif dataset == "COCO": 86 | segmentation_filename = COCO_segmentations_directory + os.path.splitext(filename)[0] + '.png' 87 | else: 88 | raise Exception("Illegal dataset: " + dataset) 89 | 90 | if not os.path.exists(segmentation_filename): 91 | return 92 | 93 | saliencies = torch.zeros(num_classes, 224, 224) 94 | start_time = default_timer() 95 | for class_index in range(num_classes): 96 | if targets[0][class_index] == 1.0: 97 | saliencies[class_index] = torch.tensor(self(image, class_index)[0, :]) 98 | 99 | saliency_map = saliencies.amax(dim=0) 100 | end_time = default_timer() 101 | total_time += end_time - start_time 102 | 103 | saliency_map.nan_to_num(nan=0.0) 104 | 105 | save_mask(saliency_map, save_path / filename) 106 | 107 | elif mode == 'classes': 108 | target_classes = [index for index, value in enumerate(targets[0]) if value == 1.0] 109 | intersection = set(target_classes) & set(masks_for_classes) 110 | if intersection: 111 | for target_class in intersection: 112 | for mask_class in masks_for_classes: 113 | saliency = torch.tensor(self(image, mask_class)[0, :]) 114 | saliency.nan_to_num(nan=0.0) 115 | 116 | save_mask(saliency, save_path / "class_masks" 117 | / "target_class_{}".format(target_class) 118 | / "masks_for_class_{}".format(mask_class) 119 | / filename) 120 | 121 | model = GradCAMModel(num_classes=num_classes) 122 | trainer = pl.Trainer(gpus=[0] if torch.cuda.is_available() else 0) 123 | trainer.test(model=model, datamodule=data_module) 124 | 125 | if mode == 'seg': 126 | print("Total time for masking process of GradCAM with dataset {} and classifier {}: {} seconds".format(dataset, classifier_type, total_time)) 127 | -------------------------------------------------------------------------------- /src/evaluation/guided_backprop.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | sys.path.insert(0, os.path.abspath("..")) 5 | 6 | from pathlib import Path 7 | 8 | from eval_utils.compute_masks import compute_and_save_masks 9 | from data.dataloader import VOCDataModule, COCODataModule 10 | from models.classifier import VGG16ClassifierModel, Resnet50ClassifierModel 11 | 12 | ##################### Change to your settings ################################ 13 | dataset = "VOC" # one of: ["VOC", "COCO"] 14 | data_base_path = "../../datasets/" 15 | classifier_type = "vgg16" # one of: ["vgg16", "resnet50"] 16 | classifier_checkpoint = "../checkpoints/pretrained_classifiers/vgg16_voc.ckpt" 17 | VOC_segmentations_path = '../../datasets/VOC2007/VOCdevkit/VOC2007/SegmentationClass/' 18 | COCO_segmentations_path = './coco_segmentations/' 19 | ############################################################################## 20 | 21 | if dataset == "VOC": 22 | num_classes = 20 23 | data_path = Path(data_base_path) / 'VOC2007' 24 | data_module = VOCDataModule(data_path=data_path, test_batch_size=1) 25 | segmentations_path = Path(VOC_segmentations_path) 26 | elif dataset == "COCO": 27 | num_classes = 91 28 | data_path = Path(data_base_path) / 'COCO2014' 29 | data_module = COCODataModule(data_path=data_path, test_batch_size=1) 30 | segmentations_path = Path(COCO_segmentations_path) 31 | else: 32 | raise Exception("Unknown dataset: " + dataset) 33 | 34 | if classifier_type == "vgg16": 35 | classifier = VGG16ClassifierModel.load_from_checkpoint(classifier_checkpoint, num_classes=num_classes, dataset=dataset) 36 | elif classifier_type == "resnet50": 37 | classifier = Resnet50ClassifierModel.load_from_checkpoint(classifier_checkpoint, num_classes=num_classes, dataset=dataset) 38 | else: 39 | raise Exception("Unknown classifier type: " + classifier_type) 40 | 41 | for param in classifier.parameters(): 42 | param.requires_grad_(False) 43 | 44 | data_module.setup(stage = "test") 45 | 46 | save_path = Path('masks/{}_{}_{}/'.format(dataset, classifier_type, "guided_backprop")) 47 | 48 | compute_and_save_masks(model=classifier, data_module=data_module, path_segmentation=segmentations_path, path_masks=save_path, method="guided_backprop") 49 | -------------------------------------------------------------------------------- /src/evaluation/print_mean_scores.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from statistics import mean 3 | 4 | with np.load("results.npz", allow_pickle=True) as file: 5 | results = file["results"].item() 6 | for dataset in results: 7 | print("\n" + dataset + "\n") 8 | for classifier in results[dataset]: 9 | print("\t" + classifier + "\n") 10 | for method in results[dataset][classifier]: 11 | print("\t\t" + method + "\n") 12 | for metric in results[dataset][classifier][method]: 13 | values = results[dataset][classifier][method][metric] 14 | try: 15 | print("\t\t\t{}: {}\n".format(metric, mean(values))) 16 | except: 17 | values = [mean(value) for value in values] 18 | print("\t\t\t{}: {}\n".format(metric, mean(values))) 19 | 20 | 21 | 22 | 23 | -------------------------------------------------------------------------------- /src/evaluation/rise.py: -------------------------------------------------------------------------------- 1 | import pytorch_lightning as pl 2 | import torch 3 | import os 4 | import sys 5 | 6 | sys.path.insert(0, os.path.abspath("..")) 7 | 8 | from pathlib import Path 9 | from timeit import default_timer 10 | 11 | from data.dataloader import * 12 | from utils.helper import * 13 | from utils.image_utils import * 14 | from models.classifier import * 15 | 16 | from torchray.attribution.rise import rise 17 | 18 | ############################## Change to your settings ############################## 19 | dataset = 'VOC' # one of: ['VOC', 'COCO'] 20 | data_base_path = '../../datasets/' 21 | classifier_type = 'vgg16' # one of: ['vgg16', 'resnet50'] 22 | classifier_checkpoint = '../checkpoints/pretrained_classifiers/vgg16_voc.ckpt' 23 | 24 | mode = 'seg' # one of: ['seg', 'classes'] 25 | if mode == 'seg': 26 | VOC_segmentations_directory = '../../datasets/VOC2007/VOCdevkit/VOC2007/SegmentationClass/' 27 | COCO_segmentations_directory = './coco_segmentations/' 28 | elif mode == 'classes': 29 | masks_for_classes = [4, 6, 7, 11, 14] 30 | ##################################################################################### 31 | 32 | # Set up data module 33 | if dataset == "VOC": 34 | num_classes = 20 35 | data_path = Path(data_base_path) / "VOC2007" 36 | data_module = VOCDataModule(data_path=data_path, test_batch_size=1) 37 | elif dataset == "COCO": 38 | num_classes = 91 39 | data_path = Path(data_base_path) / "COCO2014" 40 | data_module = COCODataModule(data_path=data_path, test_batch_size=1) 41 | else: 42 | raise Exception("Unknown dataset " + dataset) 43 | 44 | save_path = Path('masks/{}_{}_{}/'.format(dataset, classifier_type, "rise")) 45 | if not os.path.isdir(save_path): 46 | os.makedirs(save_path) 47 | 48 | total_time = 0.0 49 | 50 | class RISEModel(pl.LightningModule): 51 | def __init__(self, num_classes): 52 | super().__init__() 53 | # Set up model 54 | if classifier_type == "vgg16": 55 | self.model = VGG16ClassifierModel.load_from_checkpoint(classifier_checkpoint, num_classes=num_classes, dataset=dataset) 56 | elif classifier_type == "resnet50": 57 | self.model = Resnet50ClassifierModel.load_from_checkpoint(classifier_checkpoint, num_classes=num_classes, dataset=dataset) 58 | else: 59 | raise Exception("Unknown classifier type " + classifier_type) 60 | self.model.eval() 61 | for param in self.model.parameters(): 62 | param.requires_grad_(False) 63 | 64 | def forward(self, image): 65 | saliency = rise(self.model, image) 66 | 67 | return saliency 68 | 69 | def test_step(self, batch, batch_idx): 70 | image, annotations = batch 71 | targets = get_targets_from_annotations(annotations, dataset=dataset) 72 | filename = get_filename_from_annotations(annotations, dataset=dataset) 73 | 74 | global total_time 75 | if mode == 'seg': 76 | if dataset == "VOC": 77 | segmentation_filename = VOC_segmentations_directory + os.path.splitext(filename)[0] + '.png' 78 | elif dataset == "COCO": 79 | segmentation_filename = COCO_segmentations_directory + os.path.splitext(filename)[0] + '.png' 80 | else: 81 | raise Exception("Illegal dataset: " + dataset) 82 | 83 | if not os.path.exists(segmentation_filename): 84 | return 85 | assert(targets.size()[0] == 1) 86 | 87 | saliencies = torch.zeros(num_classes, 224, 224) 88 | start_time = default_timer() 89 | saliency = self(image) 90 | for class_index in range(num_classes): 91 | if targets[0][class_index] == 1.0: 92 | class_sal = saliency[:, class_index].squeeze() 93 | min_val = class_sal.min() 94 | max_val = class_sal.max() 95 | class_sal = class_sal - min_val 96 | class_sal = torch.mul(class_sal, 1 / (max_val - min_val)) 97 | class_sal = class_sal.clamp(0, 1) 98 | saliencies[class_index] = class_sal 99 | 100 | saliency_map = saliencies.amax(dim=0) 101 | end_time = default_timer() 102 | total_time += end_time - start_time 103 | 104 | save_mask(saliency_map, save_path / filename) 105 | 106 | elif mode == 'classes': 107 | target_classes = [index for index, value in enumerate(targets[0]) if value == 1.0] 108 | intersection = set(target_classes) & set(masks_for_classes) 109 | if intersection: 110 | saliency = self(image) 111 | for target_class in intersection: 112 | for mask_class in masks_for_classes: 113 | class_sal = saliency[:, mask_class].squeeze() 114 | min_val = class_sal.min() 115 | max_val = class_sal.max() 116 | class_sal = class_sal - min_val 117 | class_sal = torch.mul(class_sal, 1 / (max_val - min_val)) 118 | class_sal = class_sal.clamp(0, 1) 119 | 120 | save_mask(class_sal, save_path / "class_masks" 121 | / "target_class_{}".format(target_class) 122 | / "masks_for_class_{}".format(mask_class) 123 | / filename) 124 | 125 | model = RISEModel(num_classes=num_classes) 126 | trainer = pl.Trainer(gpus=[0] if torch.cuda.is_available() else 0) 127 | trainer.test(model=model, datamodule=data_module) 128 | 129 | if mode == 'seg': 130 | print("Total time for masking process of RISE with dataset {} and classifier {}: {} seconds".format(dataset, classifier_type, total_time)) 131 | -------------------------------------------------------------------------------- /src/evaluation/rt_saliency.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | sys.path.insert(0, os.path.abspath("..")) 5 | 6 | from tqdm import tqdm 7 | from pathlib import Path 8 | from torchray.utils import get_device 9 | from timeit import default_timer 10 | 11 | from models.explainer_classifier_rtsal import RTSalExplainerClassifierModel 12 | from data.dataloader import * 13 | from utils.helper import * 14 | from utils.image_utils import * 15 | 16 | ############################## Change to your settings ############################## 17 | dataset = 'VOC' # one of: ['VOC', 'COCO'] 18 | data_base_path = '../../datasets/' 19 | classifier_type = 'vgg16' # one of: ['vgg16', 'resnet50'] 20 | explainer_classifier_checkpoint = '../checkpoints/rt_saliency/rtsal_explainer_vgg16_voc.ckpt' 21 | 22 | mode = 'seg' # one of: ['seg', 'classes'] 23 | if mode == 'seg': 24 | VOC_segmentations_directory = '../../datasets/VOC2007/VOCdevkit/VOC2007/SegmentationClass/' 25 | COCO_segmentations_directory = './coco_segmentations/' 26 | elif mode == 'classes': 27 | masks_for_classes = [4, 6, 7, 11, 14] 28 | ##################################################################################### 29 | 30 | # Set up data module 31 | if dataset == "VOC": 32 | num_classes = 20 33 | data_path = Path(data_base_path) / "VOC2007" 34 | data_module = VOCDataModule(data_path=data_path, test_batch_size=1) 35 | elif dataset == "COCO": 36 | num_classes = 91 37 | data_path = Path(data_base_path) / "COCO2014" 38 | data_module = COCODataModule(data_path=data_path, test_batch_size=1) 39 | else: 40 | raise Exception("Unknown dataset " + dataset) 41 | 42 | data_module.setup(stage = "test") 43 | 44 | model = RTSalExplainerClassifierModel.load_from_checkpoint(explainer_classifier_checkpoint, num_classes=num_classes, dataset=dataset, classifier_type=classifier_type) 45 | device = get_device() 46 | model.to(device) 47 | model.eval() 48 | 49 | save_path = Path('masks/{}_{}_{}/'.format(dataset, classifier_type, "rt_saliency")) 50 | if not os.path.isdir(save_path): 51 | os.makedirs(save_path) 52 | 53 | total_time = 0.0 54 | for batch in tqdm(data_module.test_dataloader()): 55 | image, annotations = batch 56 | image = image.to(device) 57 | 58 | filename = get_filename_from_annotations(annotations, dataset=dataset) 59 | targets = get_targets_from_annotations(annotations, dataset=dataset) 60 | 61 | if mode == 'seg': 62 | if dataset == "VOC": 63 | segmentation_filename = VOC_segmentations_directory + os.path.splitext(filename)[0] + '.png' 64 | elif dataset == "COCO": 65 | segmentation_filename = COCO_segmentations_directory + os.path.splitext(filename)[0] + '.png' 66 | 67 | if not os.path.exists(segmentation_filename): 68 | continue 69 | 70 | start_time = default_timer() 71 | _, _, mask = model(image, targets) 72 | end_time = default_timer() 73 | total_time += end_time - start_time 74 | 75 | save_mask(mask, save_path / filename) 76 | save_masked_image(image, mask, save_path / "images" / filename) 77 | 78 | elif mode == 'classes': 79 | target_classes = [index for index, value in enumerate(targets[0]) if value == 1.0] 80 | intersection = set(target_classes) & set(masks_for_classes) 81 | if intersection: 82 | segmentations = model.explainer(image) 83 | for target_class in intersection: 84 | for mask_class in masks_for_classes: 85 | mask = segmentations[0][mask_class].sigmoid() 86 | 87 | save_mask(mask, save_path / "class_masks" 88 | / "target_class_{}".format(target_class) 89 | / "masks_for_class_{}".format(mask_class) 90 | / filename) 91 | 92 | if mode == 'seg': 93 | print("Total time for masking process of the Real-Time Saliency Explainer with dataset {} and classifier {}: {} seconds".format(dataset, classifier_type, total_time)) 94 | 95 | -------------------------------------------------------------------------------- /src/main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import datetime 4 | import pytorch_lightning as pl 5 | from pytorch_lightning.callbacks.early_stopping import EarlyStopping 6 | from pathlib import Path 7 | 8 | from data.dataloader import VOCDataModule, COCODataModule, CUB200DataModule 9 | from utils.argparser import get_parser, write_config_file 10 | from models.classifier import VGG16ClassifierModel, Resnet50ClassifierModel 11 | from models.explainer_classifier import ExplainerClassifierModel 12 | from models.interpretable_fcnn import InterpretableFCNN 13 | from models.explainer_classifier_rtsal import RTSalExplainerClassifierModel 14 | 15 | main_dir = Path(os.path.dirname(os.path.abspath(__file__))) 16 | 17 | parser = get_parser() 18 | args = parser.parse_args() 19 | if args.arg_log: 20 | write_config_file(args) 21 | 22 | pl.seed_everything(args.seed) 23 | 24 | # Set up Logging 25 | if args.use_tensorboard_logger: 26 | log_dir = "tb_logs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S") 27 | logger = pl.loggers.TensorBoardLogger(log_dir, name="NN Explainer") 28 | else: 29 | logger = False 30 | 31 | # Set up data module 32 | if args.dataset == "VOC": 33 | data_path = main_dir / args.data_base_path / 'VOC2007' 34 | data_module = VOCDataModule( 35 | data_path=data_path, train_batch_size=args.train_batch_size, val_batch_size=args.val_batch_size, 36 | test_batch_size=args.test_batch_size, use_data_augmentation=args.use_data_augmentation 37 | ) 38 | num_classes = 20 39 | elif args.dataset == "COCO": 40 | data_path = main_dir / args.data_base_path / 'COCO2014' 41 | data_module = COCODataModule( 42 | data_path=data_path, train_batch_size=args.train_batch_size, val_batch_size=args.val_batch_size, 43 | test_batch_size=args.test_batch_size, use_data_augmentation=args.use_data_augmentation 44 | ) 45 | num_classes = 91 46 | elif args.dataset == "CUB": 47 | data_path = main_dir / args.data_base_path / 'CUB200' 48 | data_module = CUB200DataModule( 49 | data_path=data_path, train_batch_size=args.train_batch_size, val_batch_size=args.val_batch_size, 50 | test_batch_size=args.test_batch_size, use_data_augmentation=args.use_data_augmentation 51 | ) 52 | num_classes = 200 53 | else: 54 | raise Exception("Unknown dataset " + args.dataset) 55 | 56 | # Set up model 57 | if args.model_to_train == "explainer": 58 | model = ExplainerClassifierModel( 59 | num_classes=num_classes, dataset=args.dataset, classifier_type=args.classifier_type, classifier_checkpoint=args.classifier_checkpoint, fix_classifier=args.fix_classifier, learning_rate=args.learning_rate, 60 | class_mask_min_area=args.class_mask_min_area, class_mask_max_area=args.class_mask_max_area, entropy_regularizer=args.entropy_regularizer, use_mask_variation_loss=args.use_mask_variation_loss, 61 | mask_variation_regularizer=args.mask_variation_regularizer, use_mask_area_loss=args.use_mask_area_loss, mask_area_constraint_regularizer=args.mask_area_constraint_regularizer, 62 | mask_total_area_regularizer=args.mask_total_area_regularizer, ncmask_total_area_regularizer=args.ncmask_total_area_regularizer, metrics_threshold=args.metrics_threshold, 63 | save_masked_images=args.save_masked_images, save_masks=args.save_masks, 64 | save_all_class_masks=args.save_all_class_masks, save_path=args.save_path 65 | ) 66 | 67 | if args.explainer_classifier_checkpoint is not None: 68 | model = model.load_from_checkpoint( 69 | args.explainer_classifier_checkpoint, 70 | num_classes=num_classes, dataset=args.dataset, classifier_type=args.classifier_type, classifier_checkpoint=args.classifier_checkpoint, fix_classifier=args.fix_classifier, learning_rate=args.learning_rate, 71 | class_mask_min_area=args.class_mask_min_area, class_mask_max_area=args.class_mask_max_area, entropy_regularizer=args.entropy_regularizer, use_mask_variation_loss=args.use_mask_variation_loss, 72 | mask_variation_regularizer=args.mask_variation_regularizer, use_mask_area_loss=args.use_mask_area_loss, mask_area_constraint_regularizer=args.mask_area_constraint_regularizer, 73 | mask_total_area_regularizer=args.mask_total_area_regularizer, ncmask_total_area_regularizer=args.ncmask_total_area_regularizer, metrics_threshold=args.metrics_threshold, 74 | save_masked_images=args.save_masked_images, save_masks=args.save_masks, save_all_class_masks=args.save_all_class_masks, save_path=args.save_path 75 | ) 76 | elif args.model_to_train == "classifier": 77 | if args.classifier_type == "vgg16": 78 | model = VGG16ClassifierModel( 79 | num_classes=num_classes, dataset=args.dataset, learning_rate=args.learning_rate, use_imagenet_pretraining=args.use_imagenet_pretraining, 80 | fix_classifier_backbone=args.fix_classifier_backbone, metrics_threshold=args.metrics_threshold 81 | ) 82 | elif args.classifier_type == "resnet50": 83 | model = Resnet50ClassifierModel( 84 | num_classes=num_classes, dataset=args.dataset, learning_rate=args.learning_rate, use_imagenet_pretraining=args.use_imagenet_pretraining, 85 | fix_classifier_backbone=args.fix_classifier_backbone, metrics_threshold=args.metrics_threshold 86 | ) 87 | else: 88 | raise Exception("Unknown classifier type " + args.classifier_type) 89 | 90 | if args.classifier_checkpoint is not None: 91 | model = model.load_from_checkpoint( 92 | args.classifier_checkpoint, 93 | num_classes=num_classes, dataset=args.dataset, learning_rate=args.learning_rate, use_imagenet_pretraining=args.use_imagenet_pretraining, 94 | fix_classifier_backbone=args.fix_classifier_backbone, metrics_threshold=args.metrics_threshold 95 | ) 96 | elif args.model_to_train == "fcnn": 97 | model = InterpretableFCNN( 98 | num_classes=num_classes, dataset=args.dataset, learning_rate=args.learning_rate, class_mask_min_area=args.class_mask_min_area, class_mask_max_area=args.class_mask_max_area, 99 | use_mask_coherency_loss=args.use_mask_coherency_loss, use_mask_variation_loss=args.use_mask_variation_loss, mask_variation_regularizer=args.mask_variation_regularizer, 100 | use_mask_area_loss=args.use_mask_area_loss, mask_area_constraint_regularizer=args.mask_area_constraint_regularizer, mask_total_area_regularizer=args.mask_total_area_regularizer, 101 | ncmask_total_area_regularizer=args.ncmask_total_area_regularizer, metrics_threshold=args.metrics_threshold, 102 | save_masked_images=args.save_masked_images, save_masks=args.save_masks, save_all_class_masks=args.save_all_class_masks, save_path=args.save_path 103 | ) 104 | if args.fcnn_checkpoint is not None: 105 | model = model.load_from_checkpoint( 106 | args.fcnn_checkpoint, 107 | num_classes=num_classes, dataset=args.dataset, learning_rate=args.learning_rate, class_mask_min_area=args.class_mask_min_area, class_mask_max_area=args.class_mask_max_area, 108 | use_mask_coherency_loss=args.use_mask_coherency_loss, use_mask_variation_loss=args.use_mask_variation_loss, mask_variation_regularizer=args.mask_variation_regularizer, 109 | use_mask_area_loss=args.use_mask_area_loss, mask_area_constraint_regularizer=args.mask_area_constraint_regularizer, mask_total_area_regularizer=args.mask_total_area_regularizer, 110 | ncmask_total_area_regularizer=args.ncmask_total_area_regularizer, metrics_threshold=args.metrics_threshold, save_masked_images=args.save_masked_images, 111 | save_masks=args.save_masks, save_all_class_masks=args.save_all_class_masks, save_path=args.save_path 112 | ) 113 | elif args.model_to_train == "rtsal_explainer": 114 | model = RTSalExplainerClassifierModel( 115 | num_classes=num_classes, dataset=args.dataset, classifier_type=args.classifier_type, 116 | classifier_checkpoint=args.classifier_checkpoint, fix_classifier=args.fix_classifier, 117 | learning_rate=args.learning_rate, metrics_threshold=args.metrics_threshold, 118 | save_masked_images=args.save_masked_images, save_masks=args.save_masks, 119 | save_all_class_masks=args.save_all_class_masks, save_path=args.save_path 120 | ) 121 | if args.explainer_classifier_checkpoint is not None: 122 | model = model.load_from_checkpoint( 123 | args.explainer_classifier_checkpoint, 124 | num_classes=num_classes, dataset=args.dataset, classifier_type=args.classifier_type, 125 | classifier_checkpoint=args.classifier_checkpoint, fix_classifier=args.fix_classifier, 126 | learning_rate=args.learning_rate, metrics_threshold=args.metrics_threshold, 127 | save_masked_images=args.save_masked_images, save_masks=args.save_masks, 128 | save_all_class_masks=args.save_all_class_masks, save_path=args.save_path 129 | ) 130 | else: 131 | raise Exception("Unknown model type " + args.model_to_train) 132 | 133 | # Define Early Stopping condition 134 | early_stop_callback = EarlyStopping( 135 | monitor="val_loss", 136 | min_delta=args.early_stop_min_delta, 137 | patience=args.early_stop_patience, 138 | verbose=False, 139 | mode="min", 140 | ) 141 | 142 | trainer = pl.Trainer( 143 | logger = logger, 144 | callbacks = [early_stop_callback], 145 | gpus = [0] if torch.cuda.is_available() else 0, 146 | terminate_on_nan = True, 147 | checkpoint_callback = args.checkpoint_callback, 148 | ) 149 | 150 | if args.train_model: 151 | trainer.fit(model=model, datamodule=data_module) 152 | trainer.test() 153 | else: 154 | trainer.test(model=model, datamodule=data_module) 155 | -------------------------------------------------------------------------------- /src/models/classifier.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pytorch_lightning as pl 3 | 4 | from torch import nn 5 | from torch.optim import Adam 6 | from torchvision import models 7 | 8 | from utils.helper import get_targets_from_annotations 9 | from utils.metrics import MultiLabelMetrics 10 | 11 | class VGG16ClassifierModel(pl.LightningModule): 12 | def __init__(self, num_classes=20, dataset="VOC", learning_rate=1e-5, use_imagenet_pretraining=True, fix_classifier_backbone=True, metrics_threshold=0.0): 13 | super().__init__() 14 | 15 | self.setup_model(num_classes=num_classes, use_imagenet_pretraining=use_imagenet_pretraining, fix_classifier_backbone=fix_classifier_backbone) 16 | 17 | self.setup_losses(dataset=dataset) 18 | self.setup_metrics(num_classes=num_classes, metrics_threshold=metrics_threshold) 19 | 20 | self.dataset = dataset 21 | self.learning_rate = learning_rate 22 | 23 | def setup_model(self, num_classes, use_imagenet_pretraining, fix_classifier_backbone): 24 | backbone = models.vgg16(pretrained=use_imagenet_pretraining) 25 | 26 | layers = list(backbone.children())[:-1] 27 | 28 | self.feature_extractor = nn.Sequential(*layers[0]) 29 | self.avgpool = layers[1] 30 | 31 | if fix_classifier_backbone: 32 | self.feature_extractor.eval() 33 | 34 | for param in self.feature_extractor.parameters(): 35 | param.requires_grad = False 36 | 37 | self.avgpool.eval() 38 | for param in self.avgpool.parameters(): 39 | param.requires_grad = False 40 | 41 | self.classifier = nn.Sequential( 42 | nn.Linear(in_features=25088, out_features=4096, bias=True), 43 | nn.ReLU(inplace=True), 44 | nn.Dropout(p=0.5, inplace=False), 45 | nn.Linear(in_features=4096, out_features=4096, bias=True), 46 | nn.ReLU(inplace=True), 47 | nn.Dropout(p=0.5, inplace=False), 48 | nn.Linear(in_features=4096, out_features=num_classes, bias=True) 49 | ) 50 | 51 | def setup_losses(self, dataset): 52 | if dataset == "CUB": 53 | self.classification_loss_fn = nn.CrossEntropyLoss() 54 | else: 55 | self.classification_loss_fn = nn.BCEWithLogitsLoss() 56 | 57 | def setup_metrics(self, num_classes, metrics_threshold): 58 | self.train_metrics = MultiLabelMetrics(num_classes=num_classes, threshold=metrics_threshold) 59 | self.valid_metrics = MultiLabelMetrics(num_classes=num_classes, threshold=metrics_threshold) 60 | self.test_metrics = MultiLabelMetrics(num_classes=num_classes, threshold=metrics_threshold) 61 | 62 | def forward(self, x): 63 | x = self.feature_extractor(x) 64 | x = self.avgpool(x) 65 | x = torch.flatten(x, 1) 66 | 67 | x = self.classifier(x) 68 | 69 | return x 70 | 71 | def training_step(self, batch, batch_idx): 72 | x, y = batch 73 | logits = self(x) 74 | targets = get_targets_from_annotations(y, dataset=self.dataset) 75 | 76 | if self.dataset == "CUB": 77 | labels = targets.argmax(dim=1) 78 | loss = self.classification_loss_fn(logits, labels) 79 | else: 80 | loss = self.classification_loss_fn(logits, targets) 81 | 82 | self.log('train_loss', loss) 83 | self.train_metrics(logits, targets) 84 | 85 | return loss 86 | 87 | def training_epoch_end(self, outs): 88 | self.log('train_metrics', self.train_metrics.compute()) 89 | self.train_metrics.reset() 90 | 91 | def validation_step(self, batch, batch_idx): 92 | x, y = batch 93 | logits = self(x) 94 | targets = get_targets_from_annotations(y, dataset=self.dataset) 95 | 96 | if self.dataset == "CUB": 97 | labels = targets.argmax(dim=1) 98 | loss = self.classification_loss_fn(logits, labels) 99 | else: 100 | loss = self.classification_loss_fn(logits, targets) 101 | 102 | self.log('val_loss', loss) 103 | self.valid_metrics(logits, targets) 104 | 105 | def validation_epoch_end(self, outs): 106 | self.log('val_metrics', self.valid_metrics.compute(), prog_bar=True) 107 | self.valid_metrics.reset() 108 | 109 | def test_step(self, batch, batch_idx): 110 | x, y = batch 111 | logits = self(x) 112 | targets = get_targets_from_annotations(y, dataset=self.dataset) 113 | 114 | if self.dataset == "CUB": 115 | labels = targets.argmax(dim=1) 116 | loss = self.classification_loss_fn(logits, labels) 117 | else: 118 | loss = self.classification_loss_fn(logits, targets) 119 | 120 | self.log('test_loss', loss) 121 | self.test_metrics(logits, targets) 122 | 123 | def test_epoch_end(self, outs): 124 | self.log('test_metrics', self.test_metrics.compute(), prog_bar=True) 125 | self.test_metrics.save(model="classifier", classifier_type="vgg16", dataset=self.dataset) 126 | self.test_metrics.reset() 127 | 128 | def configure_optimizers(self): 129 | return Adam(self.parameters(), lr=self.learning_rate) 130 | 131 | class Resnet50ClassifierModel(pl.LightningModule): 132 | 133 | def __init__(self, num_classes=20, dataset="VOC", learning_rate=1e-5, use_imagenet_pretraining=True, fix_classifier_backbone=True, metrics_threshold=0.0): 134 | super().__init__() 135 | 136 | self.setup_model(num_classes=num_classes, use_imagenet_pretraining=use_imagenet_pretraining, fix_classifier_backbone=fix_classifier_backbone) 137 | 138 | self.setup_losses(dataset=dataset) 139 | self.setup_metrics(num_classes=num_classes, metrics_threshold=metrics_threshold) 140 | 141 | self.dataset = dataset 142 | self.learning_rate = learning_rate 143 | 144 | def setup_model(self, num_classes, use_imagenet_pretraining, fix_classifier_backbone): 145 | backbone = models.resnet50(pretrained=use_imagenet_pretraining) 146 | 147 | num_filters = backbone.fc.in_features 148 | layers = list(backbone.children())[:-1] 149 | self.feature_extractor = nn.Sequential(*layers) 150 | 151 | if fix_classifier_backbone: 152 | self.feature_extractor.eval() 153 | for param in self.feature_extractor.parameters(): 154 | param.requires_grad = False 155 | 156 | self.classifier = nn.Linear(in_features=num_filters, out_features=num_classes, bias=True) 157 | 158 | def setup_losses(self, dataset): 159 | if dataset == "CUB": 160 | self.classification_loss_fn = nn.CrossEntropyLoss() 161 | else: 162 | self.classification_loss_fn = nn.BCEWithLogitsLoss() 163 | 164 | def setup_metrics(self, num_classes, metrics_threshold): 165 | self.train_metrics = MultiLabelMetrics(num_classes=num_classes, threshold=metrics_threshold) 166 | self.valid_metrics = MultiLabelMetrics(num_classes=num_classes, threshold=metrics_threshold) 167 | self.test_metrics = MultiLabelMetrics(num_classes=num_classes, threshold=metrics_threshold) 168 | 169 | def forward(self, x): 170 | representations = self.feature_extractor(x).flatten(1) 171 | x = self.classifier(representations) 172 | 173 | return x 174 | 175 | def training_step(self, batch, batch_idx): 176 | x, y = batch 177 | logits = self(x) 178 | targets = get_targets_from_annotations(y, dataset=self.dataset) 179 | 180 | if self.dataset == "CUB": 181 | labels = targets.argmax(dim=1) 182 | loss = self.classification_loss_fn(logits, labels) 183 | else: 184 | loss = self.classification_loss_fn(logits, targets) 185 | 186 | self.log('train_loss', loss) 187 | self.train_metrics(logits, targets) 188 | 189 | return loss 190 | 191 | def training_epoch_end(self, outs): 192 | self.log('train_metrics', self.train_metrics.compute()) 193 | self.train_metrics.reset() 194 | 195 | def validation_step(self, batch, batch_idx): 196 | x, y = batch 197 | logits = self(x) 198 | targets = get_targets_from_annotations(y, dataset=self.dataset) 199 | 200 | if self.dataset == "CUB": 201 | labels = targets.argmax(dim=1) 202 | loss = self.classification_loss_fn(logits, labels) 203 | else: 204 | loss = self.classification_loss_fn(logits, targets) 205 | 206 | self.log('val_loss', loss) 207 | self.valid_metrics(logits, targets) 208 | 209 | def validation_epoch_end(self, outs): 210 | self.log('val_metrics', self.valid_metrics.compute(), prog_bar=True) 211 | self.valid_metrics.reset() 212 | 213 | def test_step(self, batch, batch_idx): 214 | x, y = batch 215 | logits = self(x) 216 | targets = get_targets_from_annotations(y, dataset=self.dataset) 217 | 218 | if self.dataset == "CUB": 219 | labels = targets.argmax(dim=1) 220 | loss = self.classification_loss_fn(logits, labels) 221 | else: 222 | loss = self.classification_loss_fn(logits, targets) 223 | 224 | self.log('test_loss', loss) 225 | self.test_metrics(logits, targets) 226 | 227 | def test_epoch_end(self, outs): 228 | self.log('test_metrics', self.test_metrics.compute(), prog_bar=True) 229 | self.test_metrics.save(model="classifier", classifier_type="resnet50", dataset=self.dataset) 230 | self.test_metrics.reset() 231 | 232 | def configure_optimizers(self): 233 | return Adam(self.parameters(), lr=self.learning_rate) -------------------------------------------------------------------------------- /src/models/explainer.py: -------------------------------------------------------------------------------- 1 | import pytorch_lightning as pl 2 | 3 | from torchvision import models 4 | 5 | class Deeplabv3Resnet50ExplainerModel(pl.LightningModule): 6 | def __init__(self, num_classes=20): 7 | super().__init__() 8 | self.explainer = models.segmentation.deeplabv3_resnet50(pretrained=False, num_classes=num_classes) 9 | 10 | def forward(self, x): 11 | x = self.explainer(x)['out'] 12 | return x 13 | 14 | 15 | -------------------------------------------------------------------------------- /src/models/explainer_classifier.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pytorch_lightning as pl 3 | 4 | from torch import nn 5 | from torch.optim import Adam 6 | from pathlib import Path 7 | 8 | from models.explainer import Deeplabv3Resnet50ExplainerModel 9 | from models.classifier import VGG16ClassifierModel, Resnet50ClassifierModel 10 | from utils.helper import get_targets_from_annotations, get_filename_from_annotations, extract_masks 11 | from utils.image_utils import save_mask, save_masked_image, save_all_class_masks 12 | from utils.loss import TotalVariationConv, ClassMaskAreaLoss, entropy_loss 13 | from utils.metrics import MultiLabelMetrics 14 | 15 | class ExplainerClassifierModel(pl.LightningModule): 16 | def __init__(self, num_classes=20, dataset="VOC", classifier_type="vgg16", classifier_checkpoint=None, fix_classifier=True, learning_rate=1e-5, class_mask_min_area=0.05, 17 | class_mask_max_area=0.3, entropy_regularizer=1.0, use_mask_variation_loss=True, mask_variation_regularizer=1.0, use_mask_area_loss=True, 18 | mask_area_constraint_regularizer=1.0, mask_total_area_regularizer=0.1, ncmask_total_area_regularizer=0.3, metrics_threshold=-1.0, 19 | save_masked_images=False, save_masks=False, save_all_class_masks=False, save_path="./results/"): 20 | 21 | super().__init__() 22 | 23 | self.setup_explainer(num_classes=num_classes) 24 | self.setup_classifier(classifier_type=classifier_type, classifier_checkpoint=classifier_checkpoint, fix_classifier=fix_classifier, num_classes=num_classes) 25 | 26 | self.setup_losses(dataset=dataset, class_mask_min_area=class_mask_min_area, class_mask_max_area=class_mask_max_area) 27 | self.setup_metrics(num_classes=num_classes, metrics_threshold=metrics_threshold) 28 | 29 | self.dataset = dataset 30 | self.classifier_type = classifier_type 31 | 32 | # Hyperparameters 33 | self.learning_rate = learning_rate 34 | self.entropy_regularizer = entropy_regularizer 35 | self.use_mask_variation_loss = use_mask_variation_loss 36 | self.mask_variation_regularizer = mask_variation_regularizer 37 | self.use_mask_area_loss = use_mask_area_loss 38 | self.mask_area_constraint_regularizer = mask_area_constraint_regularizer 39 | self.mask_total_area_regularizer = mask_total_area_regularizer 40 | self.ncmask_total_area_regularizer = ncmask_total_area_regularizer 41 | 42 | # Image display/save settings 43 | self.save_masked_images = save_masked_images 44 | self.save_masks = save_masks 45 | self.save_all_class_masks = save_all_class_masks 46 | self.save_path = save_path 47 | 48 | def setup_explainer(self, num_classes): 49 | self.explainer = Deeplabv3Resnet50ExplainerModel(num_classes=num_classes) 50 | 51 | def setup_classifier(self, classifier_type, classifier_checkpoint, fix_classifier, num_classes): 52 | if classifier_type == "vgg16": 53 | self.classifier = VGG16ClassifierModel(num_classes=num_classes) 54 | elif classifier_type == "resnet50": 55 | self.classifier = Resnet50ClassifierModel(num_classes=num_classes) 56 | else: 57 | raise Exception("Unknown classifier type " + classifier_type) 58 | 59 | if classifier_checkpoint is not None: 60 | self.classifier = self.classifier.load_from_checkpoint(classifier_checkpoint, num_classes=num_classes) 61 | if fix_classifier: 62 | self.classifier.freeze() 63 | 64 | def setup_losses(self, dataset, class_mask_min_area, class_mask_max_area): 65 | self.total_variation_conv = TotalVariationConv() 66 | 67 | if dataset == "CUB": 68 | self.classification_loss_fn = nn.CrossEntropyLoss() 69 | else: 70 | self.classification_loss_fn = nn.BCEWithLogitsLoss() 71 | 72 | self.class_mask_area_loss_fn = ClassMaskAreaLoss(min_area=class_mask_min_area, max_area=class_mask_max_area) 73 | 74 | def setup_metrics(self, num_classes, metrics_threshold): 75 | self.train_metrics = MultiLabelMetrics(num_classes=num_classes, threshold=metrics_threshold) 76 | self.valid_metrics = MultiLabelMetrics(num_classes=num_classes, threshold=metrics_threshold) 77 | self.test_metrics = MultiLabelMetrics(num_classes=num_classes, threshold=metrics_threshold) 78 | 79 | def forward(self, image, targets): 80 | segmentations = self.explainer(image) 81 | target_mask, non_target_mask = extract_masks(segmentations, targets) 82 | inversed_target_mask = torch.ones_like(target_mask) - target_mask 83 | 84 | masked_image = target_mask.unsqueeze(1) * image 85 | inversed_masked_image = inversed_target_mask.unsqueeze(1) * image 86 | 87 | logits_mask = self.classifier(masked_image) 88 | logits_inversed_mask = self.classifier(inversed_masked_image) 89 | 90 | return logits_mask, logits_inversed_mask, target_mask, non_target_mask, segmentations 91 | 92 | def training_step(self, batch, batch_idx): 93 | image, annotations = batch 94 | targets = get_targets_from_annotations(annotations, dataset=self.dataset) 95 | logits_mask, logits_inversed_mask, target_mask, non_target_mask, segmentations = self(image, targets) 96 | 97 | if self.dataset == "CUB": 98 | labels = targets.argmax(dim=1) 99 | classification_loss_mask = self.classification_loss_fn(logits_mask, labels) 100 | else: 101 | classification_loss_mask = self.classification_loss_fn(logits_mask, targets) 102 | 103 | classification_loss_inversed_mask = self.entropy_regularizer * entropy_loss(logits_inversed_mask) 104 | loss = classification_loss_mask + classification_loss_inversed_mask 105 | 106 | if self.use_mask_variation_loss: 107 | mask_variation_loss = self.mask_variation_regularizer * (self.total_variation_conv(target_mask) + self.total_variation_conv(non_target_mask)) 108 | loss += mask_variation_loss 109 | 110 | if self.use_mask_area_loss: 111 | mask_area_loss = self.mask_area_constraint_regularizer * self.class_mask_area_loss_fn(segmentations, targets) 112 | mask_area_loss += self.mask_total_area_regularizer * target_mask.mean() 113 | mask_area_loss += self.ncmask_total_area_regularizer * non_target_mask.mean() 114 | loss += mask_area_loss 115 | 116 | self.log('train_loss', loss) 117 | self.train_metrics(logits_mask, targets) 118 | 119 | return loss 120 | 121 | def training_epoch_end(self, outs): 122 | self.log('train_metrics', self.train_metrics.compute()) 123 | self.train_metrics.reset() 124 | 125 | def validation_step(self, batch, batch_idx): 126 | image, annotations = batch 127 | targets = get_targets_from_annotations(annotations, dataset=self.dataset) 128 | logits_mask, logits_inversed_mask, target_mask, non_target_mask, segmentations = self(image, targets) 129 | 130 | if self.dataset == "CUB": 131 | labels = targets.argmax(dim=1) 132 | classification_loss_mask = self.classification_loss_fn(logits_mask, labels) 133 | else: 134 | classification_loss_mask = self.classification_loss_fn(logits_mask, targets) 135 | 136 | classification_loss_inversed_mask = self.entropy_regularizer * entropy_loss(logits_inversed_mask) 137 | loss = classification_loss_mask + classification_loss_inversed_mask 138 | 139 | if self.use_mask_variation_loss: 140 | mask_variation_loss = self.mask_variation_regularizer * (self.total_variation_conv(target_mask) + self.total_variation_conv(non_target_mask)) 141 | loss += mask_variation_loss 142 | 143 | if self.use_mask_area_loss: 144 | mask_area_loss = self.mask_area_constraint_regularizer * self.class_mask_area_loss_fn(segmentations, targets) 145 | mask_area_loss += self.mask_total_area_regularizer * target_mask.mean() 146 | mask_area_loss += self.ncmask_total_area_regularizer * non_target_mask.mean() 147 | loss += mask_area_loss 148 | 149 | self.log('val_loss', loss) 150 | self.valid_metrics(logits_mask, targets) 151 | 152 | def validation_epoch_end(self, outs): 153 | self.log('val_metrics', self.valid_metrics.compute(), prog_bar=True) 154 | self.valid_metrics.reset() 155 | 156 | def test_step(self, batch, batch_idx): 157 | image, annotations = batch 158 | targets = get_targets_from_annotations(annotations, dataset=self.dataset) 159 | logits_mask, logits_inversed_mask, target_mask, non_target_mask, segmentations = self(image, targets) 160 | 161 | if self.save_masked_images and image.size()[0] == 1: 162 | filename = Path(self.save_path) / "masked_images" / get_filename_from_annotations(annotations, dataset=self.dataset) 163 | save_masked_image(image, target_mask, filename) 164 | 165 | if self.save_masks and image.size()[0] == 1: 166 | filename = get_filename_from_annotations(annotations, dataset=self.dataset) 167 | save_mask(target_mask, Path(self.save_path) / "masks" / filename) 168 | 169 | if self.save_all_class_masks and image.size()[0] == 1 and self.dataset == "VOC": 170 | filename = Path(self.save_path) / "all_class_masks" / get_filename_from_annotations(annotations, dataset=self.dataset) 171 | save_all_class_masks(image, segmentations, filename) 172 | 173 | if self.dataset == "CUB": 174 | labels = targets.argmax(dim=1) 175 | classification_loss_mask = self.classification_loss_fn(logits_mask, labels) 176 | else: 177 | classification_loss_mask = self.classification_loss_fn(logits_mask, targets) 178 | 179 | classification_loss_inversed_mask = self.entropy_regularizer * entropy_loss(logits_inversed_mask) 180 | loss = classification_loss_mask + classification_loss_inversed_mask 181 | 182 | if self.use_mask_variation_loss: 183 | mask_variation_loss = self.mask_variation_regularizer * (self.total_variation_conv(target_mask) + self.total_variation_conv(non_target_mask)) 184 | loss += mask_variation_loss 185 | 186 | if self.use_mask_area_loss: 187 | mask_area_loss = self.mask_area_constraint_regularizer * self.class_mask_area_loss_fn(segmentations, targets) 188 | mask_area_loss += self.mask_total_area_regularizer * target_mask.mean() 189 | mask_area_loss += self.ncmask_total_area_regularizer * non_target_mask.mean() 190 | loss += mask_area_loss 191 | 192 | self.log('test_loss', loss) 193 | self.test_metrics(logits_mask, targets) 194 | 195 | def test_epoch_end(self, outs): 196 | self.log('test_metrics', self.test_metrics.compute(), prog_bar=True) 197 | self.test_metrics.save(model="explainer", classifier_type=self.classifier_type, dataset=self.dataset) 198 | 199 | self.test_metrics.reset() 200 | 201 | def configure_optimizers(self): 202 | return Adam(self.parameters(), lr=self.learning_rate) 203 | -------------------------------------------------------------------------------- /src/models/explainer_classifier_rtsal.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import pytorch_lightning as pl 4 | 5 | from torch import nn 6 | from torch.optim import Adam 7 | from pathlib import Path 8 | 9 | from models.explainer import * 10 | from models.classifier import * 11 | from utils.helper import * 12 | from utils.rtsal_helper import generate_alternative_image 13 | from utils.image_utils import * 14 | from utils.rtsal_loss import RTSalTotalVariationConv, preserver_loss, destroyer_loss 15 | from utils.metrics import * 16 | 17 | class RTSalExplainerClassifierModel(pl.LightningModule): 18 | def __init__(self, num_classes=20, dataset="VOC", classifier_type="vgg16", classifier_checkpoint=None, 19 | fix_classifier=True, learning_rate=1e-5, metrics_threshold=-1.0, save_masked_images=False, 20 | save_masks=False, save_all_class_masks=False, save_path="./results/"): 21 | 22 | super().__init__() 23 | 24 | self.setup_explainer(num_classes=num_classes) 25 | self.setup_classifier(classifier_type=classifier_type, classifier_checkpoint=classifier_checkpoint, fix_classifier=fix_classifier, num_classes=num_classes) 26 | 27 | self.setup_losses() 28 | self.setup_metrics(num_classes=num_classes, metrics_threshold=metrics_threshold) 29 | 30 | self.dataset = dataset 31 | self.classifier_type = classifier_type 32 | 33 | # Hyperparameters 34 | self.learning_rate = learning_rate 35 | 36 | # Image display/save settings 37 | self.save_masked_images = save_masked_images 38 | self.save_masks = save_masks 39 | self.save_all_class_masks = save_all_class_masks 40 | self.save_path = save_path 41 | 42 | def setup_explainer(self, num_classes): 43 | self.explainer = Deeplabv3Resnet50ExplainerModel(num_classes=num_classes) 44 | 45 | def setup_classifier(self, classifier_type, classifier_checkpoint, fix_classifier, num_classes): 46 | if classifier_type == "vgg16": 47 | self.classifier = VGG16ClassifierModel(num_classes=num_classes) 48 | elif classifier_type == "resnet50": 49 | self.classifier = Resnet50ClassifierModel(num_classes=num_classes) 50 | else: 51 | raise Exception("Unknown classifier type " + classifier_type) 52 | 53 | if classifier_checkpoint is not None: 54 | self.classifier = self.classifier.load_from_checkpoint(classifier_checkpoint, num_classes=num_classes) 55 | if fix_classifier: 56 | self.classifier.freeze() 57 | 58 | def setup_losses(self): 59 | self.total_variation_conv = RTSalTotalVariationConv() 60 | 61 | def setup_metrics(self, num_classes, metrics_threshold): 62 | self.train_metrics = MultiLabelMetrics(num_classes=num_classes, threshold=metrics_threshold) 63 | self.valid_metrics = MultiLabelMetrics(num_classes=num_classes, threshold=metrics_threshold) 64 | self.test_metrics = MultiLabelMetrics(num_classes=num_classes, threshold=metrics_threshold) 65 | 66 | def forward(self, image, targets): 67 | segmentations = self.explainer(image) 68 | target_mask, _ = extract_masks(segmentations, targets) 69 | inversed_target_mask = torch.ones_like(target_mask) - target_mask 70 | 71 | alternative_image = generate_alternative_image(image) 72 | masked_image = target_mask.unsqueeze(1) * image + inversed_target_mask.unsqueeze(1) * alternative_image 73 | inversed_masked_image = inversed_target_mask.unsqueeze(1) * image + target_mask.unsqueeze(1) * alternative_image 74 | 75 | logits_mask = self.classifier(masked_image) 76 | logits_inversed_mask = self.classifier(inversed_masked_image) 77 | 78 | return logits_mask, logits_inversed_mask, target_mask 79 | 80 | def training_step(self, batch, batch_idx): 81 | image, annotations = batch 82 | targets = get_targets_from_annotations(annotations, dataset=self.dataset) 83 | logits_mask, logits_inversed_mask, target_mask = self(image, targets) 84 | 85 | tv_loss = self.total_variation_conv(target_mask) 86 | av_loss = target_mask.mean() 87 | pres_loss = preserver_loss(logits_mask, targets) 88 | destr_loss = destroyer_loss(logits_inversed_mask, targets) 89 | 90 | loss = 10*tv_loss + 1.0*av_loss + pres_loss + 5*torch.pow(destr_loss, 0.3) # lambda regularisers taken from Dabkowski et. al. 2017, second regulariser has been tuned from 0.001 to 1.0 91 | 92 | self.log('loss', loss) 93 | self.train_metrics(logits_mask, targets) 94 | 95 | return loss 96 | 97 | def training_epoch_end(self, outs): 98 | self.log('train_metrics', self.train_metrics.compute()) 99 | self.train_metrics.reset() 100 | 101 | def validation_step(self, batch, batch_idx): 102 | image, annotations = batch 103 | targets = get_targets_from_annotations(annotations, dataset=self.dataset) 104 | logits_mask, logits_inversed_mask, target_mask = self(image, targets) 105 | 106 | tv_loss = self.total_variation_conv(target_mask) 107 | av_loss = target_mask.mean() 108 | pres_loss = preserver_loss(logits_mask, targets) 109 | destr_loss = destroyer_loss(logits_inversed_mask, targets) 110 | 111 | loss = 10*tv_loss + 1.0*av_loss + pres_loss + 5*torch.pow(destr_loss, 0.3) # lambda regularisers taken from Dabkowski et. al. 2017, second regularizer has been tuned from 0.001 to 1.0 112 | 113 | self.log('val_loss', loss) 114 | self.valid_metrics(logits_mask, targets) 115 | 116 | def validation_epoch_end(self, outs): 117 | self.log('val_metrics', self.valid_metrics.compute(), prog_bar=True) 118 | self.valid_metrics.reset() 119 | 120 | def test_step(self, batch, batch_idx): 121 | image, annotations = batch 122 | targets = get_targets_from_annotations(annotations, dataset=self.dataset) 123 | logits_mask, logits_inversed_mask, target_mask = self(image, targets) 124 | 125 | if self.save_masked_images and image.size()[0] == 1: 126 | filename = Path(self.save_path) / "masked_images" / get_filename_from_annotations(annotations, dataset=self.dataset) 127 | save_masked_image(image, target_mask, filename) 128 | 129 | if self.save_masks and image.size()[0] == 1: 130 | filename = get_filename_from_annotations(annotations, dataset=self.dataset) 131 | 132 | ### Commented code was used to produce exactly those masks that correspond to segmentation groundtruths ### 133 | #if self.dataset == "VOC": 134 | # segmentation_filename = '/scratch/snx3000/sstalder/VOCData/VOCdevkit/VOC2007/SegmentationClass/' + os.path.splitext(filename)[0] + '.png' 135 | #elif self.dataset == "COCO": 136 | # segmentation_filename = './benchmark/coco_segmentations/' + filename 137 | 138 | #if not os.path.exists(segmentation_filename): 139 | # return 140 | 141 | save_mask(target_mask, Path(self.save_path) / "masks" / filename) 142 | 143 | tv_loss = self.total_variation_conv(target_mask) 144 | av_loss = target_mask.mean() 145 | pres_loss = preserver_loss(logits_mask, targets) 146 | destr_loss = destroyer_loss(logits_inversed_mask, targets) 147 | 148 | loss = 10*tv_loss + 1.0*av_loss + pres_loss + 5*torch.pow(destr_loss, 0.3) # lambda regularisers taken from Dabkowski et. al. 2017, second regularizer has been tuned from 0.001 to 1.0 149 | 150 | self.log('test_loss', loss) 151 | self.test_metrics(logits_mask, targets) 152 | 153 | def test_epoch_end(self, outs): 154 | self.log('test_metrics', self.test_metrics.compute(), prog_bar=True) 155 | self.test_metrics.save(model="rtsal_explainer", classifier_type=self.classifier_type, dataset=self.dataset) 156 | 157 | self.test_metrics.reset() 158 | 159 | def configure_optimizers(self): 160 | return Adam(self.parameters(), lr=self.learning_rate) 161 | -------------------------------------------------------------------------------- /src/utils/argparser.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from configargparse import ArgumentParser 3 | 4 | # This file contains the declaration of our argument parser 5 | 6 | # Needed to parse booleans from command line properly 7 | def str2bool(v): 8 | if isinstance(v, bool): 9 | return v 10 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 11 | return True 12 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 13 | return False 14 | else: 15 | raise argparse.ArgumentTypeError('Boolean value expected.') 16 | 17 | def get_parser(): 18 | parser = ArgumentParser(description='NN Explainer', formatter_class=argparse.ArgumentDefaultsHelpFormatter) 19 | parser.add('-c', '--config', is_config_file=True, help='config file path') 20 | parser.add('--arg_log', default=False, type=str2bool, help='save arguments to config file') 21 | 22 | # Dataset parameters 23 | parser.add_argument('--dataset', choices=['VOC', 'COCO', 'CUB'], default='VOC', type=str, help='which dataset to use') 24 | parser.add_argument('--data_base_path', default='../datasets/', type=str, help='Base bath of the datasets. Should contain subdirectories with the different datasets.') 25 | 26 | # Data processing parameters 27 | parser.add_argument('--train_batch_size', default=16, type=int, help='batch size used for training') 28 | parser.add_argument('--val_batch_size', default=16, type=int, help='batch size used for validation') 29 | parser.add_argument('--test_batch_size', default=16, type=int, help='batch size used for testing') 30 | parser.add_argument('--use_data_augmentation', default=False, type=str2bool, help='set to true to enable data augmentation on training images') 31 | 32 | # Trainer Parameters 33 | parser.add_argument('--seed', default=42, type=int, help='seed for all random number generators in pytorch, numpy, and python.random') 34 | parser.add_argument('--use_tensorboard_logger', default=False, type=str2bool, help='whether to use tensorboard') 35 | parser.add_argument('--checkpoint_callback', default=True, type=str2bool, help='if true, trained model will be automatically saved') 36 | 37 | # Early stopping Parameters 38 | parser.add_argument('--early_stop_min_delta', default=0.001, type=float, help='threshold for early stopping condition') 39 | parser.add_argument('--early_stop_patience', default=5, type=int, help='patience for early stopping to trigger') 40 | 41 | # General Model Parameters 42 | parser.add_argument('--train_model', default=True, type=str2bool, help='If True, specified model will be trained. If False, model will be tested.') 43 | parser.add_argument('--use_imagenet_pretraining', default=True, type=str2bool, help='If True, classifiers use a pretrained backbone from ImageNet pretraining') 44 | parser.add_argument('--fix_classifier_backbone', default=True, type=str2bool, help='Whether to fix the wait for the classifiers backbone') 45 | parser.add_argument('--fix_classifier', default=True, type=str2bool, help='If True, classifier is frozen. Strongly recommended for Explainer training.') 46 | parser.add_argument('--model_to_train', choices=['explainer', 'classifier', 'fcnn', 'rtsal_explainer'], default='explainer', type=str, help='which model architecture should be used for training or testing') 47 | parser.add_argument('--classifier_type', choices=['vgg16', 'resnet50'], default='vgg16', type=str, help='type of classifier architecture to use') 48 | parser.add_argument('--explainer_classifier_checkpoint', default=None, type=str, help='Path to the .ckpt file that contains the weights of a pretrained explainer. Also contains the weights for the associated classifier.') 49 | parser.add_argument('--classifier_checkpoint', default=None, type=str, help='Path to the .ckpt file that contains the weights of a pretrained classifier.') 50 | parser.add_argument('--fcnn_checkpoint', default=None, type=str, help='Path to the .ckpt file that contains the weights of a pretrained self-explainer.') 51 | 52 | # Model-specific parameters 53 | parser.add_argument('--learning_rate', default=1e-5, type=float, help='learning rate used by the Adam optimizer') 54 | parser.add_argument('--use_mask_variation_loss', default=True, type=str2bool, help='whether to use variation loss on the mask.') 55 | parser.add_argument('--use_mask_area_loss', default=True, type=str2bool, help='whether to use area loss on the mask.') 56 | parser.add_argument('--use_mask_coherency_loss', default=True, type=str2bool, help='whether to use mask coherency loss (only for self-explainer architecture)') 57 | parser.add_argument('--entropy_regularizer', default=1.0, type=float, help='loss weighting term for entropy loss') 58 | parser.add_argument('--mask_variation_regularizer', default=1.0, type=float, help='loss weighting term for mask variation loss') 59 | parser.add_argument('--mask_area_constraint_regularizer', default=1.0, type=float, help='loss weighting term for overall mask area constraint (currently not used!)') 60 | parser.add_argument('--mask_total_area_regularizer', default=0.1, type=float, help='loss weighting term for the total area loss') 61 | parser.add_argument('--ncmask_total_area_regularizer', default=0.3, type=float, help='loss weighting term for the area constraints for the individual class segmentation masks') 62 | 63 | parser.add_argument('--target_mask_min_area', default=0.05, type=float, help='minimum area for the overall mask area constraint (currently not used!)') 64 | parser.add_argument('--target_mask_max_area', default=0.5, type=float, help='maximum area for the overall mask area constraint (currently not used!)') 65 | parser.add_argument('--class_mask_min_area', default=0.05, type=float, help='minimum area for the area constraints for the individual class segmentation masks') 66 | parser.add_argument('--class_mask_max_area', default=0.3, type=float, help='maximum area for the area constraints for the individual class segmentation masks') 67 | 68 | # Image display parameters 69 | parser.add_argument('--show_images', default=False, type=str2bool, help='If true, displays images and corresponding masked images during testing. Requires testing batch size to be 1.') 70 | parser.add_argument('--show_all_class_masks', default=False, type=str2bool, help='If true, displays individual class masks during testing. Requires VOC dataset. Requires testing batch size to be 1.') 71 | parser.add_argument('--show_max_activation_for_class_id', default=None, type=int, help='If true, highlights point of maximum activation for given class id. Requires testing batch size to be 1.') 72 | parser.add_argument('--save_masks', default=False, type=str2bool, help='If true, masks are saved to location specified by save_path (see below)') 73 | parser.add_argument('--save_masked_images', default=False, type=str2bool, help='If true, masked images are saved to location specified by save_path (see below)') 74 | parser.add_argument('--save_all_class_masks', default=False, type=str2bool, help='Unused.') 75 | parser.add_argument('--save_path', default='./results/', type=str, help='Path to where masks and/or masked images are saved if corresponding options are set to true.') 76 | 77 | # Metrics parameters 78 | parser.add_argument('--metrics_threshold', default=-1.0, type=float, help='Threshold for logit to count as positive vs. negative prediction. Use -1.0 for Explainer and 0.0 for classifier.') 79 | 80 | return parser 81 | 82 | def write_config_file(args, path='config.cfg'): 83 | with open(path, 'w') as f: 84 | for k in sorted(args.__dict__): 85 | if args.__dict__[k] is not None: 86 | print(k, '=', args.__dict__[k], file=f) 87 | -------------------------------------------------------------------------------- /src/utils/helper.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def get_targets_from_annotations(annotations, dataset, include_background_class=False): 4 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 5 | 6 | if dataset == "VOC": 7 | target_dict = get_target_dictionary(include_background_class) 8 | objects = [item['annotation']['object'] for item in annotations] 9 | 10 | batch_size = len(objects) 11 | target_vectors = torch.full((batch_size, 20), fill_value=0.0, device=device) 12 | for i in range(batch_size): 13 | object_names = [item['name'] for item in objects[i]] 14 | 15 | for name in object_names: 16 | index = target_dict[name] 17 | target_vectors[i][index] = 1.0 18 | 19 | elif dataset == "COCO": 20 | batch_size = len(annotations) 21 | target_vectors = torch.full((batch_size, 91), fill_value=0.0, device=device) 22 | for i in range(batch_size): 23 | targets = annotations[i]['targets'] 24 | for target in targets: 25 | target_vectors[i][target] = 1.0 26 | 27 | elif dataset == "CUB": 28 | batch_size = len(annotations) 29 | target_vectors = torch.full((batch_size, 200), fill_value=0.0, device=device) 30 | for i in range(batch_size): 31 | target = annotations[i]['target'] 32 | target_vectors[i][target] = 1.0 33 | 34 | return target_vectors 35 | 36 | # Only returns 1 filename, not an array of filenames 37 | # Ônly used with batch size 1 38 | def get_filename_from_annotations(annotations, dataset): 39 | if dataset == "VOC": 40 | filename = annotations[0]['annotation']['filename'] 41 | 42 | elif dataset == "COCO": 43 | filename = annotations[0]['filename'] 44 | 45 | elif dataset == "CUB": 46 | filename = annotations[0]['filename'] 47 | 48 | else: 49 | raise Exception("Unknown dataset: " + dataset) 50 | 51 | return filename 52 | 53 | def get_target_dictionary(include_background_class): 54 | if include_background_class: 55 | target_dict = {'background' : 0, 'aeroplane' : 1, 'bicycle' : 2, 'bird' : 3, 'boat' : 4, 'bottle' : 5, 'bus' : 6, 'car' : 7, 56 | 'cat' : 8, 'chair' : 9, 'cow' : 10, 'diningtable' : 11, 'dog' : 12, 'horse' : 13, 'motorbike' : 14, 'person' : 15, 57 | 'pottedplant' : 16, 'sheep' : 17, 'sofa' : 18, 'train' : 19, 'tvmonitor' : 20} 58 | else: 59 | target_dict = {'aeroplane' : 0, 'bicycle' : 1, 'bird' : 2, 'boat' : 3, 'bottle' : 4, 'bus' : 5, 'car' : 6, 60 | 'cat' : 7, 'chair' : 8, 'cow' : 9, 'diningtable' : 10, 'dog' : 11, 'horse' : 12, 'motorbike' : 13, 'person' : 14, 61 | 'pottedplant' : 15, 'sheep' : 16, 'sofa' : 17, 'train' : 18, 'tvmonitor' : 19} 62 | 63 | return target_dict 64 | 65 | def extract_masks(segmentations, target_vectors): 66 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 67 | 68 | batch_size, num_classes, h, w = segmentations.size() 69 | 70 | target_masks = torch.empty(batch_size, h, w, device=device) 71 | non_target_masks = torch.empty(batch_size, h, w, device=device) 72 | for i in range(batch_size): 73 | class_indices = target_vectors[i].eq(1.0) 74 | non_class_indices = target_vectors[i].eq(0.0) 75 | 76 | target_masks[i] = (segmentations[i][class_indices]).amax(dim=0) 77 | 78 | non_target_masks[i] = (segmentations[i][non_class_indices]).amax(dim=0) 79 | 80 | return target_masks.sigmoid(), non_target_masks.sigmoid() 81 | 82 | -------------------------------------------------------------------------------- /src/utils/image_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import io 3 | import os 4 | import torchvision.transforms as T 5 | 6 | import matplotlib.pyplot as plt 7 | import matplotlib as mpl 8 | 9 | from PIL import Image 10 | 11 | def show_max_activation(image, segmentations, class_id): 12 | nat_image = get_unnormalized_image(image) 13 | 14 | mask = segmentations[0][class_id].numpy() 15 | max_pixel_coords = np.unravel_index(mask.argmax(), mask.shape) 16 | 17 | circle = plt.Circle(max_pixel_coords[::-1], 10, fill=False, color='red') 18 | fig, ax = plt.subplots(1) 19 | ax.imshow(np.stack(nat_image.squeeze(), axis=2)) 20 | ax.add_patch(circle) 21 | 22 | plt.show() 23 | 24 | def save_mask(mask, filename): 25 | os.makedirs(os.path.dirname(filename), exist_ok=True) 26 | path_file = os.path.splitext(filename)[0] 27 | 28 | img = mask.detach().cpu().numpy().squeeze() 29 | plt.imsave(path_file + ".png", img, cmap='gray',format="png") 30 | np.savez_compressed(path_file + ".npz", img) 31 | 32 | def save_masked_image(image, mask, filename): 33 | os.makedirs(os.path.dirname(filename), exist_ok=True) 34 | path_file = os.path.splitext(filename)[0] 35 | 36 | nat_image = get_unnormalized_image(image) 37 | masked_nat_im = get_masked_image(nat_image, mask) 38 | 39 | plt.imsave(path_file + ".png", np.stack(masked_nat_im.detach().cpu().squeeze(), axis=2), format="png") 40 | 41 | def show_image_and_masked_image(image, mask): 42 | nat_image = get_unnormalized_image(image) 43 | masked_nat_im = get_masked_image(nat_image, mask) 44 | 45 | fig = get_fullscreen_figure_canvas("Image and masked image") 46 | fig.add_subplot(1, 2, 1) 47 | show_image(nat_image) 48 | 49 | fig.add_subplot(1, 2, 2) 50 | show_image(masked_nat_im) 51 | 52 | plt.show() 53 | 54 | def save_all_class_masks(image, segmentations, filename): 55 | os.makedirs(os.path.dirname(filename), exist_ok=True) 56 | filename = os.path.splitext(filename)[0] 57 | 58 | nat_image = get_unnormalized_image(image) 59 | all_class_masks = segmentations.transpose(0, 1).sigmoid() 60 | 61 | fig = get_fullscreen_figure_canvas("All class masks") 62 | for i in range(all_class_masks.size()[0]): #loop over all classes 63 | masked_nat_im = get_masked_image(nat_image, all_class_masks[i]) 64 | add_subplot_with_class_mask(fig, i) 65 | show_image(masked_nat_im) 66 | 67 | img_buf = io.BytesIO() 68 | plt.savefig(img_buf, format='png') 69 | 70 | im = Image.open(img_buf) 71 | im.save(filename, format='png') 72 | 73 | img_buf.close() 74 | 75 | def show_target_class_masks(image, segmentations, targets): 76 | nat_image = get_unnormalized_image(image) 77 | all_class_masks = segmentations.transpose(0, 1).sigmoid() 78 | 79 | fig = get_fullscreen_figure_canvas("Target class masks") 80 | for i in range(all_class_masks.size()[0]): #loop over all classes 81 | if targets[0][i] == 1.0: 82 | masked_nat_im = get_masked_image(nat_image, all_class_masks[i]) 83 | add_subplot_with_class_mask(fig, i) 84 | show_image(masked_nat_im) 85 | 86 | def show_most_likely_class_masks(image, segmentations, logits, threshold=0.0): 87 | nat_image = get_unnormalized_image(image) 88 | all_class_masks = segmentations.transpose(0, 1).sigmoid() 89 | 90 | fig = get_fullscreen_figure_canvas("Predicted class masks") 91 | for i in range(all_class_masks.size()[0]): #loop over all classes 92 | if logits[0][i] >= threshold: 93 | masked_nat_im = get_masked_image(nat_image, all_class_masks[i]) 94 | add_subplot_with_class_mask(fig, i) 95 | show_image(masked_nat_im) 96 | 97 | def get_unnormalized_image(image): 98 | inverse_transform = T.Compose([T.Normalize(mean = [ 0., 0., 0. ], std = [ 1/0.229, 1/0.224, 1/0.225 ]), 99 | T.Normalize(mean = [ -0.485, -0.456, -0.406 ], std = [ 1., 1., 1. ])]) 100 | 101 | nat_image = inverse_transform(image) 102 | 103 | return nat_image 104 | 105 | def get_masked_image(image, mask): 106 | masked_image = mask.unsqueeze(1) * image 107 | 108 | return masked_image 109 | 110 | def get_fullscreen_figure_canvas(title): 111 | mpl.rcParams["figure.figsize"] = (40,40) 112 | fig = plt.figure() 113 | fig.suptitle(title) 114 | 115 | return fig 116 | 117 | def add_subplot_with_class_mask(fig, class_id): 118 | target_labels = get_target_labels(include_background_class=False) 119 | 120 | axis = fig.add_subplot(4, 5, class_id+1) 121 | axis.get_xaxis().set_visible(False) 122 | axis.get_yaxis().set_visible(False) 123 | axis.title.set_text(target_labels[class_id]) 124 | 125 | def show_image(image): 126 | plt.imshow(np.stack(image.squeeze(), axis=2)) 127 | 128 | def get_target_labels(include_background_class): 129 | if include_background_class: 130 | targets = ['background', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 131 | 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', 'person', 132 | 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor'] 133 | else: 134 | targets = ['aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 135 | 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', 'person', 136 | 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor'] 137 | 138 | return targets 139 | -------------------------------------------------------------------------------- /src/utils/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import pytorch_lightning as pl 4 | 5 | from torch import nn 6 | 7 | class TotalVariationConv(pl.LightningModule): 8 | def __init__(self): 9 | super().__init__() 10 | 11 | weights_right_variance = torch.tensor([[0.0, 0.0, 0.0], 12 | [0.0, 1.0, -1.0], 13 | [0.0, 0.0, 0.0]], device=self.device).view(1, 1, 3, 3) 14 | 15 | weights_down_variance = torch.tensor([[0.0, 0.0, 0.0], 16 | [0.0, 1.0, 0.0], 17 | [0.0, -1.0, 0.0]], device=self.device).view(1, 1, 3, 3) 18 | 19 | self.variance_right_filter = nn.Conv2d(in_channels=1, out_channels=1, 20 | kernel_size=3, padding=1, padding_mode='reflect', groups=1, bias=False) 21 | self.variance_right_filter.weight.data = weights_right_variance 22 | self.variance_right_filter.weight.requires_grad = False 23 | 24 | self.variance_down_filter = nn.Conv2d(in_channels=1, out_channels=1, 25 | kernel_size=3, padding=1, padding_mode='reflect', groups=1, bias=False) 26 | self.variance_down_filter.weight.data = weights_down_variance 27 | self.variance_down_filter.weight.requires_grad = False 28 | 29 | def forward(self, mask): 30 | variance_right = self.variance_right_filter(mask.unsqueeze(1)).abs() 31 | 32 | variance_down = self.variance_down_filter(mask.unsqueeze(1)).abs() 33 | 34 | total_variance = (variance_right + variance_down).mean() 35 | return total_variance 36 | 37 | class MaskAreaLoss(): 38 | def __init__(self, image_size=224, min_area=0.0, max_area=1.0): 39 | self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 40 | 41 | self.image_size = image_size 42 | self.min_area = min_area 43 | self.max_area = max_area 44 | 45 | assert(self.min_area >= 0.0 and self.min_area <= 1.0) 46 | assert(self.max_area >= 0.0 and self.max_area <= 1.0) 47 | assert(self.min_area <= self.max_area) 48 | 49 | def __call__(self, masks): 50 | batch_size = masks.size()[0] 51 | losses = torch.zeros(batch_size, device=self.device) 52 | for i in range(batch_size): 53 | mask = masks[i].flatten() 54 | sorted_mask, indices = mask.sort(descending=True) 55 | losses[i] += (self._min_mask_area_loss(sorted_mask) + self._max_mask_area_loss(sorted_mask)).mean() 56 | 57 | return losses.mean() 58 | 59 | def _min_mask_area_loss(self, sorted_mask): 60 | if (self.min_area == 0.0): 61 | return torch.tensor(0.0) 62 | 63 | ones_length = (int)(self.image_size * self.image_size * self.min_area) 64 | ones = torch.ones(ones_length, device=self.device) 65 | zeros = torch.zeros((self.image_size * self.image_size) - ones_length, device=self.device) 66 | ones_and_zeros = torch.cat((ones, zeros), dim=0) 67 | 68 | # [1, 1, 0, 0, 0] - [0.9, 0.9, 0.9, 0.5, 0.1] = [0.1, 0.1, -0.9, -0.5, -0.1] -> [0.1, 0.1, 0, 0, 0] 69 | loss = F.relu(ones_and_zeros - sorted_mask) 70 | 71 | return loss 72 | 73 | def _max_mask_area_loss(self, sorted_mask): 74 | if (self.max_area == 1.0): 75 | return torch.tensor(0.0) 76 | 77 | ones_length = (int)(self.image_size * self.image_size * self.max_area) 78 | ones = torch.ones(ones_length, device=self.device) 79 | zeros = torch.zeros((self.image_size * self.image_size) - ones_length, device=self.device) 80 | ones_and_zeros = torch.cat((ones, zeros), dim=0) 81 | 82 | # [0.9, 0.9, 0.9, 0.5, 0.1] - [1, 1, 1, 1, 0] = [-0.1, -0.1, -0.1, -0.5, 0.1] -> [0, 0, 0, 0, 0.1] 83 | loss = F.relu(sorted_mask - ones_and_zeros) 84 | 85 | return loss 86 | 87 | class ClassMaskAreaLoss(MaskAreaLoss): 88 | def __call__(self, segmentations, target_vectors): 89 | masks = segmentations.sigmoid() 90 | batch_size, num_classes, h, w = masks.size() 91 | 92 | losses = torch.zeros(batch_size, device=self.device) 93 | for i in range(batch_size): 94 | class_indices = target_vectors[i].eq(1.0) 95 | class_masks = masks[i][class_indices] 96 | for j in range(class_masks.size()[0]): 97 | mask = class_masks[j].flatten() 98 | sorted_mask, indices = mask.sort(descending=True) 99 | losses[i] += (self._min_mask_area_loss(sorted_mask) + self._max_mask_area_loss(sorted_mask)).mean() 100 | 101 | losses[i] = losses[i].mean() 102 | 103 | return losses.mean() 104 | 105 | def entropy_loss(logits): 106 | min_prob = 1e-16 107 | probs = F.softmax(logits, dim=-1).clamp(min=min_prob) 108 | log_probs = probs.log() 109 | entropy = (-probs * log_probs) 110 | entropy_loss = -entropy.mean() 111 | 112 | return entropy_loss 113 | 114 | 115 | 116 | -------------------------------------------------------------------------------- /src/utils/metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchmetrics 3 | 4 | class SingleLabelMetrics(torchmetrics.Metric): 5 | def __init__(self, num_classes): 6 | super().__init__() 7 | 8 | self.num_classes = num_classes 9 | 10 | self.add_state("true_positives", torch.zeros(num_classes)) 11 | self.add_state("false_positives", torch.zeros(num_classes)) 12 | self.add_state("true_negatives", torch.zeros(num_classes)) 13 | self.add_state("false_negatives", torch.zeros(num_classes)) 14 | 15 | def update(self, logits, labels): 16 | with torch.no_grad(): 17 | for i, batch_sample_logits in enumerate(logits): 18 | self.true_negatives += 1.0 19 | top_class_prediction = batch_sample_logits.argmax(-1) 20 | if labels[i] == top_class_prediction: 21 | self.true_positives[labels] += 1.0 22 | self.true_negatives[labels] -= 1.0 23 | else: 24 | self.false_negatives[labels] += 1.0 25 | self.false_positives[top_class_prediction] += 1.0 26 | self.true_negatives[labels] -= 1.0 27 | self.true_negatives[top_class_prediction] -= 1.0 28 | 29 | def compute(self): 30 | accuracy = ((self.true_positives + self.true_negatives) / (self.true_positives + self.true_negatives + self.false_positives + self.false_negatives)).mean() 31 | precision = (self.true_positives / (self.true_positives + self.false_positives)).mean() 32 | recall = (self.true_positives / (self.true_positives + self.false_negatives)).mean() 33 | f_score = ((2 * self.true_positives) / (2 * self.true_positives + self.false_positives + self.false_negatives)).mean() 34 | 35 | return {'Accuracy': accuracy.item(), 'Precision': precision.item(), 'Recall': recall.item(), 'F-Score': f_score.item()} 36 | 37 | class MultiLabelMetrics(torchmetrics.Metric): 38 | def __init__(self, num_classes, threshold): 39 | super().__init__() 40 | 41 | self.num_classes = num_classes 42 | self.threshold = threshold 43 | 44 | self.add_state("true_positives", torch.tensor(0.0)) 45 | self.add_state("false_positives", torch.tensor(0.0)) 46 | self.add_state("true_negatives", torch.tensor(0.0)) 47 | self.add_state("false_negatives", torch.tensor(0.0)) 48 | 49 | def update(self, logits, labels): 50 | with torch.no_grad(): 51 | for i, batch_sample_logits in enumerate(logits): 52 | for j in range(self.num_classes): 53 | if labels[i][j] == 1.0: 54 | if batch_sample_logits[j] >= self.threshold: 55 | self.true_positives += 1.0 56 | else: 57 | self.false_negatives += 1.0 58 | else: 59 | if batch_sample_logits[j] >= self.threshold: 60 | self.false_positives += 1.0 61 | else: 62 | self.true_negatives += 1.0 63 | 64 | def compute(self): 65 | self.accuracy = ((self.true_positives + self.true_negatives) / (self.true_positives + self.true_negatives + self.false_positives + self.false_negatives)) 66 | self.precision = (self.true_positives / (self.true_positives + self.false_positives)) 67 | self.recall = (self.true_positives / (self.true_positives + self.false_negatives)) 68 | self.f_score = ((2 * self.true_positives) / (2 * self.true_positives + self.false_positives + self.false_negatives)) 69 | 70 | return {'Accuracy': self.accuracy.item(), 'Precision': self.precision.item(), 'Recall': self.recall.item(), 'F-Score': self.f_score.item()} 71 | 72 | def save(self, model, classifier_type, dataset): 73 | f = open(model + "_" + classifier_type + "_" + dataset + "_" + "test_metrics.txt", "w") 74 | f.write("Accuracy: " + str(self.accuracy.item()) + "\n") 75 | f.write("Precision: " + str(self.precision.item()) + "\n") 76 | f.write("Recall: " + str(self.recall.item()) + "\n") 77 | f.write("F-Score: " + str(self.f_score.item())) 78 | f.close() 79 | 80 | ### BELOW ARE JUST UTILITY FUNCTIONS, NOT THE ONES USED FOR THE RESULTS IN THE PAPER/THESIS ### 81 | 82 | class MultiLabelTopPredictionAccuracy(torchmetrics.Metric): 83 | def __init__(self): 84 | super().__init__() 85 | 86 | self.add_state("correct", torch.tensor(0.0)) 87 | self.add_state("total", torch.tensor(0.0)) 88 | 89 | def update(self, logits, targets): 90 | with torch.no_grad(): 91 | for i, batch_sample_logits in enumerate(logits): 92 | self.total += 1.0 93 | top_class_prediction = batch_sample_logits.argmax(-1) 94 | if (targets[i][top_class_prediction] == 1.0): 95 | self.correct += 1.0 96 | 97 | def compute(self): 98 | return {'Top prediction accuracy': self.correct / self.total} 99 | 100 | class MultiLabelPrecision(torchmetrics.Metric): 101 | def __init__(self, num_classes, threshold): 102 | super().__init__() 103 | 104 | self.num_classes = num_classes 105 | self.threshold = threshold 106 | 107 | self.add_state("true_positives", torch.tensor(0.0)) 108 | self.add_state("false_positives", torch.tensor(0.0)) 109 | 110 | def update(self, logits, targets): 111 | with torch.no_grad(): 112 | for i, batch_sample_logits in enumerate(logits): 113 | for j in range(self.num_classes): 114 | if (batch_sample_logits[j] >= self.threshold): 115 | if (targets[i][j] == 1.0): 116 | self.true_positives += 1.0 117 | else: 118 | self.false_positives += 1.0 119 | 120 | def compute(self): 121 | return self.true_positives / (self.true_positives + self.false_positives) 122 | 123 | class MultiLabelRecall(torchmetrics.Metric): 124 | def __init__(self, num_classes, threshold): 125 | super().__init__() 126 | 127 | self.num_classes = num_classes 128 | self.threshold = threshold 129 | 130 | self.add_state("true_positives", torch.tensor(0.0)) 131 | self.add_state("false_negatives", torch.tensor(0.0)) 132 | 133 | def update(self, logits, targets): 134 | with torch.no_grad(): 135 | for i, batch_sample_logits in enumerate(logits): 136 | for j in range(self.num_classes): 137 | if (targets[i][j] == 1.0): 138 | if (batch_sample_logits[j] >= self.threshold): 139 | self.true_positives += 1.0 140 | else: 141 | self.false_negatives += 1.0 142 | 143 | -------------------------------------------------------------------------------- /src/utils/rtsal_helper.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | import math 4 | import torch.nn.functional as F 5 | 6 | ### Values taken from https://github.com/PiotrDabkowski/pytorch-saliency/blob/master/sal/utils/mask.py ### 7 | def generate_alternative_image(image): 8 | cuda = torch.cuda.is_available() 9 | 10 | options = ["blur", "color"] 11 | method = random.choice(options) 12 | 13 | if method == "blur": 14 | alternative_image = imsmooth(image, sigma=10) 15 | elif method == "color": 16 | n, c, _, _ = image.size() 17 | color_range = 0.66 18 | noise = torch.zeros_like(image).normal_(0.11) 19 | if cuda: 20 | alternative_image = noise + torch.Tensor(n, c, 1, 1).cuda().uniform_(-color_range/2., color_range/2.) 21 | else: 22 | alternative_image = noise + torch.Tensor(n, c, 1, 1).uniform_(-color_range/2., color_range/2.) 23 | else: 24 | raise Exception("Unknown option for generating alternative image") 25 | 26 | return alternative_image 27 | 28 | ### Taken from https://github.com/facebookresearch/TorchRay/blob/6a198ee61d229360a3def590410378d2ed6f1f06/torchray/utils.py ### 29 | def imsmooth(tensor, 30 | sigma, 31 | stride=1, 32 | padding=0, 33 | padding_mode='constant', 34 | padding_value=0): 35 | r"""Apply a 2D Gaussian filter to a tensor. 36 | The 2D filter itself is implementing by separating the 2D convolution in 37 | two 1D convolutions, first along the vertical direction and then along 38 | the horizontal one. Each 1D Gaussian kernel is given by: 39 | .. math:: 40 | f_i \propto \exp\left(-\frac{1}{2} \frac{i^2}{\sigma^2} \right), 41 | ~~~ i \in \{-W,\dots,W\}, 42 | ~~~ W = \lceil 4\sigma \rceil. 43 | This kernel is normalized to sum to one exactly. Given the latter, the 44 | function calls `torch.nn.functional.conv2d` 45 | to perform the actual convolution. Various padding parameters and the 46 | stride are passed to the latter. 47 | Args: 48 | tensor (:class:`torch.Tensor`): :math:`N\times C\times H\times W` 49 | image tensor. 50 | sigma (float): standard deviation of the Gaussian kernel. 51 | stride (int, optional): subsampling factor. Default: ``1``. 52 | padding (int, optional): extra padding. Default: ``0``. 53 | padding_mode (str, optional): ``'constant'``, ``'reflect'`` or 54 | ``'replicate'``. Default: ``'constant'``. 55 | padding_value (float, optional): constant value for the `constant` 56 | padding mode. Default: ``0``. 57 | Returns: 58 | :class:`torch.Tensor`: :math:`N\times C\times H\times W` tensor with 59 | the smoothed images. 60 | """ 61 | 62 | EPSILON_DOUBLE = torch.tensor(2.220446049250313e-16, dtype=torch.float64) 63 | EPSILON_SINGLE = torch.tensor(1.19209290E-07, dtype=torch.float32) 64 | SQRT_TWO_DOUBLE = torch.tensor(math.sqrt(2), dtype=torch.float32) 65 | SQRT_TWO_SINGLE = SQRT_TWO_DOUBLE.to(torch.float32) 66 | 67 | assert sigma >= 0 68 | width = math.ceil(4 * sigma) 69 | filt = (torch.arange(-width, 70 | width + 1, 71 | dtype=torch.float32, 72 | device=tensor.device) / 73 | (SQRT_TWO_SINGLE * sigma + EPSILON_SINGLE)) 74 | filt = torch.exp(-filt * filt) 75 | filt /= torch.sum(filt) 76 | num_channels = tensor.shape[1] 77 | width = width + padding 78 | if padding_mode == 'constant' and padding_value == 0: 79 | other_padding = width 80 | x = tensor 81 | else: 82 | # pad: (before, after) pairs starting from last dimension backward 83 | x = F.pad(tensor, 84 | (width, width, width, width), 85 | mode=padding_mode, 86 | value=padding_value) 87 | other_padding = 0 88 | padding = 0 89 | x = F.conv2d(x, 90 | filt.reshape((1, 1, -1, 1)).expand(num_channels, -1, -1, -1), 91 | padding=(other_padding, padding), 92 | stride=(stride, 1), 93 | groups=num_channels) 94 | x = F.conv2d(x, 95 | filt.reshape((1, 1, 1, -1)).expand(num_channels, -1, -1, -1), 96 | padding=(padding, other_padding), 97 | stride=(1, stride), 98 | groups=num_channels) 99 | return x 100 | -------------------------------------------------------------------------------- /src/utils/rtsal_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import pytorch_lightning as pl 4 | 5 | from torch import nn 6 | 7 | class RTSalTotalVariationConv(pl.LightningModule): 8 | def __init__(self): 9 | super().__init__() 10 | 11 | weights_right_variance = torch.tensor([[0.0, 0.0, 0.0], 12 | [0.0, 1.0, -1.0], 13 | [0.0, 0.0, 0.0]], device=self.device).view(1, 1, 3, 3) 14 | 15 | weights_down_variance = torch.tensor([[0.0, 0.0, 0.0], 16 | [0.0, 1.0, 0.0], 17 | [0.0, -1.0, 0.0]], device=self.device).view(1, 1, 3, 3) 18 | 19 | self.variance_right_filter = nn.Conv2d(in_channels=1, out_channels=1, 20 | kernel_size=3, padding=1, padding_mode='reflect', groups=1, bias=False) 21 | self.variance_right_filter.weight.data = weights_right_variance 22 | self.variance_right_filter.weight.requires_grad = False 23 | 24 | self.variance_down_filter = nn.Conv2d(in_channels=1, out_channels=1, 25 | kernel_size=3, padding=1, padding_mode='reflect', groups=1, bias=False) 26 | self.variance_down_filter.weight.data = weights_down_variance 27 | self.variance_down_filter.weight.requires_grad = False 28 | 29 | def forward(self, mask): 30 | variance_right = self.variance_right_filter(mask.unsqueeze(1)).square() 31 | 32 | variance_down = self.variance_down_filter(mask.unsqueeze(1)).square() 33 | 34 | total_variance = (variance_right + variance_down).mean() 35 | return total_variance 36 | 37 | ### Adaption from the preserver loss in Dabkowski et. al. 2017 to account for multi-class target labels ### 38 | def preserver_loss(logits, targets): 39 | probs = logits.sigmoid() 40 | 41 | batch_size, num_classes = probs.size() 42 | num_object_classes_in_batch = 0 43 | loss = 0.0 44 | for i in range(batch_size): 45 | for j in range(num_classes): 46 | if targets[i][j] == 1.0: 47 | num_object_classes_in_batch += 1 48 | loss -= torch.log(probs[i][j]) 49 | 50 | loss = loss / num_object_classes_in_batch 51 | 52 | return loss 53 | 54 | ### Adaption from the destroyer loss in Dabkowski et. al. 2017 to account for multi-class target labels ### 55 | def destroyer_loss(logits, targets): 56 | probs = logits.sigmoid() 57 | 58 | batch_size, num_classes = probs.size() 59 | num_object_classes_in_batch = 0 60 | loss = 0.0 61 | for i in range(batch_size): 62 | for j in range(num_classes): 63 | if targets[i][j] == 1.0: 64 | num_object_classes_in_batch += 1 65 | loss += probs[i][j] 66 | 67 | loss = loss / num_object_classes_in_batch 68 | 69 | return loss --------------------------------------------------------------------------------