├── .gitignore ├── LICENSE ├── LiNeS_example.ipynb ├── README.md ├── config ├── config.yaml └── method │ ├── average.yaml │ ├── consensus.yaml │ ├── mag_masking.yaml │ ├── single_task.yaml │ ├── sum.yaml │ ├── tall_mask.yaml │ ├── ties.yaml │ └── zeroshot.yaml ├── download_checkpoints.py ├── environment.yml ├── eval_single_task.py ├── finetune.py ├── main.py ├── results ├── editing_single_task │ └── .gitkeep ├── merging_multi_task │ └── .gitkeep ├── single_task │ ├── ViT-B-16 │ │ └── nonlinear_ft_accuracies.json │ ├── ViT-B-32 │ │ └── nonlinear_ft_accuracies.json │ └── ViT-L-14 │ │ └── nonlinear_ft_accuracies.json └── zero_shot │ ├── ViT-B-16_20tasks_zeroshot.json │ ├── ViT-B-32_20tasks_zeroshot.json │ └── ViT-L-14_20tasks_zeroshot.json └── src ├── datasets ├── __init__.py ├── cars.py ├── cifar10.py ├── cifar100.py ├── common.py ├── country211.py ├── dtd.py ├── emnist.py ├── eurosat.py ├── fashionmnist.py ├── fer2013.py ├── flowers102.py ├── food101.py ├── gtsrb.py ├── imagenet.py ├── kmnist.py ├── mnist.py ├── oxfordpets.py ├── pcam.py ├── registry.py ├── resisc45.py ├── sst2.py ├── stl10.py ├── sun397.py ├── svhn.py └── templates.py ├── eval ├── __init__.py ├── aggregation.py ├── eval.py └── eval_utils.py ├── models ├── __init__.py ├── heads.py ├── modeling.py └── task_vectors.py └── utils ├── __init__.py ├── args.py ├── distributed.py ├── logging.py ├── tallmask_utils.py ├── ties_utils.py ├── utils.py └── variables_and_paths.py /.gitignore: -------------------------------------------------------------------------------- 1 | checkpoints*/ 2 | 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | share/python-wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | MANIFEST 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .nox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *.cover 51 | *.py,cover 52 | .hypothesis/ 53 | .pytest_cache/ 54 | cover/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | db.sqlite3 64 | db.sqlite3-journal 65 | 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | .pybuilder/ 78 | target/ 79 | 80 | # Jupyter Notebook 81 | .ipynb_checkpoints 82 | 83 | # IPython 84 | profile_default/ 85 | ipython_config.py 86 | 87 | # pyenv 88 | # For a library or package, you might want to ignore these files since the code is 89 | # intended to run in multiple environments; otherwise, check them in: 90 | # .python-version 91 | 92 | # pipenv 93 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 94 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 95 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 96 | # install all needed dependencies. 97 | #Pipfile.lock 98 | 99 | # poetry 100 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 101 | # This is especially recommended for binary packages to ensure reproducibility, and is more 102 | # commonly ignored for libraries. 103 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 104 | #poetry.lock 105 | 106 | # pdm 107 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 108 | #pdm.lock 109 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 110 | # in version control. 111 | # https://pdm.fming.dev/#use-with-ide 112 | .pdm.toml 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ 163 | wandb 164 | logs 165 | checkpoints 166 | tall_masks 167 | *.pt -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 wangke97, nik-dim, gortizji 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # LiNeS 2 | 3 | This is a source code to reproduce the experiments for "[LiNeS: Post-training Layer Scaling Prevents Forgetting and Enhances Model Merging](https://arxiv.org/abs/2410.17146)" by Ke Wang*, Nikolaos Dimitriadis*, Alessandro Favero, Guillermo Ortiz-Jimenez, Francois Fleuret, and Pascal Frossard. 4 | 5 | Our paper proposes a post-training model editing method to mitigate catastrophic forgetting, with applications for improving model merging methods. 6 | This repo contains the following experiments: 7 | 1) applying LiNeS on the fine-tuned residual, improving fine-tuned model's performance on control tasks while preserving performance on the fine-tuned (target) task. 8 | 2) applying LiNeS for enhancing multi-task merging, improving the performance over multiple baseline merging methods. 9 | 10 | This repo is heavily based on the repo for [TALL-Masks](https://github.com/nik-dim/tall_masks). 11 | 12 | 13 | 14 | ## Dependencies 15 | 16 | To run the code, please install all its dependencies: 17 | ```sh 18 | conda env create 19 | conda activate lines 20 | ``` 21 | 22 | ## Checkpoints 23 | 24 | The checkpoints can be downloaded from the HuggingFace repo [`nik-dim/tall_masks`](https://huggingface.co/nik-dim/tall_masks). See the [`snapshot_download documentation`](https://huggingface.co/docs/huggingface_hub/v0.26.0/en/package_reference/file_download#huggingface_hub.snapshot_download) for more details. 25 | 26 | ```sh 27 | from huggingface_hub import snapshot_download 28 | 29 | # download the ViT-B-32 checkpoints including backbone, classification heads and tall masks 30 | snapshot_download(repo_id="nik-dim/tall_masks", allow_patterns="*32*") 31 | 32 | # download the ViT-B-16 checkpoints including backbone, classification heads and tall masks 33 | snapshot_download(repo_id="nik-dim/tall_masks", allow_patterns="*16*") 34 | 35 | # download the ViT-L-14 checkpoints including backbone, classification heads and tall masks 36 | snapshot_download(repo_id="nik-dim/tall_masks", allow_patterns="*14*") 37 | 38 | # download everything 39 | snapshot_download(repo_id="nik-dim/tall_masks") 40 | ``` 41 | 42 | ## Datasets 43 | Most datasets being used should be downloaded automatically with torchvision or huggingface. For the datasets requiring manual preparation, please follow the instructions in [this issue](https://github.com/mlfoundations/task_vectors/issues/1). Depending on the torchvision version, some issues might arise when downloading specific datasets like [here](https://github.com/basveeling/pcam/issues/4) or [here](https://github.com/pytorch/vision/issues/5662). In this case, using a different torchvision version might solve the issue. 44 | 45 | ## Evaluation 46 | Evaluation is performed with Hydra, please modify `model_location` and `data_location` in `config/config.yaml` before evaluation. Note that you can set different number of tasks by setting `num_tasks`. Then, the first `num_tasks` are going to be selected from the list defined in `src/utils/variables_and_paths.py`, which you can modify as well. 47 | 48 | ### 1) Editing the fine-tuned checkpoint 49 | 50 | We provide in [LiNeS_example.ipynb](https://github.com/wang-kee/LiNeS/blob/master/LiNeS_example.ipynb) an example for applying LiNeS to edit the fine-tuned checkpoint, such that it reduces the performance loss on the control datasets while preserving fine-tuned accuracy. 51 | 52 | Alternatively, you can run the following scripts: 53 | 54 | ```bash 55 | # Evaluate the zero-shot performance of the pre-trained model on the first 8 tasks 56 | python main.py model=ViT-B-32 num_tasks=8 method="zeroshot" 57 | 58 | # Evaluate the performance of fine-tuned model on the first 8 tasks; 59 | # task_index=0 indicates fine-tuned on the 0-th task 60 | python main.py model=ViT-B-32 num_tasks=8 method="single_task" method.task_index=0 61 | 62 | # Evaluate the performance of fine-tuned model (edited with LiNeS) on the first 8 tasks; 63 | # task_index=0 indicates fine-tuned on the 0-th task 64 | python main.py model=ViT-B-32 num_tasks=8 method="single_task" method.apply_lines=True method.task_index=0 65 | ``` 66 | 67 | The target and control tasks accuracy are separately reported when evaluating the fine-tuned checkpoints. Note that you can set different vaxlues to `method.tradeoff_target_weight` (set by default to 2) to select varying importance to target accuracy (for the trade-off between target and control task accuracy) when selecting the best hyper-parameter for LiNeS for evaluation on test set. 68 | 69 | ### 2) Improving multi-task merging baselines 70 | 71 | We apply LiNeS to enhance the baseline multi-task merging methods by scaling the multi-task vector. 72 | 73 | The following scirpts demonstrate the usage: 74 | ```bash 75 | # Evaluate with Task Arithmetic baseline 76 | python main.py model=ViT-B-32 num_tasks=8 method="sum" 77 | 78 | # Evaluate with Task Arithmetic baseline; enhanced with LiNeS 79 | python main.py model=ViT-B-32 num_tasks=8 method="sum" method.apply_lines=True 80 | 81 | # Evaluate with Ties-merging baseline 82 | python main.py model=ViT-B-32 num_tasks=8 method="ties" method.k=20 83 | 84 | # Evaluate with Ties-merging baseline; enhanced with LiNeS 85 | python main.py model=ViT-B-32 num_tasks=8 method="ties" method.k=20 method.apply_lines=True 86 | 87 | # Evaluate with Consensus merging baseline 88 | python main.py model=ViT-B-32 num_tasks=8 method="consensus" method.prun_thre_k=2 89 | 90 | # Evaluate with Consensus merging baseline; enhanced with LiNeS 91 | python main.py model=ViT-B-32 num_tasks=8 method="consensus" method.prun_thre_k=2 method.apply_lines=True 92 | ``` 93 | 94 | Notes: 95 | * Enhancing with LiNeS maintains the same hyper-parameter tuning costs compared to baseline methods. 96 | * You can select model in [ViT-B-32, ViT-L-14] and num_tasks in [8, 14, 20] to test different settings in the paper. 97 | * For consensus merging, you need to construct TALL-masks in advance, details in [this link](https://github.com/nik-dim/tall_masks). 98 | 99 | 100 | ## Other usage: 101 | 102 | ``` sh 103 | # Finetune on 2 GPUs 104 | python finetune.py --model=ViT-B-32 --world-size=2 105 | 106 | # Evaluate pre-trained model on single task 107 | python eval_single_task.py --model=ViT-B-32 --finetuning-mode=none 108 | 109 | # Evaluate fine-tuned model on single task 110 | python eval_single_task.py --model=ViT-B-32 --finetuning-mode=standard 111 | 112 | ``` 113 | 114 | ## Reference 115 | If you find this code useful, please cite the following paper: 116 | ```bibtex 117 | @article{wang2024lines, 118 | author = { 119 | Ke Wang, 120 | Nikolaos Dimitriadis, 121 | Alessandro Favero, 122 | Guillermo Ortiz-Jimenez, 123 | Fran\c{c}ois Fleuret, 124 | Pascal Frossard}, 125 | journal = {arXiv}, 126 | title = {{LiNeS: Post-training Layer Scaling Prevents Forgetting and Enhances Model Merging}}, 127 | year = {2024} 128 | } 129 | 130 | ``` 131 | -------------------------------------------------------------------------------- /config/config.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - _self_ 4 | - method: sum 5 | 6 | model: ViT-B-32 7 | num_tasks: 8 8 | specify_lambda: None 9 | 10 | DATASETS: '' 11 | DATASETS_VAL: '' 12 | 13 | # utilities 14 | cache_dir: None 15 | world_size: 1 16 | port: 12355 17 | n_eval_points: 21 18 | device: "cuda" 19 | batch_size: 128 20 | data_location: "/mnt/lts4/scratch/data" 21 | model_location: "/mnt/lts4/scratch/checkpoints/tall_mask_checkpoints/mount/model_checkpoints" 22 | 23 | wandb: 24 | project: task-vectors 25 | mode: online 26 | group: ${model} 27 | 28 | hydra: 29 | run: 30 | dir: ./logs/${model}/${method.name}/${now:%Y-%m-%d}/${now:%H-%M-%S} 31 | sweep: 32 | dir: ./logs/multirun/${now:%Y-%m-%d}/${now:%H-%M-%S} 33 | subdir: ${hydra.job.num} 34 | -------------------------------------------------------------------------------- /config/method/average.yaml: -------------------------------------------------------------------------------- 1 | name: average 2 | k: 100 3 | 4 | full_name: ${method.name}_k=${method.k} 5 | -------------------------------------------------------------------------------- /config/method/consensus.yaml: -------------------------------------------------------------------------------- 1 | name: consensus 2 | k: 100 3 | prun_thre_k: 2 4 | use_ties: false 5 | ties_agg: "sum" 6 | apply_lines: false 7 | 8 | full_name: ${method.name}_k=${method.k} 9 | -------------------------------------------------------------------------------- /config/method/mag_masking.yaml: -------------------------------------------------------------------------------- 1 | name: mag_masking 2 | k: 10 3 | 4 | full_name: ${method.name}_k=${method.k} -------------------------------------------------------------------------------- /config/method/single_task.yaml: -------------------------------------------------------------------------------- 1 | name: single_task 2 | k: 100 3 | apply_lines: false 4 | tradeoff_target_weight: 2.0 5 | 6 | full_name: ${method.name}_k=${method.k} 7 | task_index: 0 -------------------------------------------------------------------------------- /config/method/sum.yaml: -------------------------------------------------------------------------------- 1 | name: sum 2 | k: 100 3 | apply_lines: false 4 | 5 | full_name: ${method.name}_k=${method.k} 6 | -------------------------------------------------------------------------------- /config/method/tall_mask.yaml: -------------------------------------------------------------------------------- 1 | name: tall_mask 2 | k: 100 3 | use_ties: false 4 | ties_agg: "sum" 5 | load_mask: false 6 | 7 | full_name: ${method.name}_k=${method.k} -------------------------------------------------------------------------------- /config/method/ties.yaml: -------------------------------------------------------------------------------- 1 | name: ties 2 | k: 20 3 | apply_lines: false 4 | agg: mean 5 | 6 | full_name: ${method.name}_k=${method.k} 7 | -------------------------------------------------------------------------------- /config/method/zeroshot.yaml: -------------------------------------------------------------------------------- 1 | name: zeroshot 2 | k: 100 3 | apply_lines: false 4 | 5 | full_name: ${method.name}_k=${method.k} -------------------------------------------------------------------------------- /download_checkpoints.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import gdown 5 | 6 | # note: this downloading procedure will be removed, please use the downloading procedure (from huggingface) provided in readme instead. 7 | 8 | URLS = { 9 | "checkpoints": { 10 | "ViT-B-32": "16xmoeMqNf1JlICqjMUEgHiH6mHn0G1f0", 11 | "ViT-B-16": "1dmLqIwCYPqO0qsNEuflPhC5-_QodQHDv", 12 | "ViT-L-14": "1QFBcz79RqXUAkEVlYOItt6VJIcf1Ppq6", 13 | }, 14 | "tall_masks": { 15 | "ViT-B-32": "1jpqsurrAdD5bn9i7pRBsTq4E3ayPb0YK", 16 | "ViT-B-16": "1jYNsdeFz6vlwIl5s4T48zeTm7FdwTSL9", 17 | "ViT-L-14": "16GVDMhpScmM3zeyfubjmNRFLiFKkBE7p", 18 | }, 19 | } 20 | 21 | parser = argparse.ArgumentParser("Download checkpoints for Vision Transformer models") 22 | parser.add_argument( 23 | "--model", 24 | type=str, 25 | required=True, 26 | help="Model type to download", 27 | choices=["ViT-B-32", "ViT-B-16", "ViT-L-14"], 28 | ) 29 | 30 | parser.add_argument( 31 | "--kind", 32 | type=str, 33 | required=True, 34 | help="Kind of download: checkpoints refer to single-task fine-tuned models and tall_masks for the per-task binary" 35 | + " masks generated by the 'TALL-masks' method.", 36 | choices=["checkpoints", "tall_masks"], 37 | ) 38 | 39 | 40 | if __name__ == "__main__": 41 | args = parser.parse_args() 42 | url = URLS[args.kind][args.model] 43 | url = "https://drive.google.com/drive/u/1/folders/" + url 44 | gdown.download_folder(url, output=os.path.join(args.kind, args.model), quiet=False) 45 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: lines 2 | channels: 3 | - pytorch 4 | - nvidia 5 | - defaults 6 | dependencies: 7 | - _libgcc_mutex=0.1 8 | - _openmp_mutex=5.1 9 | - blas=1.0 10 | - brotlipy=0.7.0 11 | - bzip2=1.0.8 12 | - ca-certificates=2023.05.30 13 | - certifi=2023.5.7 14 | - cffi=1.15.1 15 | - charset-normalizer=2.0.4 16 | - cryptography=39.0.1 17 | - cuda=11.6.1 18 | - cuda-cccl=11.6.55 19 | - cuda-command-line-tools=11.6.2 20 | - cuda-compiler=11.6.2 21 | - cuda-cudart=11.6.55 22 | - cuda-cudart-dev=11.6.55 23 | - cuda-cuobjdump=11.6.124 24 | - cuda-cupti=11.6.124 25 | - cuda-cuxxfilt=11.6.124 26 | - cuda-driver-dev=11.6.55 27 | - cuda-gdb=12.1.105 28 | - cuda-libraries=11.6.1 29 | - cuda-libraries-dev=11.6.1 30 | - cuda-memcheck=11.8.86 31 | - cuda-nsight=12.1.105 32 | - cuda-nsight-compute=12.1.1 33 | - cuda-nvcc=11.6.124 34 | - cuda-nvdisasm=12.1.105 35 | - cuda-nvml-dev=11.6.55 36 | - cuda-nvprof=12.1.105 37 | - cuda-nvprune=11.6.124 38 | - cuda-nvrtc=11.6.124 39 | - cuda-nvrtc-dev=11.6.124 40 | - cuda-nvtx=11.6.124 41 | - cuda-nvvp=12.1.105 42 | - cuda-runtime=11.6.1 43 | - cuda-samples=11.6.101 44 | - cuda-sanitizer-api=12.1.105 45 | - cuda-toolkit=11.6.1 46 | - cuda-tools=11.6.1 47 | - cuda-visual-tools=11.6.1 48 | - ffmpeg=4.3 49 | - freetype=2.12.1 50 | - gds-tools=1.6.1.9 51 | - giflib=5.2.1 52 | - gmp=6.2.1 53 | - gnutls=3.6.15 54 | - idna=3.4 55 | - intel-openmp=2023.1.0 56 | - jpeg=9e 57 | - lame=3.100 58 | - lcms2=2.12 59 | - ld_impl_linux-64=2.38 60 | - lerc=3.0 61 | - libcublas=11.9.2.110 62 | - libcublas-dev=11.9.2.110 63 | - libcufft=10.7.1.112 64 | - libcufft-dev=10.7.1.112 65 | - libcufile=1.6.1.9 66 | - libcufile-dev=1.6.1.9 67 | - libcurand=10.3.2.106 68 | - libcurand-dev=10.3.2.106 69 | - libcusolver=11.3.4.124 70 | - libcusparse=11.7.2.124 71 | - libcusparse-dev=11.7.2.124 72 | - libdeflate=1.17 73 | - libffi=3.4.4 74 | - libgcc-ng=11.2.0 75 | - libgomp=11.2.0 76 | - libiconv=1.16 77 | - libidn2=2.3.4 78 | - libnpp=11.6.3.124 79 | - libnpp-dev=11.6.3.124 80 | - libnvjpeg=11.6.2.124 81 | - libnvjpeg-dev=11.6.2.124 82 | - libpng=1.6.39 83 | - libstdcxx-ng=11.2.0 84 | - libtasn1=4.19.0 85 | - libtiff=4.5.0 86 | - libunistring=0.9.10 87 | - libuuid=1.41.5 88 | - libwebp=1.2.4 89 | - libwebp-base=1.2.4 90 | - lz4-c=1.9.4 91 | - mkl=2023.1.0 92 | - mkl-service=2.4.0 93 | - mkl_fft=1.3.6 94 | - mkl_random=1.2.2 95 | - ncurses=6.4 96 | - nettle=3.7.3 97 | - nsight-compute=2023.1.1.4 98 | - numpy=1.24.3 99 | - numpy-base=1.24.3 100 | - openh264=2.1.1 101 | - openssl=1.1.1t 102 | - pillow=9.4.0 103 | - pip=23.0.1 104 | - pycparser=2.21 105 | - pyopenssl=23.0.0 106 | - pysocks=1.7.1 107 | - python=3.10.11 108 | - pytorch=1.13.1 109 | - pytorch-cuda=11.6 110 | - pytorch-mutex=1.0 111 | - readline=8.2 112 | - requests=2.29.0 113 | - setuptools=67.8.0 114 | - sqlite=3.41.2 115 | - tbb=2021.8.0 116 | - tk=8.6.12 117 | - torchaudio=0.13.1 118 | - torchvision=0.14.1 119 | - typing_extensions=4.5.0 120 | - tzdata=2023c 121 | - urllib3=1.26.16 122 | - wheel=0.38.4 123 | - xz=5.4.2 124 | - zlib=1.2.13 125 | - zstd=1.5.5 126 | - pip: 127 | - filelock==3.12.0 128 | - fsspec==2023.5.0 129 | - ftfy==6.1.1 130 | - huggingface-hub==0.21.2 131 | - open-clip-torch==2.10.1 132 | - packaging==23.1 133 | - protobuf==3.20.3 134 | - pyyaml==6.0 135 | - regex==2023.6.3 136 | - safetensors==0.3.1 137 | - scipy==1.10.1 138 | - sentencepiece==0.1.99 139 | - timm==0.9.2 140 | - wcwidth==0.2.6 141 | - datasets==2.19.0 142 | - wandb==0.16.6 143 | - hydra-core==1.3.2 144 | - h5py==3.11.0 145 | - gdown==5.1.0 146 | -------------------------------------------------------------------------------- /eval_single_task.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | from pprint import pprint 4 | 5 | import numpy as np 6 | 7 | from src.eval.eval import eval_single_dataset 8 | from src.models.task_vectors import NonLinearTaskVector 9 | from src.utils.args import parse_arguments 10 | from src.utils.variables_and_paths import get_finetuned_path, get_zeroshot_path 11 | 12 | args = parse_arguments() 13 | args.save_dir = os.path.join(args.model_location, args.model) 14 | if not os.path.exists(args.save_dir): 15 | os.makedirs(args.save_dir) 16 | print("*" * 100) 17 | 18 | pprint(args.__dict__, width=1) 19 | accuracies = {} 20 | 21 | # load pretrained checkpoint 22 | pretrained_checkpoint = get_zeroshot_path(args.model_location, "MNIST", args.model) 23 | 24 | # evaluate each task sequentially 25 | for dataset in [ 26 | "MNIST", 27 | "Cars", 28 | "DTD", 29 | "EuroSAT", 30 | "GTSRB", 31 | "RESISC45", 32 | "SUN397", 33 | "SVHN", 34 | # "PCAM", 35 | # "CIFAR100", 36 | # "STL10", 37 | # "OxfordIIITPet", 38 | # "Flowers102", 39 | # "FER2013", 40 | # "CIFAR10", 41 | # "Food101", 42 | # "RenderedSST2", 43 | # "EMNIST", 44 | # "FashionMNIST", 45 | # "KMNIST", 46 | ]: 47 | 48 | print("\n" * 3) 49 | print("*" * 100) 50 | print(f"Evaluating on {dataset}") 51 | 52 | # load finetuned checkpoint 53 | finetuned_checkpoint = get_finetuned_path(args.model_location, dataset, args.model) 54 | task_vector = NonLinearTaskVector(args.model, pretrained_checkpoint, finetuned_checkpoint) 55 | 56 | if args.finetuning_mode == "none": 57 | image_encoder = task_vector.apply_to(pretrained_checkpoint, scaling_coef=0.0) 58 | save_file = f"single_task_zeroshot_accuracies.json" 59 | elif args.finetuning_mode == "standard": 60 | image_encoder = task_vector.apply_to(pretrained_checkpoint, scaling_coef=1.0) 61 | save_file = f"single_task_nonlinear_ft_accuracies.json" 62 | 63 | for split in ["val", "test"]: 64 | print("=" * 100) 65 | print(f"Evaluating on {split} split.") 66 | eval_dataset = dataset if split == "test" else f"{dataset}Val" 67 | 68 | if eval_dataset not in accuracies: 69 | accuracies[eval_dataset] = {} 70 | accuracies[eval_dataset] = eval_single_dataset(image_encoder, eval_dataset, args)["top1"] 71 | print() 72 | 73 | directory = f"results/single_task/{args.model}" 74 | if not os.path.exists(directory): 75 | os.makedirs(directory) 76 | save_path = os.path.join(directory, save_file) 77 | with open(save_path, "a+") as f: 78 | f.write(json.dumps(accuracies, sort_keys=False, indent=4) + "\n") 79 | 80 | pprint(accuracies, width=1) 81 | print("File saved at: ", save_path) 82 | -------------------------------------------------------------------------------- /finetune.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | from omegaconf import open_dict 4 | 5 | import torch 6 | import wandb 7 | 8 | from src.datasets import get_dataloader, get_dataset, maybe_dictionarize 9 | from src.eval.eval import eval_single_dataset 10 | from src.models import ImageClassifier, ImageEncoder, get_classification_head 11 | from src.utils import initialize_wandb, parse_arguments 12 | from src.utils.distributed import ( 13 | cleanup_ddp, 14 | distribute_loader, 15 | is_main_process, 16 | setup_ddp, 17 | ) 18 | from src.utils.utils import LabelSmoothing, cosine_lr 19 | from src.utils.variables_and_paths import get_finetuned_path, get_zeroshot_path 20 | 21 | 22 | def finetune(rank, args): 23 | setup_ddp(rank, args.world_size, port=args.port) 24 | 25 | if is_main_process(): 26 | initialize_wandb(args) 27 | 28 | train_dataset = args.train_dataset 29 | 30 | ft_path = get_finetuned_path(args.model_location, train_dataset, args.model) 31 | zs_path = get_zeroshot_path(args.model_location, train_dataset, args.model) 32 | 33 | if os.path.exists(zs_path) and os.path.exists(ft_path): 34 | if is_main_process(): 35 | print(f"Skipping fine-tuning because {ft_path} exists.") 36 | return zs_path, ft_path 37 | 38 | image_encoder = ImageEncoder(args.model) 39 | 40 | classification_head = get_classification_head(args, train_dataset) 41 | model = ImageClassifier(image_encoder, classification_head) 42 | 43 | num_params = sum(p.numel() for p in model.parameters() if p.requires_grad) 44 | print(f"The toal number of trainable parameters is {num_params/1e6:.2f}M") 45 | 46 | model.freeze_head() 47 | model = model.cuda() 48 | 49 | preprocess_fn = model.train_preprocess 50 | print_every = 100 51 | 52 | dataset = get_dataset( 53 | train_dataset, 54 | preprocess_fn, 55 | location=args.data_location, 56 | batch_size=args.batch_size, 57 | ) 58 | data_loader = get_dataloader(dataset, is_train=True, args=args, image_encoder=None) 59 | num_batches = len(dataset.train_loader) 60 | 61 | # Distribute the data and model across the GPUs. 62 | ddp_loader = distribute_loader(data_loader) 63 | ddp_model = torch.nn.parallel.DistributedDataParallel( 64 | model, device_ids=[rank], find_unused_parameters=True, output_device=rank 65 | ) 66 | 67 | print("hello from process", rank) 68 | 69 | if args.ls > 0: 70 | loss_fn = LabelSmoothing(args.ls) 71 | else: 72 | loss_fn = torch.nn.CrossEntropyLoss() 73 | 74 | params = [p for p in ddp_model.parameters() if p.requires_grad] 75 | optimizer = torch.optim.AdamW(params, lr=args.lr, weight_decay=args.wd) 76 | scheduler = cosine_lr( 77 | optimizer, 78 | args.lr, 79 | args.warmup_length, 80 | args.epochs * num_batches // args.num_grad_accumulation, 81 | ) 82 | 83 | # Saving zero-shot model 84 | if is_main_process(): 85 | ckpdir = os.path.join(args.save_dir, train_dataset) 86 | os.makedirs(ckpdir, exist_ok=True) 87 | model_path = get_zeroshot_path(args.model_location, train_dataset, args.model) 88 | ddp_model.module.image_encoder.save(model_path) 89 | 90 | for epoch in range(args.epochs): 91 | ddp_model.train() 92 | 93 | for i, batch in enumerate(ddp_loader): 94 | start_time = time.time() 95 | 96 | step = i // args.num_grad_accumulation + epoch * num_batches // args.num_grad_accumulation 97 | 98 | batch = maybe_dictionarize(batch) 99 | inputs = batch["images"].cuda() 100 | labels = batch["labels"].cuda() 101 | data_time = time.time() - start_time 102 | 103 | logits = ddp_model(inputs) 104 | loss = loss_fn(logits, labels) 105 | loss.backward() 106 | 107 | if (i + 1) % args.num_grad_accumulation == 0: 108 | scheduler(step) 109 | 110 | torch.nn.utils.clip_grad_norm_(params, 1.0) 111 | optimizer.step() 112 | optimizer.zero_grad() 113 | 114 | batch_time = time.time() - start_time 115 | 116 | if args.checkpoint_every > 0 and step % args.checkpoint_every == 0 and is_main_process(): 117 | print("Saving checkpoint.") 118 | model_path = get_finetuned_path(args.model_location, train_dataset, args.model).replace( 119 | ".pt", f"_{step}.pt" 120 | ) 121 | ddp_model.module.image_encoder.save(model_path) 122 | 123 | if step % print_every == 0 and ((i + 1) % args.num_grad_accumulation == 0) and is_main_process(): 124 | percent_complete = 100 * i / len(ddp_loader) 125 | print( 126 | f"Train Epoch: {epoch} [{percent_complete:.0f}% {i}/{len(dataset.train_loader)}]\t" # noqa: E501 127 | f"Loss: {loss.item():.6f}\tData (t) {data_time:.3f}\tBatch (t) {batch_time:.3f}\t", # noqa: E501 128 | flush=True, 129 | ) 130 | wandb.log( 131 | { 132 | f"{train_dataset}/train/loss": loss.item(), 133 | "train/data_time": data_time, 134 | "train/batch_time": batch_time, 135 | } 136 | ) 137 | 138 | if is_main_process(): 139 | # We only need to evaluate the model on the first GPU. 140 | image_encoder = ddp_model.module.image_encoder 141 | test_accuracy = eval_single_dataset(image_encoder, train_dataset, args) 142 | 143 | if is_main_process(): 144 | ft_path = get_finetuned_path(args.model_location, train_dataset, args.model) 145 | zs_path = get_zeroshot_path(args.model_location, train_dataset, args.model) 146 | 147 | image_encoder.save(ft_path) 148 | return zs_path, ft_path 149 | 150 | cleanup_ddp() 151 | 152 | 153 | if __name__ == "__main__": 154 | 155 | # uncomment all the datasets for fine-tuning 156 | train_datasets = [ 157 | # "MNIST", 158 | # "Cars", 159 | "DTD", 160 | # "EuroSAT", 161 | # "GTSRB", 162 | # "RESISC45", 163 | # "SUN397", 164 | # "SVHN", 165 | # "CIFAR100", 166 | # "STL10", 167 | # "Flowers102", 168 | # "OxfordIIITPet", 169 | # "FER2013", 170 | # "PCAM", 171 | # "FashionMNIST", 172 | # "CIFAR10", 173 | # "Food101", 174 | # "RenderedSST2", 175 | # "KMNIST", 176 | # "EMNIST", 177 | ] 178 | epochs = { 179 | "Cars": 35, 180 | "DTD": 76, 181 | "EuroSAT": 12, 182 | "GTSRB": 11, 183 | "MNIST": 5, 184 | "RESISC45": 15, 185 | "SUN397": 14, 186 | "SVHN": 4, 187 | "CIFAR10": 6, 188 | "CIFAR100": 6, 189 | "STL10": 60, 190 | "Food101": 4, 191 | "Flowers102": 147, 192 | "FER2013": 10, 193 | "PCAM": 1, 194 | "OxfordIIITPet": 82, 195 | "RenderedSST2": 39, 196 | "EMNIST": 2, 197 | "FashionMNIST": 5, 198 | "KMNIST": 5, 199 | } 200 | 201 | for dataset in train_datasets: 202 | args = parse_arguments() 203 | args.lr = 1e-5 204 | args.epochs = epochs[dataset] 205 | args.train_dataset = dataset + "Val" 206 | 207 | args.save_di = "DTD_new.pt" 208 | 209 | args.save_dir = os.path.join(args.model_location, args.model) 210 | 211 | # We use gradient accumulation to simulate larger batch sizes if the model does not fit in memory. 212 | args.batch_size = 64 if args.model == "ViT-L-14" else 128 213 | args.num_grad_accumulation = 2 if args.model == "ViT-L-14" else 1 214 | 215 | print("=" * 100) 216 | print(f"Finetuning {args.model} on {dataset}") 217 | print("=" * 100) 218 | torch.multiprocessing.spawn(finetune, args=(args,), nprocs=args.world_size) 219 | # finetune(0, args) 220 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pprint import pprint 3 | 4 | import hydra 5 | import wandb 6 | from omegaconf import DictConfig, OmegaConf 7 | 8 | from src.eval.aggregation import create_task_vector 9 | from src.eval.eval_utils import perform_eval_with_merged_vector 10 | from src.utils.variables_and_paths import ALL_DATASETS 11 | 12 | 13 | @hydra.main(config_path="config", config_name="config", version_base="1.3") 14 | def my_app(cfg: DictConfig) -> None: 15 | 16 | if cfg.DATASETS == "": 17 | cfg.DATASETS = ALL_DATASETS[: cfg.num_tasks] 18 | else: 19 | cfg.num_tasks = len(cfg.DATASETS) 20 | cfg.DATASETS_VAL = [dataset + "Val" for dataset in cfg.DATASETS] 21 | cfg.data_location = os.path.expanduser(cfg.data_location) 22 | OmegaConf.set_struct(cfg, True) 23 | 24 | # set up experiment for WandB 25 | print(cfg.method.full_name) 26 | print() 27 | wandb.init( 28 | config=OmegaConf.to_container(cfg), 29 | mode=cfg.wandb.mode, 30 | project=cfg.wandb.project, 31 | group=cfg.wandb.group, 32 | dir="logs/", 33 | ) 34 | wandb.config.update({"method.full_name1": cfg.method.full_name}) 35 | wandb.config.update({"method.keep": cfg.method.k}) 36 | print(OmegaConf.to_yaml(cfg)) 37 | OmegaConf.set_struct(cfg, True) 38 | 39 | # create final task vector 40 | task_vector_dict, eval_masks = create_task_vector(cfg) 41 | print("*" * 100) 42 | print("*" * 37, "Created task vector dict", "*" * 37) 43 | print("*" * 100) 44 | print("\n" * 3) 45 | 46 | # perform evaluation and log results 47 | print("*" * 100) 48 | print("*" * 39, "Starting Evaluation.", "*" * 39) 49 | print("*" * 100) 50 | eval_results = perform_eval_with_merged_vector(cfg, task_vector_dict, eval_masks) 51 | pprint(eval_results, width=1) 52 | wandb.log(eval_results) 53 | wandb.finish(quiet=True) 54 | 55 | 56 | if __name__ == "__main__": 57 | my_app() 58 | -------------------------------------------------------------------------------- /results/editing_single_task/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wang-kee/LiNeS/43a7bdd564a16195b164abc2c64ec1856c59c211/results/editing_single_task/.gitkeep -------------------------------------------------------------------------------- /results/merging_multi_task/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wang-kee/LiNeS/43a7bdd564a16195b164abc2c64ec1856c59c211/results/merging_multi_task/.gitkeep -------------------------------------------------------------------------------- /results/single_task/ViT-B-16/nonlinear_ft_accuracies.json: -------------------------------------------------------------------------------- 1 | { 2 | "MNISTVal": 0.995, 3 | "MNIST": 0.9975, 4 | "CarsVal": 0.8648648648648649, 5 | "Cars": 0.8761348091033454, 6 | "DTDVal": 0.8191489361702128, 7 | "DTD": 0.8175531914893617, 8 | "EuroSATVal": 0.9974304068522484, 9 | "EuroSAT": 0.9995810640971932, 10 | "GTSRBVal": 0.9992492492492493, 11 | "GTSRB": 0.9900237529691212, 12 | "RESISC45Val": 0.9703703703703703, 13 | "RESISC45": 0.9693650793650793, 14 | "SUN397Val": 0.7758186397984886, 15 | "SUN397": 0.788110831234257, 16 | "SVHNVal": 0.9676, 17 | "SVHN": 0.9782191149354641, 18 | "CIFAR10Val": 0.9842, 19 | "CIFAR10": 0.9853, 20 | "CIFAR100Val": 0.9108, 21 | "CIFAR100": 0.9071, 22 | "Food101Val": 0.8956, 23 | "Food101": 0.9287920792079208, 24 | "STL10Val": 0.996, 25 | "STL10": 0.986, 26 | "OxfordIIITPetVal": 0.9483695652173914, 27 | "OxfordIIITPet": 0.9449441264649768, 28 | "RenderedSST2Val": 0.773121387283237, 29 | "RenderedSST2": 0.7918725974739155, 30 | "EMNISTVal": 0.9978, 31 | "EMNIST": 0.997925, 32 | "FashionMNISTVal": 0.9528, 33 | "FashionMNIST": 0.9565, 34 | "KMNISTVal": 0.9974, 35 | "KMNIST": 0.9879, 36 | "Flowers102Val": 0.9705882352941176, 37 | "Flowers102": 0.9375508212717515, 38 | "PCAMVal": 0.9786, 39 | "PCAM": 0.891815185546875, 40 | "FER2013": 0.7475619949846753, 41 | "FER2013Val": 0.7526132404181185 42 | } 43 | -------------------------------------------------------------------------------- /results/single_task/ViT-B-32/nonlinear_ft_accuracies.json: -------------------------------------------------------------------------------- 1 | { 2 | "CIFAR10": 0.9791, 3 | "CIFAR100": 0.8927, 4 | "CIFAR100Val": 0.8952, 5 | "CIFAR10Val": 0.9788, 6 | "Cars": 0.7818679268747668, 7 | "CarsVal": 0.7800982800982801, 8 | "DTD": 0.7920212765957447, 9 | "DTDVal": 0.7819148936170213, 10 | "EMNIST": 0.997825, 11 | "EMNISTVal": 0.9982, 12 | "EuroSAT": 0.9987431922915794, 13 | "EuroSATVal": 0.9987152034261242, 14 | "FashionMNIST": 0.953, 15 | "FashionMNISTVal": 0.9552, 16 | "Flowers102": 0.9050252073507887, 17 | "Flowers102Val": 0.9509803921568627, 18 | "Food101": 0.890970297029703, 19 | "Food101Val": 0.8462, 20 | "FER2013": 0.7294511005851212, 21 | "FER2013Val": 0.735191637630662, 22 | "GTSRB": 0.9914489311163895, 23 | "GTSRBVal": 0.9992492492492493, 24 | "KMNIST": 0.9863, 25 | "KMNISTVal": 0.995, 26 | "MNIST": 0.9969, 27 | "MNISTVal": 0.9956, 28 | "OxfordIIITPet": 0.9111474516216953, 29 | "OxfordIIITPetVal": 0.9239130434782609, 30 | "PCAM": 0.878753662109375, 31 | "PCAMVal": 0.9736, 32 | "RESISC45": 0.957936507936508, 33 | "RESISC45Val": 0.9624338624338624, 34 | "RenderedSST2": 0.7435475013728721, 35 | "RenderedSST2Val": 0.7138728323699421, 36 | "STL10": 0.97975, 37 | "STL10Val": 0.984, 38 | "SUN397": 0.7534508816120907, 39 | "SUN397Val": 0.7541561712846347, 40 | "SVHN": 0.9756837738168408, 41 | "SVHNVal": 0.9638 42 | } 43 | -------------------------------------------------------------------------------- /results/single_task/ViT-L-14/nonlinear_ft_accuracies.json: -------------------------------------------------------------------------------- 1 | { 2 | "Cars": 0.9226464370103221, 3 | "CarsVal": 0.9287469287469288, 4 | "DTD": 0.8425531914893617, 5 | "DTDVal": 0.8404255319148937, 6 | "EuroSATVal": 0.9991434689507495, 7 | "EuroSAT": 0.9991621281943862, 8 | "GTSRB": 0.9915281076801267, 9 | "GTSRBVal": 0.9996246246246246, 10 | "MNIST": 0.9979, 11 | "MNISTVal": 0.997, 12 | "RESISC45": 0.9731746031746031, 13 | "RESISC45Val": 0.973015873015873, 14 | "SUN397": 0.8208564231738035, 15 | "SUN397Val": 0.8241813602015113, 16 | "SVHN": 0.9812922556853104, 17 | "SVHNVal": 0.9724, 18 | "CIFAR100": 0.9333, 19 | "CIFAR100Val": 0.9368, 20 | "FER2013": 0.770130955697966, 21 | "FER2013Val": 0.7735191637630662, 22 | "Flowers102": 0.9786957228817694, 23 | "Flowers102Val": 0.9901960784313726, 24 | "OxfordIIITPet": 0.9553011719814664, 25 | "OxfordIIITPetVal": 0.9592391304347826, 26 | "PCAM": 0.902679443359375, 27 | "PCAMVal": 0.9804, 28 | "STL10": 0.995125, 29 | "STL10Val": 0.994, 30 | "CIFAR10": 0.992, 31 | "CIFAR10Val": 0.9906, 32 | "EMNIST": 0.9981, 33 | "EMNISTVal": 0.9978, 34 | "FashionMNIST": 0.9581, 35 | "FashionMNISTVal": 0.9566, 36 | "Food101": 0.9546138613861386, 37 | "Food101Val": 0.9306, 38 | "KMNIST": 0.9875, 39 | "KMNISTVal": 0.9974, 40 | "RenderedSST2": 0.8544755628775398, 41 | "RenderedSST2Val": 0.8251445086705202 42 | } -------------------------------------------------------------------------------- /results/zero_shot/ViT-B-16_20tasks_zeroshot.json: -------------------------------------------------------------------------------- 1 | { 2 | "test": { 3 | "Cars:top1": 0.6464370103220992, 4 | "DTD:top1": 0.451063829787234, 5 | "EuroSAT:top1": 0.5550900712191035, 6 | "GTSRB:top1": 0.433729216152019, 7 | "MNIST:top1": 0.5173, 8 | "RESISC45:top1": 0.6576190476190477, 9 | "SVHN:top1": 0.5198985863552551, 10 | "SUN397:top1": 0.6551133501259446, 11 | "STL10:top1": 0.9825, 12 | "OxfordIIITPet:top1": 0.8912510220768601, 13 | "Flowers102:top1": 0.7142624817043421, 14 | "CIFAR100:top1": 0.6695, 15 | "PCAM:top1": 0.54022216796875, 16 | "FER2013:top1": 0.35553078852047926, 17 | "CIFAR10:top1": 0.9077, 18 | "Food101:top1": 0.8790891089108911, 19 | "FashionMNIST:top1": 0.673, 20 | "RenderedSST2:top1": 0.6057111477210324, 21 | "EMNIST:top1": 0.198775, 22 | "KMNIST:top1": 0.1018, 23 | "Cars:normalized_top1": 0.7378282469836763, 24 | "DTD:normalized_top1": 0.5517241379310345, 25 | "EuroSAT:normalized_top1": 0.5613270383114529, 26 | "GTSRB:normalized_top1": 0.4380998080614203, 27 | "MNIST:normalized_top1": 0.5185964912280702, 28 | "RESISC45:normalized_top1": 0.6784018339610284, 29 | "SVHN:normalized_top1": 0.5314745729432554, 30 | "SUN397:normalized_top1": 0.8312452058297111, 31 | "STL10:normalized_top1": 0.996450304259635, 32 | "OxfordIIITPet:normalized_top1": 0.9431785405249494, 33 | "Flowers102:normalized_top1": 0.7618386816999132, 34 | "CIFAR100:normalized_top1": 0.7380663653400947, 35 | "PCAM:normalized_top1": 0.6057557403415118, 36 | "FER2013:normalized_top1": 0.4755870294446516, 37 | "CIFAR10:normalized_top1": 0.9212422612402313, 38 | "Food101:normalized_top1": 0.9464864403888794, 39 | "FashionMNIST:normalized_top1": 0.7036069001568218, 40 | "RenderedSST2:normalized_top1": 0.7649098474341193, 41 | "EMNIST:normalized_top1": 0.19918831575519205, 42 | "KMNIST:normalized_top1": 0.10304686709181092, 43 | "avg_normalized_top1": 0.650402731446373, 44 | "avg_top1": 0.597779641424153 45 | }, 46 | "val": { 47 | "0.0": { 48 | "CarsVal:top1": 0.64004914004914, 49 | "DTDVal:top1": 0.449468085106383, 50 | "EuroSATVal:top1": 0.5481798715203426, 51 | "GTSRBVal:top1": 0.4283033033033033, 52 | "MNISTVal:top1": 0.5052, 53 | "RESISC45Val:top1": 0.6756613756613756, 54 | "SVHNVal:top1": 0.4456, 55 | "SUN397Val:top1": 0.654911838790932, 56 | "STL10Val:top1": 0.978, 57 | "OxfordIIITPetVal:top1": 0.8722826086956522, 58 | "Flowers102Val:top1": 0.6862745098039216, 59 | "CIFAR100Val:top1": 0.6816, 60 | "PCAMVal:top1": 0.4982, 61 | "FER2013Val:top1": 0.35958188153310106, 62 | "CIFAR10Val:top1": 0.9068, 63 | "Food101Val:top1": 0.8448, 64 | "FashionMNISTVal:top1": 0.6772, 65 | "RenderedSST2Val:top1": 0.5953757225433526, 66 | "EMNISTVal:top1": 0.1952, 67 | "KMNISTVal:top1": 0.1006, 68 | "CarsVal:normalized_top1": 0.7400568181818181, 69 | "DTDVal:normalized_top1": 0.5487012987012987, 70 | "EuroSATVal:normalized_top1": 0.5522707660839273, 71 | "GTSRBVal:normalized_top1": 0.4286250939143501, 72 | "MNISTVal:normalized_top1": 0.5077386934673367, 73 | "RESISC45Val:normalized_top1": 0.6962922573609597, 74 | "SVHNVal:normalized_top1": 0.4605208763952046, 75 | "SUN397Val:normalized_top1": 0.8441558441558441, 76 | "STL10Val:normalized_top1": 0.9819277108433735, 77 | "OxfordIIITPetVal:normalized_top1": 0.9197707736389685, 78 | "Flowers102Val:normalized_top1": 0.7070707070707071, 79 | "CIFAR100Val:normalized_top1": 0.7483530961791831, 80 | "PCAMVal:normalized_top1": 0.5090946249744532, 81 | "FER2013Val:normalized_top1": 0.47777777777777775, 82 | "CIFAR10Val:normalized_top1": 0.9213574476732372, 83 | "Food101Val:normalized_top1": 0.9432782492184011, 84 | "FashionMNISTVal:normalized_top1": 0.7107472712006717, 85 | "RenderedSST2Val:normalized_top1": 0.7700934579439253, 86 | "EMNISTVal:normalized_top1": 0.19563038685107237, 87 | "KMNISTVal:normalized_top1": 0.10086224182875476, 88 | "avg_normalized_top1": 0.6382162696730633, 89 | "avg_top1": 0.5871644168503751 90 | } 91 | }, 92 | "val_best": { 93 | "CarsVal:top1": 0.64004914004914, 94 | "DTDVal:top1": 0.449468085106383, 95 | "EuroSATVal:top1": 0.5481798715203426, 96 | "GTSRBVal:top1": 0.4283033033033033, 97 | "MNISTVal:top1": 0.5052, 98 | "RESISC45Val:top1": 0.6756613756613756, 99 | "SVHNVal:top1": 0.4456, 100 | "SUN397Val:top1": 0.654911838790932, 101 | "STL10Val:top1": 0.978, 102 | "OxfordIIITPetVal:top1": 0.8722826086956522, 103 | "Flowers102Val:top1": 0.6862745098039216, 104 | "CIFAR100Val:top1": 0.6816, 105 | "PCAMVal:top1": 0.4982, 106 | "FER2013Val:top1": 0.35958188153310106, 107 | "CIFAR10Val:top1": 0.9068, 108 | "Food101Val:top1": 0.8448, 109 | "FashionMNISTVal:top1": 0.6772, 110 | "RenderedSST2Val:top1": 0.5953757225433526, 111 | "EMNISTVal:top1": 0.1952, 112 | "KMNISTVal:top1": 0.1006, 113 | "CarsVal:normalized_top1": 0.7400568181818181, 114 | "DTDVal:normalized_top1": 0.5487012987012987, 115 | "EuroSATVal:normalized_top1": 0.5522707660839273, 116 | "GTSRBVal:normalized_top1": 0.4286250939143501, 117 | "MNISTVal:normalized_top1": 0.5077386934673367, 118 | "RESISC45Val:normalized_top1": 0.6962922573609597, 119 | "SVHNVal:normalized_top1": 0.4605208763952046, 120 | "SUN397Val:normalized_top1": 0.8441558441558441, 121 | "STL10Val:normalized_top1": 0.9819277108433735, 122 | "OxfordIIITPetVal:normalized_top1": 0.9197707736389685, 123 | "Flowers102Val:normalized_top1": 0.7070707070707071, 124 | "CIFAR100Val:normalized_top1": 0.7483530961791831, 125 | "PCAMVal:normalized_top1": 0.5090946249744532, 126 | "FER2013Val:normalized_top1": 0.47777777777777775, 127 | "CIFAR10Val:normalized_top1": 0.9213574476732372, 128 | "Food101Val:normalized_top1": 0.9432782492184011, 129 | "FashionMNISTVal:normalized_top1": 0.7107472712006717, 130 | "RenderedSST2Val:normalized_top1": 0.7700934579439253, 131 | "EMNISTVal:normalized_top1": 0.19563038685107237, 132 | "KMNISTVal:normalized_top1": 0.10086224182875476, 133 | "avg_normalized_top1": 0.6382162696730633, 134 | "avg_top1": 0.5871644168503751 135 | } 136 | } -------------------------------------------------------------------------------- /results/zero_shot/ViT-B-32_20tasks_zeroshot.json: -------------------------------------------------------------------------------- 1 | { 2 | "test": { 3 | "Cars:top1": 0.5973137669444099, 4 | "DTD:top1": 0.4398936170212766, 5 | "EuroSAT:top1": 0.4650188521156263, 6 | "GTSRB:top1": 0.32604908946951705, 7 | "MNIST:top1": 0.4826, 8 | "RESISC45:top1": 0.6066666666666667, 9 | "SVHN:top1": 0.3162645974185618, 10 | "SUN397:top1": 0.6316372795969774, 11 | "STL10:top1": 0.97125, 12 | "OxfordIIITPet:top1": 0.874897792313982, 13 | "Flowers102:top1": 0.6648235485444788, 14 | "CIFAR100:top1": 0.6423, 15 | "PCAM:top1": 0.606353759765625, 16 | "FER2013:top1": 0.3900808024519365, 17 | "CIFAR10:top1": 0.8983, 18 | "Food101:top1": 0.8257425742574257, 19 | "FashionMNIST:top1": 0.6301, 20 | "RenderedSST2:top1": 0.586490939044481, 21 | "EMNIST:top1": 0.1719, 22 | "KMNIST:top1": 0.0977, 23 | "Cars:normalized_top1": 0.7639573723556545, 24 | "DTD:normalized_top1": 0.5554063129617193, 25 | "EuroSAT:normalized_top1": 0.47006772770954364, 26 | "GTSRB:normalized_top1": 0.3288612042804664, 27 | "MNIST:normalized_top1": 0.4841007122078443, 28 | "RESISC45:normalized_top1": 0.6333057166528583, 29 | "SVHN:normalized_top1": 0.3241466199456672, 30 | "SUN397:normalized_top1": 0.8383257555496122, 31 | "STL10:normalized_top1": 0.9913243174279153, 32 | "OxfordIIITPet:normalized_top1": 0.9602153754113072, 33 | "Flowers102:normalized_top1": 0.7345911949685535, 34 | "CIFAR100:normalized_top1": 0.7195026324633135, 35 | "PCAM:normalized_top1": 0.6900156277131446, 36 | "FER2013:normalized_top1": 0.5347593582887701, 37 | "CIFAR10:normalized_top1": 0.9174752323562455, 38 | "Food101:normalized_top1": 0.9267902386984931, 39 | "FashionMNIST:normalized_top1": 0.6611752360965373, 40 | "RenderedSST2:normalized_top1": 0.7887740029542096, 41 | "EMNIST:normalized_top1": 0.1722746974669907, 42 | "KMNIST:normalized_top1": 0.09905708202372503, 43 | "avg_normalized_top1": 0.6297063208766286, 44 | "avg_top1": 0.5612691642805483 45 | }, 46 | "val": { 47 | "0.0": { 48 | "CarsVal:top1": 0.5958230958230958, 49 | "DTDVal:top1": 0.425531914893617, 50 | "EuroSATVal:top1": 0.4586723768736617, 51 | "GTSRBVal:top1": 0.32995495495495497, 52 | "MNISTVal:top1": 0.4816, 53 | "RESISC45Val:top1": 0.6142857142857143, 54 | "SVHNVal:top1": 0.2916, 55 | "SUN397Val:top1": 0.6347607052896725, 56 | "STL10Val:top1": 0.956, 57 | "OxfordIIITPetVal:top1": 0.8260869565217391, 58 | "Flowers102Val:top1": 0.7352941176470589, 59 | "CIFAR100Val:top1": 0.6498, 60 | "PCAMVal:top1": 0.6014, 61 | "FER2013Val:top1": 0.39930313588850175, 62 | "CIFAR10Val:top1": 0.8936, 63 | "Food101Val:top1": 0.7824, 64 | "FashionMNISTVal:top1": 0.628, 65 | "RenderedSST2Val:top1": 0.5708092485549133, 66 | "EMNISTVal:top1": 0.1586, 67 | "KMNISTVal:top1": 0.0914, 68 | "CarsVal:normalized_top1": 0.7637795275590551, 69 | "DTDVal:normalized_top1": 0.54421768707483, 70 | "EuroSATVal:normalized_top1": 0.4631321681222463, 71 | "GTSRBVal:normalized_top1": 0.3302028549962434, 72 | "MNISTVal:normalized_top1": 0.4837284049819204, 73 | "RESISC45Val:normalized_top1": 0.6382627817482134, 74 | "SVHNVal:normalized_top1": 0.30255239676281387, 75 | "SUN397Val:normalized_top1": 0.841683366733467, 76 | "STL10Val:normalized_top1": 0.9715447154471545, 77 | "OxfordIIITPetVal:normalized_top1": 0.8941176470588236, 78 | "Flowers102Val:normalized_top1": 0.7731958762886599, 79 | "CIFAR100Val:normalized_top1": 0.7258713136729223, 80 | "PCAMVal:normalized_top1": 0.6177074774034511, 81 | "FER2013Val:normalized_top1": 0.5431279620853081, 82 | "CIFAR10Val:normalized_top1": 0.9129546383326522, 83 | "Food101Val:normalized_top1": 0.9246041125029544, 84 | "FashionMNISTVal:normalized_top1": 0.6574539363484087, 85 | "RenderedSST2Val:normalized_top1": 0.7995951417004049, 86 | "EMNISTVal:normalized_top1": 0.15888599479062313, 87 | "KMNISTVal:normalized_top1": 0.09185929648241206, 88 | "avg_normalized_top1": 0.6219238650046282, 89 | "avg_top1": 0.5562461110366466 90 | } 91 | }, 92 | "val_best": { 93 | "CarsVal:top1": 0.5958230958230958, 94 | "DTDVal:top1": 0.425531914893617, 95 | "EuroSATVal:top1": 0.4586723768736617, 96 | "GTSRBVal:top1": 0.32995495495495497, 97 | "MNISTVal:top1": 0.4816, 98 | "RESISC45Val:top1": 0.6142857142857143, 99 | "SVHNVal:top1": 0.2916, 100 | "SUN397Val:top1": 0.6347607052896725, 101 | "STL10Val:top1": 0.956, 102 | "OxfordIIITPetVal:top1": 0.8260869565217391, 103 | "Flowers102Val:top1": 0.7352941176470589, 104 | "CIFAR100Val:top1": 0.6498, 105 | "PCAMVal:top1": 0.6014, 106 | "FER2013Val:top1": 0.39930313588850175, 107 | "CIFAR10Val:top1": 0.8936, 108 | "Food101Val:top1": 0.7824, 109 | "FashionMNISTVal:top1": 0.628, 110 | "RenderedSST2Val:top1": 0.5708092485549133, 111 | "EMNISTVal:top1": 0.1586, 112 | "KMNISTVal:top1": 0.0914, 113 | "Cars:top1": 0.5973137669444099, 114 | "DTD:top1": 0.4398936170212766, 115 | "EuroSAT:top1": 0.4650188521156263, 116 | "GTSRB:top1": 0.32604908946951705, 117 | "MNIST:top1": 0.4826, 118 | "RESISC45:top1": 0.6066666666666667, 119 | "SVHN:top1": 0.3162645974185618, 120 | "SUN397:top1": 0.6316372795969774, 121 | "STL10:top1": 0.97125, 122 | "OxfordIIITPet:top1": 0.874897792313982, 123 | "Flowers102:top1": 0.6648235485444788, 124 | "CIFAR100:top1": 0.6423, 125 | "PCAM:top1": 0.606353759765625, 126 | "FER2013:top1": 0.3900808024519365, 127 | "CIFAR10:top1": 0.8983, 128 | "Food101:top1": 0.8257425742574257, 129 | "FashionMNIST:top1": 0.6301, 130 | "RenderedSST2:top1": 0.586490939044481, 131 | "EMNIST:top1": 0.1719, 132 | "KMNIST:top1": 0.0977, 133 | "CarsVal:normalized_top1": 0.7637795275590551, 134 | "DTDVal:normalized_top1": 0.54421768707483, 135 | "EuroSATVal:normalized_top1": 0.4631321681222463, 136 | "GTSRBVal:normalized_top1": 0.3302028549962434, 137 | "MNISTVal:normalized_top1": 0.4837284049819204, 138 | "RESISC45Val:normalized_top1": 0.6382627817482134, 139 | "SVHNVal:normalized_top1": 0.30255239676281387, 140 | "SUN397Val:normalized_top1": 0.841683366733467, 141 | "STL10Val:normalized_top1": 0.9715447154471545, 142 | "OxfordIIITPetVal:normalized_top1": 0.8941176470588236, 143 | "Flowers102Val:normalized_top1": 0.7731958762886599, 144 | "CIFAR100Val:normalized_top1": 0.7258713136729223, 145 | "PCAMVal:normalized_top1": 0.6177074774034511, 146 | "FER2013Val:normalized_top1": 0.5431279620853081, 147 | "CIFAR10Val:normalized_top1": 0.9129546383326522, 148 | "Food101Val:normalized_top1": 0.9246041125029544, 149 | "FashionMNISTVal:normalized_top1": 0.6574539363484087, 150 | "RenderedSST2Val:normalized_top1": 0.7995951417004049, 151 | "EMNISTVal:normalized_top1": 0.15888599479062313, 152 | "KMNISTVal:normalized_top1": 0.09185929648241206, 153 | "avg_normalized_top1": 0.6219238650046282, 154 | "avg_top1": 0.5562461110366466 155 | } 156 | } -------------------------------------------------------------------------------- /results/zero_shot/ViT-L-14_20tasks_zeroshot.json: -------------------------------------------------------------------------------- 1 | { 2 | "test": { 3 | "Cars:top1": 0.7775152344235792, 4 | "DTD:top1": 0.5531914893617021, 5 | "EuroSAT:top1": 0.6099706744868035, 6 | "GTSRB:top1": 0.5055423594615994, 7 | "MNIST:top1": 0.7636, 8 | "RESISC45:top1": 0.7104761904761905, 9 | "SVHN:top1": 0.5844729563614014, 10 | "SUN397:top1": 0.6827707808564232, 11 | "STL10:top1": 0.993625, 12 | "OxfordIIITPet:top1": 0.9359498500953939, 13 | "Flowers102:top1": 0.7905350463489998, 14 | "CIFAR100:top1": 0.7582, 15 | "PCAM:top1": 0.5120849609375, 16 | "FER2013:top1": 0.3815826135413764, 17 | "CIFAR10:top1": 0.9557, 18 | "Food101:top1": 0.9232475247524753, 19 | "FashionMNIST:top1": 0.6694, 20 | "RenderedSST2:top1": 0.6891817682591982, 21 | "EMNIST:top1": 0.15635, 22 | "KMNIST:top1": 0.1039, 23 | "Cars:normalized_top1": 0.8427011726647796, 24 | "DTD:normalized_top1": 0.6565656565656566, 25 | "EuroSAT:normalized_top1": 0.6145226944456603, 26 | "GTSRB:normalized_top1": 0.5098618541882936, 27 | "MNIST:normalized_top1": 0.7652069345625814, 28 | "RESISC45:normalized_top1": 0.7300603490458327, 29 | "SVHN:normalized_top1": 0.5956155803484048, 30 | "SUN397:normalized_top1": 0.8317785687983307, 31 | "STL10:normalized_top1": 0.9984926516769249, 32 | "OxfordIIITPet:normalized_top1": 0.9797432239657632, 33 | "Flowers102:normalized_top1": 0.8077434363575938, 34 | "CIFAR100:normalized_top1": 0.8123861566484517, 35 | "PCAM:normalized_top1": 0.5672943642449034, 36 | "FER2013:normalized_top1": 0.4954775687409551, 37 | "CIFAR10:normalized_top1": 0.9634072580645161, 38 | "Food101:normalized_top1": 0.9671423830069699, 39 | "FashionMNIST:normalized_top1": 0.6986744598684898, 40 | "RenderedSST2:normalized_top1": 0.8065552699228792, 41 | "EMNIST:normalized_top1": 0.15664763049794608, 42 | "KMNIST:normalized_top1": 0.10521518987341773, 43 | "avg_normalized_top1": 0.6952546201744176, 44 | "avg_top1": 0.6528648224681322 45 | }, 46 | "val": { 47 | "0.0": { 48 | "CarsVal:top1": 0.7960687960687961, 49 | "DTDVal:top1": 0.550531914893617, 50 | "EuroSATVal:top1": 0.6059957173447538, 51 | "GTSRBVal:top1": 0.5082582582582582, 52 | "MNISTVal:top1": 0.771, 53 | "RESISC45Val:top1": 0.7375661375661375, 54 | "SVHNVal:top1": 0.512, 55 | "SUN397Val:top1": 0.6861460957178841, 56 | "STL10Val:top1": 0.992, 57 | "OxfordIIITPetVal:top1": 0.9375, 58 | "Flowers102Val:top1": 0.803921568627451, 59 | "CIFAR100Val:top1": 0.7596, 60 | "PCAMVal:top1": 0.4958, 61 | "FER2013Val:top1": 0.38989547038327527, 62 | "CIFAR10Val:top1": 0.9552, 63 | "Food101Val:top1": 0.8914, 64 | "FashionMNISTVal:top1": 0.6636, 65 | "RenderedSST2Val:top1": 0.680635838150289, 66 | "EMNISTVal:top1": 0.1534, 67 | "KMNISTVal:top1": 0.0902, 68 | "Cars:top1": 0.7775152344235792, 69 | "DTD:top1": 0.5531914893617021, 70 | "EuroSAT:top1": 0.6099706744868035, 71 | "GTSRB:top1": 0.5055423594615994, 72 | "MNIST:top1": 0.7636, 73 | "RESISC45:top1": 0.7104761904761905, 74 | "SVHN:top1": 0.5844729563614014, 75 | "SUN397:top1": 0.6827707808564232, 76 | "STL10:top1": 0.993625, 77 | "OxfordIIITPet:top1": 0.9359498500953939, 78 | "Flowers102:top1": 0.7905350463489998, 79 | "CIFAR100:top1": 0.7582, 80 | "PCAM:top1": 0.5120849609375, 81 | "FER2013:top1": 0.3815826135413764, 82 | "CIFAR10:top1": 0.9557, 83 | "Food101:top1": 0.9232475247524753, 84 | "FashionMNIST:top1": 0.6694, 85 | "RenderedSST2:top1": 0.6891817682591982, 86 | "EMNIST:top1": 0.15635, 87 | "KMNIST:top1": 0.1039, 88 | "CarsVal:normalized_top1": 0.8571428571428572, 89 | "DTDVal:normalized_top1": 0.6550632911392404, 90 | "EuroSATVal:normalized_top1": 0.6102903531633104, 91 | "GTSRBVal:normalized_top1": 0.5084491175366128, 92 | "MNISTVal:normalized_top1": 0.773319959879639, 93 | "RESISC45Val:normalized_top1": 0.7580206634040239, 94 | "SVHNVal:normalized_top1": 0.5265322912381736, 95 | "SUN397Val:normalized_top1": 0.8325183374083129, 96 | "STL10Val:normalized_top1": 0.9979879275653923, 97 | "OxfordIIITPetVal:normalized_top1": 0.9773371104815864, 98 | "Flowers102Val:normalized_top1": 0.8118811881188119, 99 | "CIFAR100Val:normalized_top1": 0.8108454312553374, 100 | "PCAMVal:normalized_top1": 0.5057119543043656, 101 | "FER2013Val:normalized_top1": 0.5040540540540541, 102 | "CIFAR10Val:normalized_top1": 0.9642640823743186, 103 | "Food101Val:normalized_top1": 0.9578766387277026, 104 | "FashionMNISTVal:normalized_top1": 0.6937068785281204, 105 | "RenderedSST2Val:normalized_top1": 0.8248686514886164, 106 | "EMNISTVal:normalized_top1": 0.1537382240930046, 107 | "KMNISTVal:normalized_top1": 0.09043513134148788, 108 | "avg_normalized_top1": 0.6907022071622484, 109 | "avg_top1": 0.6490359898505231 110 | } 111 | }, 112 | "val_best": { 113 | "CarsVal:top1": 0.7960687960687961, 114 | "DTDVal:top1": 0.550531914893617, 115 | "EuroSATVal:top1": 0.6059957173447538, 116 | "GTSRBVal:top1": 0.5082582582582582, 117 | "MNISTVal:top1": 0.771, 118 | "RESISC45Val:top1": 0.7375661375661375, 119 | "SVHNVal:top1": 0.512, 120 | "SUN397Val:top1": 0.6861460957178841, 121 | "STL10Val:top1": 0.992, 122 | "OxfordIIITPetVal:top1": 0.9375, 123 | "Flowers102Val:top1": 0.803921568627451, 124 | "CIFAR100Val:top1": 0.7596, 125 | "PCAMVal:top1": 0.4958, 126 | "FER2013Val:top1": 0.38989547038327527, 127 | "CIFAR10Val:top1": 0.9552, 128 | "Food101Val:top1": 0.8914, 129 | "FashionMNISTVal:top1": 0.6636, 130 | "RenderedSST2Val:top1": 0.680635838150289, 131 | "EMNISTVal:top1": 0.1534, 132 | "KMNISTVal:top1": 0.0902, 133 | "Cars:top1": 0.7775152344235792, 134 | "DTD:top1": 0.5531914893617021, 135 | "EuroSAT:top1": 0.6099706744868035, 136 | "GTSRB:top1": 0.5055423594615994, 137 | "MNIST:top1": 0.7636, 138 | "RESISC45:top1": 0.7104761904761905, 139 | "SVHN:top1": 0.5844729563614014, 140 | "SUN397:top1": 0.6827707808564232, 141 | "STL10:top1": 0.993625, 142 | "OxfordIIITPet:top1": 0.9359498500953939, 143 | "Flowers102:top1": 0.7905350463489998, 144 | "CIFAR100:top1": 0.7582, 145 | "PCAM:top1": 0.5120849609375, 146 | "FER2013:top1": 0.3815826135413764, 147 | "CIFAR10:top1": 0.9557, 148 | "Food101:top1": 0.9232475247524753, 149 | "FashionMNIST:top1": 0.6694, 150 | "RenderedSST2:top1": 0.6891817682591982, 151 | "EMNIST:top1": 0.15635, 152 | "KMNIST:top1": 0.1039, 153 | "CarsVal:normalized_top1": 0.8571428571428572, 154 | "DTDVal:normalized_top1": 0.6550632911392404, 155 | "EuroSATVal:normalized_top1": 0.6102903531633104, 156 | "GTSRBVal:normalized_top1": 0.5084491175366128, 157 | "MNISTVal:normalized_top1": 0.773319959879639, 158 | "RESISC45Val:normalized_top1": 0.7580206634040239, 159 | "SVHNVal:normalized_top1": 0.5265322912381736, 160 | "SUN397Val:normalized_top1": 0.8325183374083129, 161 | "STL10Val:normalized_top1": 0.9979879275653923, 162 | "OxfordIIITPetVal:normalized_top1": 0.9773371104815864, 163 | "Flowers102Val:normalized_top1": 0.8118811881188119, 164 | "CIFAR100Val:normalized_top1": 0.8108454312553374, 165 | "PCAMVal:normalized_top1": 0.5057119543043656, 166 | "FER2013Val:normalized_top1": 0.5040540540540541, 167 | "CIFAR10Val:normalized_top1": 0.9642640823743186, 168 | "Food101Val:normalized_top1": 0.9578766387277026, 169 | "FashionMNISTVal:normalized_top1": 0.6937068785281204, 170 | "RenderedSST2Val:normalized_top1": 0.8248686514886164, 171 | "EMNISTVal:normalized_top1": 0.1537382240930046, 172 | "KMNISTVal:normalized_top1": 0.09043513134148788, 173 | "avg_normalized_top1": 0.6907022071622484, 174 | "avg_top1": 0.6490359898505231 175 | } 176 | } -------------------------------------------------------------------------------- /src/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .common import get_dataloader, maybe_dictionarize 2 | from .registry import get_dataset 3 | 4 | __all__ = ["get_dataloader", "maybe_dictionarize", "get_dataset"] 5 | -------------------------------------------------------------------------------- /src/datasets/cars.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pathlib 3 | from typing import Any, Callable, Optional, Tuple 4 | 5 | import torch 6 | import torchvision.datasets as datasets 7 | from PIL import Image 8 | from torchvision.datasets.utils import ( 9 | download_and_extract_archive, 10 | download_url, 11 | verify_str_arg, 12 | ) 13 | from torchvision.datasets.vision import VisionDataset 14 | from src.utils.variables_and_paths import DATA_DIR 15 | 16 | 17 | class PytorchStanfordCars(VisionDataset): 18 | """`Stanford Cars `_ Dataset 19 | 20 | The Cars dataset contains 16,185 images of 196 classes of cars. The data is 21 | split into 8,144 training images and 8,041 testing images, where each class 22 | has been split roughly in a 50-50 split 23 | 24 | .. note:: 25 | 26 | This class needs `scipy `_ to load target files from `.mat` format. 27 | 28 | Args: 29 | root (string): Root directory of dataset 30 | split (string, optional): The dataset split, supports ``"train"`` (default) or ``"test"``. 31 | transform (callable, optional): A function/transform that takes in an PIL image 32 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 33 | target_transform (callable, optional): A function/transform that takes in the 34 | target and transforms it. 35 | download (bool, optional): If True, downloads the dataset from the internet and 36 | puts it in root directory. If dataset is already downloaded, it is not 37 | downloaded again.""" 38 | 39 | def __init__( 40 | self, 41 | root: str, 42 | split: str = "train", 43 | transform: Optional[Callable] = None, 44 | target_transform: Optional[Callable] = None, 45 | download: bool = False, 46 | ) -> None: 47 | 48 | try: 49 | import scipy.io as sio 50 | except ImportError: 51 | raise RuntimeError("Scipy is not found. This dataset needs to have scipy installed: pip install scipy") 52 | 53 | super().__init__(root, transform=transform, target_transform=target_transform) 54 | 55 | self._split = verify_str_arg(split, "split", ("train", "test")) 56 | self._base_folder = pathlib.Path(root) / "stanford_cars" 57 | devkit = self._base_folder / "devkit" 58 | 59 | if self._split == "train": 60 | self._annotations_mat_path = devkit / "cars_train_annos.mat" 61 | self._images_base_path = self._base_folder / "cars_train" 62 | else: 63 | self._annotations_mat_path = self._base_folder / "cars_test_annos_withlabels.mat" 64 | self._images_base_path = self._base_folder / "cars_test" 65 | 66 | if download: 67 | self.download() 68 | 69 | if not self._check_exists(): 70 | raise RuntimeError("Dataset not found. You can use download=True to download it") 71 | 72 | self._samples = [ 73 | ( 74 | str(self._images_base_path / annotation["fname"]), 75 | annotation["class"] - 1, # Original target mapping starts from 1, hence -1 76 | ) 77 | for annotation in sio.loadmat(self._annotations_mat_path, squeeze_me=True)["annotations"] 78 | ] 79 | 80 | self.classes = sio.loadmat(str(devkit / "cars_meta.mat"), squeeze_me=True)["class_names"].tolist() 81 | self.class_to_idx = {cls: i for i, cls in enumerate(self.classes)} 82 | 83 | def __len__(self) -> int: 84 | return len(self._samples) 85 | 86 | def __getitem__(self, idx: int) -> Tuple[Any, Any]: 87 | """Returns pil_image and class_id for given index""" 88 | image_path, target = self._samples[idx] 89 | pil_image = Image.open(image_path).convert("RGB") 90 | 91 | if self.transform is not None: 92 | pil_image = self.transform(pil_image) 93 | if self.target_transform is not None: 94 | target = self.target_transform(target) 95 | return pil_image, target 96 | 97 | def download(self) -> None: 98 | if self._check_exists(): 99 | return 100 | 101 | download_and_extract_archive( 102 | url="https://ai.stanford.edu/~jkrause/cars/car_devkit.tgz", 103 | download_root=str(self._base_folder), 104 | md5="c3b158d763b6e2245038c8ad08e45376", 105 | ) 106 | if self._split == "train": 107 | download_and_extract_archive( 108 | url="https://ai.stanford.edu/~jkrause/car196/cars_train.tgz", 109 | download_root=str(self._base_folder), 110 | md5="065e5b463ae28d29e77c1b4b166cfe61", 111 | ) 112 | else: 113 | download_and_extract_archive( 114 | url="https://ai.stanford.edu/~jkrause/car196/cars_test.tgz", 115 | download_root=str(self._base_folder), 116 | md5="4ce7ebf6a94d07f1952d94dd34c4d501", 117 | ) 118 | download_url( 119 | url="https://ai.stanford.edu/~jkrause/car196/cars_test_annos_withlabels.mat", 120 | root=str(self._base_folder), 121 | md5="b0a2b23655a3edd16d84508592a98d10", 122 | ) 123 | 124 | def _check_exists(self) -> bool: 125 | if not (self._base_folder / "devkit").is_dir(): 126 | return False 127 | 128 | return self._annotations_mat_path.exists() and self._images_base_path.is_dir() 129 | 130 | 131 | class Cars: 132 | def __init__(self, preprocess, location=DATA_DIR, batch_size=32, num_workers=16): 133 | # Data loading code 134 | 135 | self.train_dataset = PytorchStanfordCars(location, "train", preprocess, download=True) 136 | self.train_loader = torch.utils.data.DataLoader( 137 | self.train_dataset, 138 | shuffle=True, 139 | batch_size=batch_size, 140 | num_workers=num_workers, 141 | ) 142 | 143 | self.test_dataset = PytorchStanfordCars(location, "test", preprocess, download=True) 144 | self.test_loader = torch.utils.data.DataLoader( 145 | self.test_dataset, batch_size=batch_size, num_workers=num_workers 146 | ) 147 | idx_to_class = dict((v, k) for k, v in self.train_dataset.class_to_idx.items()) 148 | self.classnames = [idx_to_class[i].replace("_", " ") for i in range(len(idx_to_class))] 149 | -------------------------------------------------------------------------------- /src/datasets/cifar10.py: -------------------------------------------------------------------------------- 1 | import os 2 | import PIL 3 | import torch 4 | import numpy as np 5 | import torchvision 6 | from torchvision import transforms 7 | from torchvision.datasets import CIFAR10 as PyTorchCIFAR10 8 | from torchvision.datasets import VisionDataset 9 | from src.utils.variables_and_paths import DATA_DIR 10 | 11 | cifar_classnames = [ 12 | "airplane", 13 | "automobile", 14 | "bird", 15 | "cat", 16 | "deer", 17 | "dog", 18 | "frog", 19 | "horse", 20 | "ship", 21 | "truck", 22 | ] 23 | 24 | 25 | class CIFAR10: 26 | def __init__(self, preprocess, location=DATA_DIR, batch_size=128, num_workers=16): 27 | 28 | self.train_dataset = PyTorchCIFAR10(root=location, download=True, train=True, transform=preprocess) 29 | 30 | self.train_loader = torch.utils.data.DataLoader( 31 | self.train_dataset, 32 | batch_size=batch_size, 33 | shuffle=True, 34 | num_workers=num_workers, 35 | ) 36 | 37 | self.test_dataset = PyTorchCIFAR10(root=location, download=True, train=False, transform=preprocess) 38 | 39 | self.test_loader = torch.utils.data.DataLoader( 40 | self.test_dataset, 41 | batch_size=batch_size, 42 | shuffle=False, 43 | num_workers=num_workers, 44 | ) 45 | 46 | self.classnames = self.test_dataset.classes 47 | 48 | 49 | def convert(x): 50 | if isinstance(x, np.ndarray): 51 | return torchvision.transforms.functional.to_pil_image(x) 52 | return x 53 | 54 | 55 | class BasicVisionDataset(VisionDataset): 56 | def __init__(self, images, targets, transform=None, target_transform=None): 57 | if transform is not None: 58 | transform.transforms.insert(0, convert) 59 | super(BasicVisionDataset, self).__init__(root=None, transform=transform, target_transform=target_transform) 60 | assert len(images) == len(targets) 61 | 62 | self.images = images 63 | self.targets = targets 64 | 65 | def __getitem__(self, index): 66 | return self.transform(self.images[index]), self.targets[index] 67 | 68 | def __len__(self): 69 | return len(self.targets) 70 | -------------------------------------------------------------------------------- /src/datasets/cifar100.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torchvision.datasets import CIFAR100 as PyTorchCIFAR100 4 | from src.utils.variables_and_paths import DATA_DIR 5 | 6 | 7 | class CIFAR100: 8 | def __init__(self, preprocess, location=DATA_DIR, batch_size=128, num_workers=16): 9 | 10 | self.train_dataset = PyTorchCIFAR100(root=location, download=True, train=True, transform=preprocess) 11 | 12 | self.train_loader = torch.utils.data.DataLoader( 13 | self.train_dataset, batch_size=batch_size, num_workers=num_workers 14 | ) 15 | 16 | self.test_dataset = PyTorchCIFAR100(root=location, download=True, train=False, transform=preprocess) 17 | 18 | self.test_loader = torch.utils.data.DataLoader( 19 | self.test_dataset, 20 | batch_size=batch_size, 21 | shuffle=False, 22 | num_workers=num_workers, 23 | ) 24 | 25 | self.classnames = self.test_dataset.classes 26 | -------------------------------------------------------------------------------- /src/datasets/common.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import json 4 | import glob 5 | import collections 6 | import random 7 | 8 | import numpy as np 9 | 10 | from tqdm import tqdm 11 | 12 | import torchvision.datasets as datasets 13 | from torch.utils.data import Dataset, DataLoader, Sampler 14 | 15 | 16 | class SubsetSampler(Sampler): 17 | def __init__(self, indices): 18 | self.indices = indices 19 | 20 | def __iter__(self): 21 | return (i for i in self.indices) 22 | 23 | def __len__(self): 24 | return len(self.indices) 25 | 26 | 27 | class ImageFolderWithPaths(datasets.ImageFolder): 28 | def __init__(self, path, transform, flip_label_prob=0.0): 29 | super().__init__(path, transform) 30 | self.flip_label_prob = flip_label_prob 31 | if self.flip_label_prob > 0: 32 | print(f"Flipping labels with probability {self.flip_label_prob}") 33 | num_classes = len(self.classes) 34 | for i in range(len(self.samples)): 35 | if random.random() < self.flip_label_prob: 36 | new_label = random.randint(0, num_classes - 1) 37 | self.samples[i] = (self.samples[i][0], new_label) 38 | 39 | def __getitem__(self, index): 40 | image, label = super(ImageFolderWithPaths, self).__getitem__(index) 41 | return {"images": image, "labels": label, "image_paths": self.samples[index][0]} 42 | 43 | 44 | def maybe_dictionarize(batch): 45 | if isinstance(batch, dict): 46 | return batch 47 | 48 | if len(batch) == 2: 49 | batch = {"images": batch[0], "labels": batch[1]} 50 | elif len(batch) == 3: 51 | batch = {"images": batch[0], "labels": batch[1], "metadata": batch[2]} 52 | else: 53 | raise ValueError(f"Unexpected number of elements: {len(batch)}") 54 | 55 | return batch 56 | 57 | 58 | def get_features_helper(image_encoder, dataloader, device): 59 | all_data = collections.defaultdict(list) 60 | 61 | image_encoder = image_encoder.to(device) 62 | image_encoder = torch.nn.DataParallel(image_encoder, device_ids=[x for x in range(torch.cuda.device_count())]) 63 | image_encoder.eval() 64 | 65 | with torch.no_grad(): 66 | for batch in tqdm(dataloader): 67 | batch = maybe_dictionarize(batch) 68 | features = image_encoder(batch["images"].cuda()) 69 | 70 | all_data["features"].append(features.cpu()) 71 | 72 | for key, val in batch.items(): 73 | if key == "images": 74 | continue 75 | if hasattr(val, "cpu"): 76 | val = val.cpu() 77 | all_data[key].append(val) 78 | else: 79 | all_data[key].extend(val) 80 | 81 | for key, val in all_data.items(): 82 | if torch.is_tensor(val[0]): 83 | all_data[key] = torch.cat(val).numpy() 84 | 85 | return all_data 86 | 87 | 88 | def get_features(is_train, image_encoder, dataset, device): 89 | split = "train" if is_train else "val" 90 | dname = type(dataset).__name__ 91 | if image_encoder.cache_dir is not None: 92 | cache_dir = f"{image_encoder.cache_dir}/{dname}/{split}" 93 | cached_files = glob.glob(f"{cache_dir}/*") 94 | if image_encoder.cache_dir is not None and len(cached_files) > 0: 95 | print(f"Getting features from {cache_dir}") 96 | data = {} 97 | for cached_file in cached_files: 98 | name = os.path.splitext(os.path.basename(cached_file))[0] 99 | data[name] = torch.load(cached_file) 100 | else: 101 | print(f"Did not find cached features at {cache_dir}. Building from scratch.") 102 | loader = dataset.train_loader if is_train else dataset.test_loader 103 | data = get_features_helper(image_encoder, loader, device) 104 | if image_encoder.cache_dir is None: 105 | print("Not caching because no cache directory was passed.") 106 | else: 107 | os.makedirs(cache_dir, exist_ok=True) 108 | print(f"Caching data at {cache_dir}") 109 | for name, val in data.items(): 110 | torch.save(val, f"{cache_dir}/{name}.pt") 111 | return data 112 | 113 | 114 | class FeatureDataset(Dataset): 115 | def __init__(self, is_train, image_encoder, dataset, device): 116 | self.data = get_features(is_train, image_encoder, dataset, device) 117 | 118 | def __len__(self): 119 | return len(self.data["features"]) 120 | 121 | def __getitem__(self, idx): 122 | data = {k: v[idx] for k, v in self.data.items()} 123 | data["features"] = torch.from_numpy(data["features"]).float() 124 | return data 125 | 126 | 127 | def get_dataloader(dataset, is_train, args, image_encoder=None): 128 | if image_encoder is not None: 129 | feature_dataset = FeatureDataset(is_train, image_encoder, dataset, args.device) 130 | dataloader = DataLoader(feature_dataset, batch_size=args.batch_size, shuffle=is_train) 131 | else: 132 | dataloader = dataset.train_loader if is_train else dataset.test_loader 133 | return dataloader 134 | -------------------------------------------------------------------------------- /src/datasets/country211.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import torchvision.datasets as datasets 5 | from src.utils.variables_and_paths import DATA_DIR 6 | 7 | 8 | class Country211: 9 | def __init__(self, preprocess, location=DATA_DIR, batch_size=128, num_workers=16): 10 | 11 | location = os.path.join("~/data", "Country211") 12 | self.train_dataset = datasets.Country211(root=location, download=True, split="train", transform=preprocess) 13 | 14 | self.train_loader = torch.utils.data.DataLoader( 15 | self.train_dataset, 16 | batch_size=batch_size, 17 | shuffle=True, 18 | num_workers=num_workers, 19 | ) 20 | 21 | self.test_dataset = datasets.Country211(root=location, download=True, split="test", transform=preprocess) 22 | 23 | self.test_loader = torch.utils.data.DataLoader( 24 | self.test_dataset, 25 | batch_size=batch_size, 26 | shuffle=False, 27 | num_workers=num_workers, 28 | ) 29 | 30 | self.classnames = self.train_dataset.classes 31 | -------------------------------------------------------------------------------- /src/datasets/dtd.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torchvision.datasets as datasets 4 | from src.utils.variables_and_paths import DATA_DIR 5 | 6 | 7 | class DTD: 8 | def __init__(self, preprocess, location=DATA_DIR, batch_size=32, num_workers=16): 9 | # Data loading code 10 | traindir = os.path.join(location, "dtd", "train") 11 | valdir = os.path.join(location, "dtd", "val") 12 | 13 | self.train_dataset = datasets.ImageFolder(traindir, transform=preprocess) 14 | self.train_loader = torch.utils.data.DataLoader( 15 | self.train_dataset, 16 | shuffle=True, 17 | batch_size=batch_size, 18 | num_workers=num_workers, 19 | ) 20 | 21 | self.test_dataset = datasets.ImageFolder(valdir, transform=preprocess) 22 | self.test_loader = torch.utils.data.DataLoader( 23 | self.test_dataset, batch_size=batch_size, num_workers=num_workers 24 | ) 25 | idx_to_class = dict((v, k) for k, v in self.train_dataset.class_to_idx.items()) 26 | self.classnames = [idx_to_class[i].replace("_", " ") for i in range(len(idx_to_class))] 27 | -------------------------------------------------------------------------------- /src/datasets/emnist.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import torchvision.datasets as datasets 5 | from src.utils.variables_and_paths import DATA_DIR 6 | 7 | 8 | class EMNIST: 9 | def __init__(self, preprocess, location=DATA_DIR, batch_size=128, num_workers=16): 10 | location = os.path.join(location, "EMNIST") 11 | self.train_dataset = datasets.EMNIST( 12 | root=location, 13 | download=True, 14 | split="digits", 15 | transform=preprocess, 16 | train=True, 17 | ) 18 | 19 | self.train_loader = torch.utils.data.DataLoader( 20 | self.train_dataset, 21 | batch_size=batch_size, 22 | shuffle=True, 23 | num_workers=num_workers, 24 | ) 25 | 26 | self.test_dataset = datasets.EMNIST( 27 | root=location, 28 | download=True, 29 | split="digits", 30 | transform=preprocess, 31 | train=False, 32 | ) 33 | 34 | self.test_loader = torch.utils.data.DataLoader( 35 | self.test_dataset, 36 | batch_size=batch_size, 37 | shuffle=False, 38 | num_workers=num_workers, 39 | ) 40 | 41 | self.classnames = self.train_dataset.classes 42 | -------------------------------------------------------------------------------- /src/datasets/eurosat.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torchvision.datasets as datasets 4 | import re 5 | from src.utils.variables_and_paths import DATA_DIR 6 | 7 | 8 | def pretify_classname(classname): 9 | l = re.findall(r"[A-Z](?:[a-z]+|[A-Z]*(?=[A-Z]|$))", classname) 10 | l = [i.lower() for i in l] 11 | out = " ".join(l) 12 | if out.endswith("al"): 13 | return out + " area" 14 | return out 15 | 16 | 17 | class EuroSATBase: 18 | def __init__(self, preprocess, test_split, location=DATA_DIR, batch_size=32, num_workers=16): 19 | # Data loading code 20 | traindir = os.path.join(location, "EuroSAT_splits", "train") 21 | testdir = os.path.join(location, "EuroSAT_splits", test_split) 22 | 23 | self.train_dataset = datasets.ImageFolder(traindir, transform=preprocess) 24 | self.train_loader = torch.utils.data.DataLoader( 25 | self.train_dataset, 26 | shuffle=True, 27 | batch_size=batch_size, 28 | num_workers=num_workers, 29 | ) 30 | 31 | self.test_dataset = datasets.ImageFolder(testdir, transform=preprocess) 32 | self.test_loader = torch.utils.data.DataLoader( 33 | self.test_dataset, batch_size=batch_size, num_workers=num_workers 34 | ) 35 | idx_to_class = dict((v, k) for k, v in self.train_dataset.class_to_idx.items()) 36 | self.classnames = [idx_to_class[i].replace("_", " ") for i in range(len(idx_to_class))] 37 | self.classnames = [pretify_classname(c) for c in self.classnames] 38 | ours_to_open_ai = { 39 | "annual crop": "annual crop land", 40 | "forest": "forest", 41 | "herbaceous vegetation": "brushland or shrubland", 42 | "highway": "highway or road", 43 | "industrial area": "industrial buildings or commercial buildings", 44 | "pasture": "pasture land", 45 | "permanent crop": "permanent crop land", 46 | "residential area": "residential buildings or homes or apartments", 47 | "river": "river", 48 | "sea lake": "lake or sea", 49 | } 50 | for i in range(len(self.classnames)): 51 | self.classnames[i] = ours_to_open_ai[self.classnames[i]] 52 | 53 | 54 | class EuroSAT(EuroSATBase): 55 | def __init__(self, preprocess, location="~/datasets", batch_size=32, num_workers=16): 56 | super().__init__(preprocess, "test", location, batch_size, num_workers) 57 | 58 | 59 | class EuroSATVal(EuroSATBase): 60 | def __init__(self, preprocess, location="~/datasets", batch_size=32, num_workers=16): 61 | super().__init__(preprocess, "val", location, batch_size, num_workers) 62 | -------------------------------------------------------------------------------- /src/datasets/fashionmnist.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import torchvision.datasets as datasets 5 | from src.utils.variables_and_paths import DATA_DIR 6 | 7 | 8 | class FashionMNIST: 9 | def __init__(self, preprocess, location=DATA_DIR, batch_size=128, num_workers=16): 10 | 11 | location = os.path.join(location, "FashionMNIST") 12 | self.train_dataset = datasets.FashionMNIST(root=location, download=True, train=True, transform=preprocess) 13 | 14 | self.train_loader = torch.utils.data.DataLoader( 15 | self.train_dataset, 16 | batch_size=batch_size, 17 | shuffle=True, 18 | num_workers=num_workers, 19 | ) 20 | 21 | self.test_dataset = datasets.FashionMNIST(root=location, download=True, train=False, transform=preprocess) 22 | 23 | self.test_loader = torch.utils.data.DataLoader( 24 | self.test_dataset, 25 | batch_size=batch_size, 26 | shuffle=False, 27 | num_workers=num_workers, 28 | ) 29 | 30 | self.classnames = self.train_dataset.classes 31 | -------------------------------------------------------------------------------- /src/datasets/fer2013.py: -------------------------------------------------------------------------------- 1 | import io 2 | import os 3 | 4 | import torch 5 | from PIL import Image 6 | from torch.utils.data import DataLoader, Dataset 7 | from torchvision import transforms 8 | 9 | from datasets import load_dataset 10 | from src.utils.variables_and_paths import DATA_DIR 11 | 12 | 13 | class CustomFER2013Dataset(Dataset): 14 | def __init__(self, hf_dataset, transform=None): 15 | self.hf_dataset = hf_dataset 16 | self.transform = transform 17 | 18 | def __len__(self): 19 | return len(self.hf_dataset) 20 | 21 | def __getitem__(self, idx): 22 | sample = self.hf_dataset[idx] 23 | image = Image.open(io.BytesIO(sample["img_bytes"])).convert("L") # Convert to PIL image 24 | label = sample["labels"] 25 | 26 | if self.transform: 27 | image = self.transform(image) 28 | 29 | return image, label 30 | 31 | 32 | class FER2013: 33 | def __init__(self, preprocess, location=DATA_DIR, batch_size=128, num_workers=16): 34 | 35 | # location = os.path.join("~/data", "FER2013") 36 | 37 | # Load the FER2013 dataset using Hugging Face datasets library 38 | fer2013 = load_dataset("Jeneral/fer-2013", split="train") 39 | 40 | # Instantiate the custom PyTorch training dataset 41 | self.train_dataset = CustomFER2013Dataset(fer2013, transform=preprocess) 42 | 43 | # Use PyTorch DataLoader to create an iterator over training batches 44 | self.train_loader = DataLoader( 45 | self.train_dataset, 46 | batch_size=batch_size, 47 | shuffle=True, 48 | num_workers=num_workers, 49 | ) 50 | 51 | # Load the FER2013 test dataset using Hugging Face datasets library 52 | fer2013_test = load_dataset("Jeneral/fer-2013", split="test") 53 | 54 | # Instantiate the custom PyTorch test dataset 55 | self.test_dataset = CustomFER2013Dataset(fer2013_test, transform=preprocess) 56 | 57 | # Use PyTorch DataLoader to create an iterator over test batches 58 | self.test_loader = DataLoader( 59 | self.test_dataset, 60 | batch_size=batch_size, 61 | shuffle=False, 62 | num_workers=num_workers, 63 | ) 64 | 65 | self.classnames = [ 66 | ["angry"], 67 | ["disgusted"], 68 | ["fearful"], 69 | ["happy", "smiling"], 70 | ["sad", "depressed"], 71 | ["surprised", "shocked", "spooked"], 72 | ["neutral", "bored"], 73 | ] 74 | -------------------------------------------------------------------------------- /src/datasets/flowers102.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torchvision.datasets as datasets 4 | from src.utils.variables_and_paths import DATA_DIR 5 | 6 | 7 | class Flowers102: 8 | def __init__(self, preprocess, location=DATA_DIR, batch_size=128, num_workers=16): 9 | 10 | location = os.path.join(location, "flowers102") 11 | self.train_dataset = datasets.Flowers102(root=location, download=True, split="train", transform=preprocess) 12 | 13 | self.train_loader = torch.utils.data.DataLoader( 14 | self.train_dataset, 15 | batch_size=batch_size, 16 | shuffle=True, 17 | num_workers=num_workers, 18 | ) 19 | 20 | self.test_dataset = datasets.Flowers102(root=location, download=True, split="test", transform=preprocess) 21 | 22 | self.test_loader = torch.utils.data.DataLoader( 23 | self.test_dataset, 24 | batch_size=batch_size, 25 | shuffle=False, 26 | num_workers=num_workers, 27 | ) 28 | 29 | self.classnames = [ 30 | "pink primrose", 31 | "hard-leaved pocket orchid", 32 | "canterbury bells", 33 | "sweet pea", 34 | "english marigold", 35 | "tiger lily", 36 | "moon orchid", 37 | "bird of paradise", 38 | "monkshood", 39 | "globe thistle", 40 | "snapdragon", 41 | "colt's foot", 42 | "king protea", 43 | "spear thistle", 44 | "yellow iris", 45 | "globe flower", 46 | "purple coneflower", 47 | "peruvian lily", 48 | "balloon flower", 49 | "giant white arum lily", 50 | "fire lily", 51 | "pincushion flower", 52 | "fritillary", 53 | "red ginger", 54 | "grape hyacinth", 55 | "corn poppy", 56 | "prince of wales feathers", 57 | "stemless gentian", 58 | "artichoke", 59 | "sweet william", 60 | "carnation", 61 | "garden phlox", 62 | "love in the mist", 63 | "mexican aster", 64 | "alpine sea holly", 65 | "ruby-lipped cattleya", 66 | "cape flower", 67 | "great masterwort", 68 | "siam tulip", 69 | "lenten rose", 70 | "barbeton daisy", 71 | "daffodil", 72 | "sword lily", 73 | "poinsettia", 74 | "bolero deep blue", 75 | "wallflower", 76 | "marigold", 77 | "buttercup", 78 | "oxeye daisy", 79 | "common dandelion", 80 | "petunia", 81 | "wild pansy", 82 | "primula", 83 | "sunflower", 84 | "pelargonium", 85 | "bishop of llandaff", 86 | "gaura", 87 | "geranium", 88 | "orange dahlia", 89 | "pink and yellow dahlia", 90 | "cautleya spicata", 91 | "japanese anemone", 92 | "black-eyed susan", 93 | "silverbush", 94 | "californian poppy", 95 | "osteospermum", 96 | "spring crocus", 97 | "bearded iris", 98 | "windflower", 99 | "tree poppy", 100 | "gazania", 101 | "azalea", 102 | "water lily", 103 | "rose", 104 | "thorn apple", 105 | "morning glory", 106 | "passion flower", 107 | "lotus", 108 | "toad lily", 109 | "anthurium", 110 | "frangipani", 111 | "clematis", 112 | "hibiscus", 113 | "columbine", 114 | "desert-rose", 115 | "tree mallow", 116 | "magnolia", 117 | "cyclamen", 118 | "watercress", 119 | "canna lily", 120 | "hippeastrum", 121 | "bee balm", 122 | "air plant", 123 | "foxglove", 124 | "bougainvillea", 125 | "camellia", 126 | "mallow", 127 | "mexican petunia", 128 | "bromelia", 129 | "blanket flower", 130 | "trumpet creeper", 131 | "blackberry lily", 132 | ] 133 | -------------------------------------------------------------------------------- /src/datasets/food101.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torchvision.datasets as datasets 4 | from src.utils.variables_and_paths import DATA_DIR 5 | 6 | 7 | class Food101: 8 | def __init__(self, preprocess, location=DATA_DIR, batch_size=128, num_workers=16): 9 | # 10 | location = os.path.join(location, "food101") 11 | self.train_dataset = datasets.Food101(root=location, download=True, split="train", transform=preprocess) 12 | 13 | self.train_loader = torch.utils.data.DataLoader( 14 | self.train_dataset, 15 | batch_size=batch_size, 16 | shuffle=True, 17 | num_workers=num_workers, 18 | ) 19 | 20 | self.test_dataset = datasets.Food101(root=location, download=True, split="test", transform=preprocess) 21 | 22 | self.test_loader = torch.utils.data.DataLoader( 23 | self.test_dataset, 24 | batch_size=batch_size, 25 | shuffle=False, 26 | num_workers=num_workers, 27 | ) 28 | 29 | self.classnames = self.train_dataset.classes 30 | -------------------------------------------------------------------------------- /src/datasets/gtsrb.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import os 3 | import pathlib 4 | from typing import Any, Callable, Dict, List, Optional, Tuple 5 | 6 | import numpy as np 7 | import PIL 8 | import torch 9 | from torchvision.datasets.folder import make_dataset 10 | from torchvision.datasets.utils import download_and_extract_archive, verify_str_arg 11 | from torchvision.datasets.vision import VisionDataset 12 | from src.utils.variables_and_paths import DATA_DIR 13 | 14 | 15 | def find_classes(directory: str) -> Tuple[List[str], Dict[str, int]]: 16 | """Finds the class folders in a dataset. 17 | 18 | See :class:`DatasetFolder` for details. 19 | """ 20 | classes = sorted(entry.name for entry in os.scandir(directory) if entry.is_dir()) 21 | if not classes: 22 | raise FileNotFoundError(f"Couldn't find any class folder in {directory}.") 23 | 24 | class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)} 25 | return classes, class_to_idx 26 | 27 | 28 | class PyTorchGTSRB(VisionDataset): 29 | """`German Traffic Sign Recognition Benchmark (GTSRB) `_ Dataset. 30 | 31 | Modified from https://pytorch.org/vision/main/_modules/torchvision/datasets/gtsrb.html#GTSRB. 32 | 33 | Args: 34 | root (string): Root directory of the dataset. 35 | split (string, optional): The dataset split, supports ``"train"`` (default), or ``"test"``. 36 | transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed 37 | version. E.g, ``transforms.RandomCrop``. 38 | target_transform (callable, optional): A function/transform that takes in the target and transforms it. 39 | download (bool, optional): If True, downloads the dataset from the internet and 40 | puts it in root directory. If dataset is already downloaded, it is not 41 | downloaded again. 42 | """ 43 | 44 | def __init__( 45 | self, 46 | root: str, 47 | split: str = "train", 48 | transform: Optional[Callable] = None, 49 | target_transform: Optional[Callable] = None, 50 | download: bool = False, 51 | ) -> None: 52 | 53 | super().__init__(root, transform=transform, target_transform=target_transform) 54 | 55 | self._split = verify_str_arg(split, "split", ("train", "test")) 56 | self._base_folder = pathlib.Path(root) / "gtsrb" 57 | self._target_folder = ( 58 | self._base_folder / "GTSRB" / ("Training" if self._split == "train" else "Final_Test/Images") 59 | ) 60 | 61 | if download: 62 | self.download() 63 | 64 | if not self._check_exists(): 65 | raise RuntimeError("Dataset not found. You can use download=True to download it") 66 | 67 | if self._split == "train": 68 | _, class_to_idx = find_classes(str(self._target_folder)) 69 | samples = make_dataset( 70 | str(self._target_folder), 71 | extensions=(".ppm",), 72 | class_to_idx=class_to_idx, 73 | ) 74 | else: 75 | with open(self._base_folder / "GT-final_test.csv") as csv_file: 76 | samples = [ 77 | (str(self._target_folder / row["Filename"]), int(row["ClassId"])) 78 | for row in csv.DictReader(csv_file, delimiter=";", skipinitialspace=True) 79 | ] 80 | 81 | self._samples = samples 82 | self.transform = transform 83 | self.target_transform = target_transform 84 | 85 | def __len__(self) -> int: 86 | return len(self._samples) 87 | 88 | def __getitem__(self, index: int) -> Tuple[Any, Any]: 89 | 90 | path, target = self._samples[index] 91 | sample = PIL.Image.open(path).convert("RGB") 92 | 93 | if self.transform is not None: 94 | sample = self.transform(sample) 95 | 96 | if self.target_transform is not None: 97 | target = self.target_transform(target) 98 | 99 | return sample, target 100 | 101 | def _check_exists(self) -> bool: 102 | return self._target_folder.is_dir() 103 | 104 | def download(self) -> None: 105 | if self._check_exists(): 106 | return 107 | 108 | base_url = "https://sid.erda.dk/public/archives/daaeac0d7ce1152aea9b61d9f1e19370/" 109 | 110 | if self._split == "train": 111 | download_and_extract_archive( 112 | f"{base_url}GTSRB-Training_fixed.zip", 113 | download_root=str(self._base_folder), 114 | md5="513f3c79a4c5141765e10e952eaa2478", 115 | ) 116 | else: 117 | download_and_extract_archive( 118 | f"{base_url}GTSRB_Final_Test_Images.zip", 119 | download_root=str(self._base_folder), 120 | md5="c7e4e6327067d32654124b0fe9e82185", 121 | ) 122 | download_and_extract_archive( 123 | f"{base_url}GTSRB_Final_Test_GT.zip", 124 | download_root=str(self._base_folder), 125 | md5="fe31e9c9270bbcd7b84b7f21a9d9d9e5", 126 | ) 127 | 128 | 129 | class GTSRB: 130 | def __init__(self, preprocess, location=DATA_DIR, batch_size=128, num_workers=16): 131 | location = os.path.join(location, "gtsrb") 132 | 133 | # to fit with repo conventions for location 134 | self.train_dataset = PyTorchGTSRB(root=location, download=True, split="train", transform=preprocess) 135 | 136 | self.train_loader = torch.utils.data.DataLoader( 137 | self.train_dataset, 138 | batch_size=batch_size, 139 | shuffle=True, 140 | num_workers=num_workers, 141 | ) 142 | 143 | self.test_dataset = PyTorchGTSRB(root=location, download=True, split="test", transform=preprocess) 144 | 145 | self.test_loader = torch.utils.data.DataLoader( 146 | self.test_dataset, 147 | batch_size=batch_size, 148 | shuffle=False, 149 | num_workers=num_workers, 150 | ) 151 | 152 | # from https://github.com/openai/CLIP/blob/e184f608c5d5e58165682f7c332c3a8b4c1545f2/data/prompts.md 153 | self.classnames = [ 154 | "red and white circle 20 kph speed limit", 155 | "red and white circle 30 kph speed limit", 156 | "red and white circle 50 kph speed limit", 157 | "red and white circle 60 kph speed limit", 158 | "red and white circle 70 kph speed limit", 159 | "red and white circle 80 kph speed limit", 160 | "end / de-restriction of 80 kph speed limit", 161 | "red and white circle 100 kph speed limit", 162 | "red and white circle 120 kph speed limit", 163 | "red and white circle red car and black car no passing", 164 | "red and white circle red truck and black car no passing", 165 | "red and white triangle road intersection warning", 166 | "white and yellow diamond priority road", 167 | "red and white upside down triangle yield right-of-way", 168 | "stop", 169 | "empty red and white circle", 170 | "red and white circle no truck entry", 171 | "red circle with white horizonal stripe no entry", 172 | "red and white triangle with exclamation mark warning", 173 | "red and white triangle with black left curve approaching warning", 174 | "red and white triangle with black right curve approaching warning", 175 | "red and white triangle with black double curve approaching warning", 176 | "red and white triangle rough / bumpy road warning", 177 | "red and white triangle car skidding / slipping warning", 178 | "red and white triangle with merging / narrow lanes warning", 179 | "red and white triangle with person digging / construction / road work warning", 180 | "red and white triangle with traffic light approaching warning", 181 | "red and white triangle with person walking warning", 182 | "red and white triangle with child and person walking warning", 183 | "red and white triangle with bicyle warning", 184 | "red and white triangle with snowflake / ice warning", 185 | "red and white triangle with deer warning", 186 | "white circle with gray strike bar no speed limit", 187 | "blue circle with white right turn arrow mandatory", 188 | "blue circle with white left turn arrow mandatory", 189 | "blue circle with white forward arrow mandatory", 190 | "blue circle with white forward or right turn arrow mandatory", 191 | "blue circle with white forward or left turn arrow mandatory", 192 | "blue circle with white keep right arrow mandatory", 193 | "blue circle with white keep left arrow mandatory", 194 | "blue circle with white arrows indicating a traffic circle", 195 | "white circle with gray strike bar indicating no passing for cars has ended", 196 | "white circle with gray strike bar indicating no passing for trucks has ended", 197 | ] 198 | -------------------------------------------------------------------------------- /src/datasets/kmnist.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import torchvision.datasets as datasets 5 | from src.utils.variables_and_paths import DATA_DIR 6 | 7 | 8 | class KMNIST: 9 | def __init__(self, preprocess, location=DATA_DIR, batch_size=128, num_workers=16): 10 | 11 | location = os.path.join("~/data", "KMNIST") 12 | self.train_dataset = datasets.KMNIST(root=location, download=True, train=True, transform=preprocess) 13 | 14 | self.train_loader = torch.utils.data.DataLoader( 15 | self.train_dataset, 16 | batch_size=batch_size, 17 | shuffle=True, 18 | num_workers=num_workers, 19 | ) 20 | 21 | self.test_dataset = datasets.KMNIST(root=location, download=True, train=False, transform=preprocess) 22 | 23 | self.test_loader = torch.utils.data.DataLoader( 24 | self.test_dataset, 25 | batch_size=batch_size, 26 | shuffle=False, 27 | num_workers=num_workers, 28 | ) 29 | 30 | self.classnames = self.train_dataset.classes 31 | -------------------------------------------------------------------------------- /src/datasets/mnist.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torchvision.datasets as datasets 4 | from src.utils.variables_and_paths import DATA_DIR 5 | 6 | 7 | class MNIST: 8 | def __init__(self, preprocess, location=DATA_DIR, batch_size=128, num_workers=16): 9 | self.train_dataset = datasets.MNIST(root=location, download=True, train=True, transform=preprocess) 10 | 11 | self.train_loader = torch.utils.data.DataLoader( 12 | self.train_dataset, 13 | batch_size=batch_size, 14 | shuffle=True, 15 | num_workers=num_workers, 16 | ) 17 | 18 | self.test_dataset = datasets.MNIST(root=location, download=True, train=False, transform=preprocess) 19 | 20 | self.test_loader = torch.utils.data.DataLoader( 21 | self.test_dataset, 22 | batch_size=batch_size, 23 | shuffle=False, 24 | num_workers=num_workers, 25 | ) 26 | 27 | self.classnames = ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"] 28 | -------------------------------------------------------------------------------- /src/datasets/oxfordpets.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torchvision.datasets as datasets 4 | from src.utils.variables_and_paths import DATA_DIR 5 | 6 | 7 | class OxfordIIITPet: 8 | def __init__(self, preprocess, location=DATA_DIR, batch_size=128, num_workers=16): 9 | 10 | location = os.path.join(location, "OxfordIIITPet") 11 | self.train_dataset = datasets.OxfordIIITPet( 12 | root=location, download=True, split="trainval", transform=preprocess 13 | ) 14 | 15 | self.train_loader = torch.utils.data.DataLoader( 16 | self.train_dataset, 17 | batch_size=batch_size, 18 | shuffle=True, 19 | num_workers=num_workers, 20 | ) 21 | 22 | self.test_dataset = datasets.OxfordIIITPet(root=location, download=True, split="test", transform=preprocess) 23 | 24 | self.test_loader = torch.utils.data.DataLoader( 25 | self.test_dataset, 26 | batch_size=batch_size, 27 | shuffle=False, 28 | num_workers=num_workers, 29 | ) 30 | 31 | self.classnames = self.train_dataset.classes 32 | -------------------------------------------------------------------------------- /src/datasets/pcam.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import torchvision.datasets as datasets 5 | from src.utils.variables_and_paths import DATA_DIR 6 | 7 | 8 | class PCAM: 9 | def __init__(self, preprocess, location=DATA_DIR, batch_size=128, num_workers=16): 10 | # # location = os.path.join("~/data", "PCAM") 11 | location = os.path.join(location, "PCAM") 12 | self.train_dataset = datasets.PCAM(root=location, download=True, split="train", transform=preprocess) 13 | 14 | self.train_loader = torch.utils.data.DataLoader( 15 | self.train_dataset, 16 | batch_size=batch_size, 17 | shuffle=True, 18 | num_workers=num_workers, 19 | ) 20 | 21 | self.test_dataset = datasets.PCAM(root=location, download=True, split="test", transform=preprocess) 22 | 23 | self.test_loader = torch.utils.data.DataLoader( 24 | self.test_dataset, 25 | batch_size=batch_size, 26 | shuffle=False, 27 | num_workers=num_workers, 28 | ) 29 | 30 | self.classnames = [ 31 | "lymph node", 32 | "lymph node containing metastatic tumor tissue", 33 | ] 34 | -------------------------------------------------------------------------------- /src/datasets/registry.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import inspect 3 | import random 4 | import sys 5 | 6 | import torch 7 | from torch.utils.data.dataset import random_split 8 | 9 | from src.datasets.cars import Cars 10 | from src.datasets.cifar10 import CIFAR10 11 | from src.datasets.cifar100 import CIFAR100 12 | from src.datasets.dtd import DTD 13 | from src.datasets.emnist import EMNIST 14 | from src.datasets.eurosat import EuroSAT, EuroSATVal 15 | from src.datasets.fashionmnist import FashionMNIST 16 | from src.datasets.fer2013 import FER2013 17 | from src.datasets.flowers102 import Flowers102 18 | from src.datasets.food101 import Food101 19 | from src.datasets.gtsrb import GTSRB 20 | from src.datasets.imagenet import ImageNet 21 | from src.datasets.kmnist import KMNIST 22 | from src.datasets.mnist import MNIST 23 | from src.datasets.oxfordpets import OxfordIIITPet 24 | from src.datasets.pcam import PCAM 25 | from src.datasets.resisc45 import RESISC45 26 | from src.datasets.sst2 import RenderedSST2 27 | from src.datasets.stl10 import STL10 28 | from src.datasets.sun397 import SUN397 29 | from src.datasets.svhn import SVHN 30 | 31 | registry = {name: obj for name, obj in inspect.getmembers(sys.modules[__name__], inspect.isclass)} 32 | 33 | 34 | class GenericDataset(object): 35 | def __init__(self): 36 | self.train_dataset = None 37 | self.train_loader = None 38 | self.test_dataset = None 39 | self.test_loader = None 40 | self.classnames = None 41 | 42 | 43 | def split_train_into_train_val( 44 | dataset, 45 | new_dataset_class_name, 46 | batch_size, 47 | num_workers, 48 | val_fraction, 49 | max_val_samples=None, 50 | seed=0, 51 | ): 52 | assert val_fraction > 0.0 and val_fraction < 1.0 53 | total_size = len(dataset.train_dataset) 54 | val_size = int(total_size * val_fraction) 55 | if max_val_samples is not None: 56 | val_size = min(val_size, max_val_samples) 57 | train_size = total_size - val_size 58 | 59 | assert val_size > 0 60 | assert train_size > 0 61 | 62 | lengths = [train_size, val_size] 63 | 64 | trainset, valset = random_split(dataset.train_dataset, lengths, generator=torch.Generator().manual_seed(seed)) 65 | if new_dataset_class_name == "MNISTVal": 66 | assert trainset.indices[0] == 36044 67 | 68 | new_dataset = None 69 | 70 | new_dataset_class = type(new_dataset_class_name, (GenericDataset,), {}) 71 | new_dataset = new_dataset_class() 72 | 73 | new_dataset.train_dataset = trainset 74 | new_dataset.train_loader = torch.utils.data.DataLoader( 75 | new_dataset.train_dataset, 76 | shuffle=True, 77 | batch_size=batch_size, 78 | num_workers=num_workers, 79 | ) 80 | 81 | new_dataset.test_dataset = valset 82 | new_dataset.test_loader = torch.utils.data.DataLoader( 83 | new_dataset.test_dataset, batch_size=batch_size, num_workers=num_workers 84 | ) 85 | 86 | new_dataset.classnames = copy.copy(dataset.classnames) 87 | 88 | return new_dataset 89 | 90 | 91 | def get_dataset( 92 | dataset_name, 93 | preprocess, 94 | location, 95 | batch_size=128, 96 | num_workers=16, 97 | val_fraction=0.1, 98 | max_val_samples=5000, 99 | ): 100 | if dataset_name.endswith("Val"): 101 | # Handle val splits 102 | if dataset_name in registry: 103 | dataset_class = registry[dataset_name] 104 | else: 105 | base_dataset_name = dataset_name.split("Val")[0] 106 | base_dataset = get_dataset(base_dataset_name, preprocess, location, batch_size, num_workers) 107 | dataset = split_train_into_train_val( 108 | base_dataset, 109 | dataset_name, 110 | batch_size, 111 | num_workers, 112 | val_fraction, 113 | max_val_samples, 114 | ) 115 | return dataset 116 | else: 117 | assert ( 118 | dataset_name in registry 119 | ), f"Unsupported dataset: {dataset_name}. Supported datasets: {list(registry.keys())}" 120 | dataset_class = registry[dataset_name] 121 | 122 | dataset = dataset_class(preprocess, location=location, batch_size=batch_size, num_workers=num_workers) 123 | return dataset 124 | -------------------------------------------------------------------------------- /src/datasets/resisc45.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | import abc 5 | import os 6 | from typing import Any, Callable, Dict, Optional, Tuple 7 | 8 | import numpy as np 9 | import torch 10 | from torch import Tensor 11 | from torch.utils.data import Dataset 12 | from torchvision.datasets import ImageFolder 13 | from torchvision.datasets.folder import default_loader as pil_loader 14 | from src.utils.variables_and_paths import DATA_DIR 15 | 16 | 17 | # modified from: https://github.com/microsoft/torchgeo 18 | class VisionDataset(Dataset[Dict[str, Any]], abc.ABC): 19 | """Abstract base class for datasets lacking geospatial information. 20 | This base class is designed for datasets with pre-defined image chips. 21 | """ 22 | 23 | @abc.abstractmethod 24 | def __getitem__(self, index: int) -> Dict[str, Any]: 25 | """Return an index within the dataset. 26 | Args: 27 | index: index to return 28 | Returns: 29 | data and labels at that index 30 | Raises: 31 | IndexError: if index is out of range of the dataset 32 | """ 33 | 34 | @abc.abstractmethod 35 | def __len__(self) -> int: 36 | """Return the length of the dataset. 37 | Returns: 38 | length of the dataset 39 | """ 40 | 41 | def __str__(self) -> str: 42 | """Return the informal string representation of the object. 43 | Returns: 44 | informal string representation 45 | """ 46 | return f"""\ 47 | {self.__class__.__name__} Dataset 48 | type: VisionDataset 49 | size: {len(self)}""" 50 | 51 | 52 | class VisionClassificationDataset(VisionDataset, ImageFolder): 53 | """Abstract base class for classification datasets lacking geospatial information. 54 | This base class is designed for datasets with pre-defined image chips which 55 | are separated into separate folders per class. 56 | """ 57 | 58 | def __init__( 59 | self, 60 | root: str, 61 | transforms: Optional[Callable[[Dict[str, Tensor]], Dict[str, Tensor]]] = None, 62 | loader: Optional[Callable[[str], Any]] = pil_loader, 63 | is_valid_file: Optional[Callable[[str], bool]] = None, 64 | ) -> None: 65 | """Initialize a new VisionClassificationDataset instance. 66 | Args: 67 | root: root directory where dataset can be found 68 | transforms: a function/transform that takes input sample and its target as 69 | entry and returns a transformed version 70 | loader: a callable function which takes as input a path to an image and 71 | returns a PIL Image or numpy array 72 | is_valid_file: A function that takes the path of an Image file and checks if 73 | the file is a valid file 74 | """ 75 | # When transform & target_transform are None, ImageFolder.__getitem__(index) 76 | # returns a PIL.Image and int for image and label, respectively 77 | super().__init__( 78 | root=root, 79 | transform=None, 80 | target_transform=None, 81 | loader=loader, 82 | is_valid_file=is_valid_file, 83 | ) 84 | 85 | # Must be set after calling super().__init__() 86 | self.transforms = transforms 87 | 88 | def __getitem__(self, index: int) -> Dict[str, Tensor]: 89 | """Return an index within the dataset. 90 | Args: 91 | index: index to return 92 | Returns: 93 | data and label at that index 94 | """ 95 | image, label = self._load_image(index) 96 | 97 | if self.transforms is not None: 98 | return self.transforms(image), label 99 | 100 | return image, label 101 | 102 | def __len__(self) -> int: 103 | """Return the number of data points in the dataset. 104 | Returns: 105 | length of the dataset 106 | """ 107 | return len(self.imgs) 108 | 109 | def _load_image(self, index: int) -> Tuple[Tensor, Tensor]: 110 | """Load a single image and it's class label. 111 | Args: 112 | index: index to return 113 | Returns: 114 | the image 115 | the image class label 116 | """ 117 | img, label = ImageFolder.__getitem__(self, index) 118 | label = torch.tensor(label) 119 | return img, label 120 | 121 | 122 | class RESISC45Dataset(VisionClassificationDataset): 123 | """RESISC45 dataset. 124 | The `RESISC45 `_ 125 | dataset is a dataset for remote sensing image scene classification. 126 | Dataset features: 127 | * 31,500 images with 0.2-30 m per pixel resolution (256x256 px) 128 | * three spectral bands - RGB 129 | * 45 scene classes, 700 images per class 130 | * images extracted from Google Earth from over 100 countries 131 | * images conditions with high variability (resolution, weather, illumination) 132 | Dataset format: 133 | * images are three-channel jpgs 134 | Dataset classes: 135 | 0. airplane 136 | 1. airport 137 | 2. baseball_diamond 138 | 3. basketball_court 139 | 4. beach 140 | 5. bridge 141 | 6. chaparral 142 | 7. church 143 | 8. circular_farmland 144 | 9. cloud 145 | 10. commercial_area 146 | 11. dense_residential 147 | 12. desert 148 | 13. forest 149 | 14. freeway 150 | 15. golf_course 151 | 16. ground_track_field 152 | 17. harbor 153 | 18. industrial_area 154 | 19. intersection 155 | 20. island 156 | 21. lake 157 | 22. meadow 158 | 23. medium_residential 159 | 24. mobile_home_park 160 | 25. mountain 161 | 26. overpass 162 | 27. palace 163 | 28. parking_lot 164 | 29. railway 165 | 30. railway_station 166 | 31. rectangular_farmland 167 | 32. river 168 | 33. roundabout 169 | 34. runway 170 | 35. sea_ice 171 | 36. ship 172 | 37. snowberg 173 | 38. sparse_residential 174 | 39. stadium 175 | 40. storage_tank 176 | 41. tennis_court 177 | 42. terrace 178 | 43. thermal_power_station 179 | 44. wetland 180 | This dataset uses the train/val/test splits defined in the "In-domain representation 181 | learning for remote sensing" paper: 182 | * https://arxiv.org/abs/1911.06721 183 | If you use this dataset in your research, please cite the following paper: 184 | * https://doi.org/10.1109/jproc.2017.2675998 185 | """ 186 | 187 | # url = "https://drive.google.com/file/d/1DnPSU5nVSN7xv95bpZ3XQ0JhKXZOKgIv" 188 | # md5 = "d824acb73957502b00efd559fc6cfbbb" 189 | # filename = "NWPU-RESISC45.rar" 190 | directory = "resisc45/NWPU-RESISC45" 191 | 192 | splits = ["train", "val", "test"] 193 | split_urls = { 194 | "train": "https://storage.googleapis.com/remote_sensing_representations/resisc45-train.txt", # noqa: E501 195 | "val": "https://storage.googleapis.com/remote_sensing_representations/resisc45-val.txt", # noqa: E501 196 | "test": "https://storage.googleapis.com/remote_sensing_representations/resisc45-test.txt", # noqa: E501 197 | } 198 | split_md5s = { 199 | "train": "b5a4c05a37de15e4ca886696a85c403e", 200 | "val": "a0770cee4c5ca20b8c32bbd61e114805", 201 | "test": "3dda9e4988b47eb1de9f07993653eb08", 202 | } 203 | classes = [ 204 | "airplane", 205 | "airport", 206 | "baseball_diamond", 207 | "basketball_court", 208 | "beach", 209 | "bridge", 210 | "chaparral", 211 | "church", 212 | "circular_farmland", 213 | "cloud", 214 | "commercial_area", 215 | "dense_residential", 216 | "desert", 217 | "forest", 218 | "freeway", 219 | "golf_course", 220 | "ground_track_field", 221 | "harbor", 222 | "industrial_area", 223 | "intersection", 224 | "island", 225 | "lake", 226 | "meadow", 227 | "medium_residential", 228 | "mobile_home_park", 229 | "mountain", 230 | "overpass", 231 | "palace", 232 | "parking_lot", 233 | "railway", 234 | "railway_station", 235 | "rectangular_farmland", 236 | "river", 237 | "roundabout", 238 | "runway", 239 | "sea_ice", 240 | "ship", 241 | "snowberg", 242 | "sparse_residential", 243 | "stadium", 244 | "storage_tank", 245 | "tennis_court", 246 | "terrace", 247 | "thermal_power_station", 248 | "wetland", 249 | ] 250 | 251 | def __init__( 252 | self, 253 | root: str = "data", 254 | split: str = "train", 255 | transforms: Optional[Callable[[Dict[str, Tensor]], Dict[str, Tensor]]] = None, 256 | ) -> None: 257 | """Initialize a new RESISC45 dataset instance. 258 | Args: 259 | root: root directory where dataset can be found 260 | split: one of "train", "val", or "test" 261 | transforms: a function/transform that takes input sample and its target as 262 | entry and returns a transformed version 263 | """ 264 | assert split in self.splits 265 | self.root = root 266 | 267 | valid_fns = set() 268 | with open(os.path.join(self.root, "resisc45", f"resisc45-{split}.txt")) as f: 269 | for fn in f: 270 | valid_fns.add(fn.strip()) 271 | is_in_split: Callable[[str], bool] = lambda x: os.path.basename(x) in valid_fns 272 | 273 | super().__init__( 274 | root=os.path.join(root, self.directory), 275 | transforms=transforms, 276 | is_valid_file=is_in_split, 277 | ) 278 | 279 | 280 | class RESISC45: 281 | def __init__(self, preprocess, location=DATA_DIR, batch_size=32, num_workers=16): 282 | 283 | self.train_dataset = RESISC45Dataset(root=location, split="train", transforms=preprocess) 284 | self.train_loader = torch.utils.data.DataLoader( 285 | self.train_dataset, 286 | shuffle=True, 287 | batch_size=batch_size, 288 | num_workers=num_workers, 289 | ) 290 | 291 | self.test_dataset = RESISC45Dataset(root=location, split="test", transforms=preprocess) 292 | self.test_loader = torch.utils.data.DataLoader( 293 | self.test_dataset, batch_size=batch_size, num_workers=num_workers 294 | ) 295 | 296 | # class names have _ so split on this for better zero-shot head 297 | self.classnames = [" ".join(c.split("_")) for c in RESISC45Dataset.classes] 298 | -------------------------------------------------------------------------------- /src/datasets/sst2.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torchvision.datasets as datasets 4 | 5 | from src.utils.variables_and_paths import DATA_DIR 6 | 7 | 8 | class RenderedSST2: 9 | def __init__(self, preprocess, location=DATA_DIR, batch_size=128, num_workers=16): 10 | 11 | location = os.path.join(location, "RenderedSST2") 12 | self.train_dataset = datasets.RenderedSST2(root=location, download=True, split="train", transform=preprocess) 13 | 14 | self.train_loader = torch.utils.data.DataLoader( 15 | self.train_dataset, 16 | batch_size=batch_size, 17 | shuffle=True, 18 | num_workers=num_workers, 19 | ) 20 | 21 | self.test_dataset = datasets.RenderedSST2(root=location, download=True, split="test", transform=preprocess) 22 | 23 | self.test_loader = torch.utils.data.DataLoader( 24 | self.test_dataset, 25 | batch_size=batch_size, 26 | shuffle=False, 27 | num_workers=num_workers, 28 | ) 29 | 30 | self.classnames = self.train_dataset.classes 31 | -------------------------------------------------------------------------------- /src/datasets/stl10.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torchvision.datasets as datasets 4 | from src.utils.variables_and_paths import DATA_DIR 5 | 6 | 7 | class STL10: 8 | def __init__(self, preprocess, location=DATA_DIR, batch_size=128, num_workers=16): 9 | 10 | location = "/home/kewang/data" 11 | 12 | location = os.path.join(location, "stl10") 13 | self.train_dataset = datasets.STL10(root=location, download=True, split="train", transform=preprocess) 14 | 15 | self.train_loader = torch.utils.data.DataLoader( 16 | self.train_dataset, 17 | batch_size=batch_size, 18 | shuffle=True, 19 | num_workers=num_workers, 20 | ) 21 | 22 | self.test_dataset = datasets.STL10(root=location, download=True, split="test", transform=preprocess) 23 | 24 | self.test_loader = torch.utils.data.DataLoader( 25 | self.test_dataset, 26 | batch_size=batch_size, 27 | shuffle=False, 28 | num_workers=num_workers, 29 | ) 30 | 31 | self.classnames = self.train_dataset.classes 32 | -------------------------------------------------------------------------------- /src/datasets/sun397.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torchvision.datasets as datasets 4 | from src.utils.variables_and_paths import DATA_DIR 5 | 6 | 7 | class SUN397: 8 | def __init__(self, preprocess, location=DATA_DIR, batch_size=32, num_workers=16): 9 | # Data loading code 10 | traindir = os.path.join(location, "sun397", "train") 11 | valdir = os.path.join(location, "sun397", "val") 12 | 13 | self.train_dataset = datasets.ImageFolder(traindir, transform=preprocess) 14 | self.train_loader = torch.utils.data.DataLoader( 15 | self.train_dataset, 16 | shuffle=True, 17 | batch_size=batch_size, 18 | num_workers=num_workers, 19 | ) 20 | 21 | self.test_dataset = datasets.ImageFolder(valdir, transform=preprocess) 22 | self.test_loader = torch.utils.data.DataLoader( 23 | self.test_dataset, batch_size=batch_size, num_workers=num_workers 24 | ) 25 | idx_to_class = dict((v, k) for k, v in self.train_dataset.class_to_idx.items()) 26 | self.classnames = [idx_to_class[i][2:].replace("_", " ") for i in range(len(idx_to_class))] 27 | -------------------------------------------------------------------------------- /src/datasets/svhn.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torchvision.datasets import SVHN as PyTorchSVHN 4 | import numpy as np 5 | 6 | from src.utils.variables_and_paths import DATA_DIR 7 | 8 | 9 | class SVHN: 10 | def __init__(self, preprocess, location=DATA_DIR, batch_size=128, num_workers=16): 11 | # to fit with repo conventions for location 12 | modified_location = os.path.join(location, "svhn") 13 | 14 | self.train_dataset = PyTorchSVHN(root=modified_location, download=True, split="train", transform=preprocess) 15 | 16 | self.train_loader = torch.utils.data.DataLoader( 17 | self.train_dataset, 18 | batch_size=batch_size, 19 | shuffle=True, 20 | num_workers=num_workers, 21 | ) 22 | 23 | self.test_dataset = PyTorchSVHN(root=modified_location, download=True, split="test", transform=preprocess) 24 | 25 | self.test_loader = torch.utils.data.DataLoader( 26 | self.test_dataset, 27 | batch_size=batch_size, 28 | shuffle=False, 29 | num_workers=num_workers, 30 | ) 31 | 32 | self.classnames = ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"] 33 | -------------------------------------------------------------------------------- /src/datasets/templates.py: -------------------------------------------------------------------------------- 1 | cars_template = [ 2 | lambda c: f"a photo of a {c}.", 3 | lambda c: f"a photo of the {c}.", 4 | lambda c: f"a photo of my {c}.", 5 | lambda c: f"i love my {c}!", 6 | lambda c: f"a photo of my dirty {c}.", 7 | lambda c: f"a photo of my clean {c}.", 8 | lambda c: f"a photo of my new {c}.", 9 | lambda c: f"a photo of my old {c}.", 10 | ] 11 | 12 | cifar10_template = [ 13 | lambda c: f"a photo of a {c}.", 14 | lambda c: f"a blurry photo of a {c}.", 15 | lambda c: f"a black and white photo of a {c}.", 16 | lambda c: f"a low contrast photo of a {c}.", 17 | lambda c: f"a high contrast photo of a {c}.", 18 | lambda c: f"a bad photo of a {c}.", 19 | lambda c: f"a good photo of a {c}.", 20 | lambda c: f"a photo of a small {c}.", 21 | lambda c: f"a photo of a big {c}.", 22 | lambda c: f"a photo of the {c}.", 23 | lambda c: f"a blurry photo of the {c}.", 24 | lambda c: f"a black and white photo of the {c}.", 25 | lambda c: f"a low contrast photo of the {c}.", 26 | lambda c: f"a high contrast photo of the {c}.", 27 | lambda c: f"a bad photo of the {c}.", 28 | lambda c: f"a good photo of the {c}.", 29 | lambda c: f"a photo of the small {c}.", 30 | lambda c: f"a photo of the big {c}.", 31 | ] 32 | 33 | cifar100_template = [ 34 | lambda c: f"a photo of a {c}.", 35 | lambda c: f"a blurry photo of a {c}.", 36 | lambda c: f"a black and white photo of a {c}.", 37 | lambda c: f"a low contrast photo of a {c}.", 38 | lambda c: f"a high contrast photo of a {c}.", 39 | lambda c: f"a bad photo of a {c}.", 40 | lambda c: f"a good photo of a {c}.", 41 | lambda c: f"a photo of a small {c}.", 42 | lambda c: f"a photo of a big {c}.", 43 | lambda c: f"a photo of the {c}.", 44 | lambda c: f"a blurry photo of the {c}.", 45 | lambda c: f"a black and white photo of the {c}.", 46 | lambda c: f"a low contrast photo of the {c}.", 47 | lambda c: f"a high contrast photo of the {c}.", 48 | lambda c: f"a bad photo of the {c}.", 49 | lambda c: f"a good photo of the {c}.", 50 | lambda c: f"a photo of the small {c}.", 51 | lambda c: f"a photo of the big {c}.", 52 | ] 53 | 54 | dtd_template = [ 55 | lambda c: f"a photo of a {c} texture.", 56 | lambda c: f"a photo of a {c} pattern.", 57 | lambda c: f"a photo of a {c} thing.", 58 | lambda c: f"a photo of a {c} object.", 59 | lambda c: f"a photo of the {c} texture.", 60 | lambda c: f"a photo of the {c} pattern.", 61 | lambda c: f"a photo of the {c} thing.", 62 | lambda c: f"a photo of the {c} object.", 63 | ] 64 | 65 | eurosat_template = [ 66 | lambda c: f"a centered satellite photo of {c}.", 67 | lambda c: f"a centered satellite photo of a {c}.", 68 | lambda c: f"a centered satellite photo of the {c}.", 69 | ] 70 | 71 | food101_template = [ 72 | lambda c: f"a photo of {c}, a type of food.", 73 | ] 74 | 75 | gtsrb_template = [ 76 | lambda c: f'a zoomed in photo of a "{c}" traffic sign.', 77 | lambda c: f'a centered photo of a "{c}" traffic sign.', 78 | lambda c: f'a close up photo of a "{c}" traffic sign.', 79 | ] 80 | 81 | mnist_template = [ 82 | lambda c: f'a photo of the number: "{c}".', 83 | ] 84 | 85 | imagenet_template = [ 86 | lambda c: f"a bad photo of a {c}.", 87 | lambda c: f"a photo of many {c}.", 88 | lambda c: f"a sculpture of a {c}.", 89 | lambda c: f"a photo of the hard to see {c}.", 90 | lambda c: f"a low resolution photo of the {c}.", 91 | lambda c: f"a rendering of a {c}.", 92 | lambda c: f"graffiti of a {c}.", 93 | lambda c: f"a bad photo of the {c}.", 94 | lambda c: f"a cropped photo of the {c}.", 95 | lambda c: f"a tattoo of a {c}.", 96 | lambda c: f"the embroidered {c}.", 97 | lambda c: f"a photo of a hard to see {c}.", 98 | lambda c: f"a bright photo of a {c}.", 99 | lambda c: f"a photo of a clean {c}.", 100 | lambda c: f"a photo of a dirty {c}.", 101 | lambda c: f"a dark photo of the {c}.", 102 | lambda c: f"a drawing of a {c}.", 103 | lambda c: f"a photo of my {c}.", 104 | lambda c: f"the plastic {c}.", 105 | lambda c: f"a photo of the cool {c}.", 106 | lambda c: f"a close-up photo of a {c}.", 107 | lambda c: f"a black and white photo of the {c}.", 108 | lambda c: f"a painting of the {c}.", 109 | lambda c: f"a painting of a {c}.", 110 | lambda c: f"a pixelated photo of the {c}.", 111 | lambda c: f"a sculpture of the {c}.", 112 | lambda c: f"a bright photo of the {c}.", 113 | lambda c: f"a cropped photo of a {c}.", 114 | lambda c: f"a plastic {c}.", 115 | lambda c: f"a photo of the dirty {c}.", 116 | lambda c: f"a jpeg corrupted photo of a {c}.", 117 | lambda c: f"a blurry photo of the {c}.", 118 | lambda c: f"a photo of the {c}.", 119 | lambda c: f"a good photo of the {c}.", 120 | lambda c: f"a rendering of the {c}.", 121 | lambda c: f"a {c} in a video game.", 122 | lambda c: f"a photo of one {c}.", 123 | lambda c: f"a doodle of a {c}.", 124 | lambda c: f"a close-up photo of the {c}.", 125 | lambda c: f"a photo of a {c}.", 126 | lambda c: f"the origami {c}.", 127 | lambda c: f"the {c} in a video game.", 128 | lambda c: f"a sketch of a {c}.", 129 | lambda c: f"a doodle of the {c}.", 130 | lambda c: f"a origami {c}.", 131 | lambda c: f"a low resolution photo of a {c}.", 132 | lambda c: f"the toy {c}.", 133 | lambda c: f"a rendition of the {c}.", 134 | lambda c: f"a photo of the clean {c}.", 135 | lambda c: f"a photo of a large {c}.", 136 | lambda c: f"a rendition of a {c}.", 137 | lambda c: f"a photo of a nice {c}.", 138 | lambda c: f"a photo of a weird {c}.", 139 | lambda c: f"a blurry photo of a {c}.", 140 | lambda c: f"a cartoon {c}.", 141 | lambda c: f"art of a {c}.", 142 | lambda c: f"a sketch of the {c}.", 143 | lambda c: f"a embroidered {c}.", 144 | lambda c: f"a pixelated photo of a {c}.", 145 | lambda c: f"itap of the {c}.", 146 | lambda c: f"a jpeg corrupted photo of the {c}.", 147 | lambda c: f"a good photo of a {c}.", 148 | lambda c: f"a plushie {c}.", 149 | lambda c: f"a photo of the nice {c}.", 150 | lambda c: f"a photo of the small {c}.", 151 | lambda c: f"a photo of the weird {c}.", 152 | lambda c: f"the cartoon {c}.", 153 | lambda c: f"art of the {c}.", 154 | lambda c: f"a drawing of the {c}.", 155 | lambda c: f"a photo of the large {c}.", 156 | lambda c: f"a black and white photo of a {c}.", 157 | lambda c: f"the plushie {c}.", 158 | lambda c: f"a dark photo of a {c}.", 159 | lambda c: f"itap of a {c}.", 160 | lambda c: f"graffiti of the {c}.", 161 | lambda c: f"a toy {c}.", 162 | lambda c: f"itap of my {c}.", 163 | lambda c: f"a photo of a cool {c}.", 164 | lambda c: f"a photo of a small {c}.", 165 | lambda c: f"a tattoo of the {c}.", 166 | ] 167 | 168 | resisc45_template = [ 169 | lambda c: f"satellite imagery of {c}.", 170 | lambda c: f"aerial imagery of {c}.", 171 | lambda c: f"satellite photo of {c}.", 172 | lambda c: f"aerial photo of {c}.", 173 | lambda c: f"satellite view of {c}.", 174 | lambda c: f"aerial view of {c}.", 175 | lambda c: f"satellite imagery of a {c}.", 176 | lambda c: f"aerial imagery of a {c}.", 177 | lambda c: f"satellite photo of a {c}.", 178 | lambda c: f"aerial photo of a {c}.", 179 | lambda c: f"satellite view of a {c}.", 180 | lambda c: f"aerial view of a {c}.", 181 | lambda c: f"satellite imagery of the {c}.", 182 | lambda c: f"aerial imagery of the {c}.", 183 | lambda c: f"satellite photo of the {c}.", 184 | lambda c: f"aerial photo of the {c}.", 185 | lambda c: f"satellite view of the {c}.", 186 | lambda c: f"aerial view of the {c}.", 187 | ] 188 | 189 | stl10_template = [ 190 | lambda c: f"a photo of a {c}.", 191 | lambda c: f"a photo of the {c}.", 192 | ] 193 | 194 | sun397_template = [ 195 | lambda c: f"a photo of a {c}.", 196 | lambda c: f"a photo of the {c}.", 197 | ] 198 | 199 | svhn_template = [ 200 | lambda c: f'a photo of the number: "{c}".', 201 | ] 202 | 203 | flowers102_template = [ 204 | lambda c: f"a photo of a {c}, a type of flower.", 205 | ] 206 | 207 | fer2013_template = [ 208 | lambda c: f"a photo of a {c} looking face.", 209 | lambda c: f"a photo of a face showing the emotion: {c}.", 210 | lambda c: f"a photo of a face looking {c}.", 211 | lambda c: f"a face that looks {c}.", 212 | lambda c: f"they look {c}.", 213 | lambda c: f"look at how {c} they are.", 214 | ] 215 | 216 | pcam_template = [ 217 | lambda c: f"this is a photo of {c}", 218 | ] 219 | 220 | oxfordpets_template = [ 221 | lambda c: f"a photo of a {c}, a type of pet.", 222 | ] 223 | 224 | sst2_template = [ 225 | lambda c: f"a {c} review of a movie.", 226 | ] 227 | 228 | emnist_template = [ 229 | lambda c: f'a photo of the digit character: "{c}".', 230 | ] 231 | 232 | fashionmnist_template = [ 233 | lambda c: f"a photo of a {c}.", 234 | lambda c: f"a photo of the {c}.", 235 | ] 236 | 237 | kmnist_template = [ 238 | lambda c: f"a photo of the character {c}.", 239 | ] 240 | 241 | dataset_to_template = { 242 | "Cars": cars_template, 243 | "CIFAR10": cifar10_template, 244 | "CIFAR100": cifar100_template, 245 | "DTD": dtd_template, 246 | "EuroSAT": eurosat_template, 247 | "Food101": food101_template, 248 | "GTSRB": gtsrb_template, 249 | "MNIST": mnist_template, 250 | "ImageNet": imagenet_template, 251 | "RESISC45": resisc45_template, 252 | "STL10": stl10_template, 253 | "SUN397": sun397_template, 254 | "SVHN": svhn_template, 255 | "Flowers102": flowers102_template, 256 | "FER2013": fer2013_template, 257 | "PCAM": pcam_template, 258 | "OxfordIIITPet": oxfordpets_template, 259 | "RenderedSST2": sst2_template, 260 | "EMNIST": emnist_template, 261 | "FashionMNIST": fashionmnist_template, 262 | "KMNIST": kmnist_template, 263 | } 264 | 265 | 266 | def get_templates(dataset_name): 267 | if dataset_name.endswith("Val"): 268 | return get_templates(dataset_name.replace("Val", "")) 269 | assert dataset_name in dataset_to_template, f"Unsupported dataset: {dataset_name}" 270 | return dataset_to_template[dataset_name] 271 | -------------------------------------------------------------------------------- /src/eval/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wang-kee/LiNeS/43a7bdd564a16195b164abc2c64ec1856c59c211/src/eval/__init__.py -------------------------------------------------------------------------------- /src/eval/aggregation.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Dict, Optional, Tuple 3 | 4 | import torch 5 | from omegaconf import DictConfig 6 | from omegaconf import open_dict 7 | 8 | from src.models.task_vectors import ImageEncoder, NonLinearTaskVector 9 | from src.utils.tallmask_utils import ( 10 | construct_consensus_mask, 11 | construct_tall_mask, 12 | load_tall_mask, 13 | ) 14 | from src.utils.ties_utils import ties_merging 15 | from src.utils.utils import ( 16 | check_parameterNamesMatch, 17 | check_state_dicts_equal, 18 | state_dict_to_vector, 19 | topk_values_mask, 20 | vector_to_state_dict, 21 | ) 22 | from src.utils.variables_and_paths import get_finetuned_path, get_zeroshot_path 23 | 24 | 25 | def get_all_checkpoints( 26 | config: DictConfig, 27 | ) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]: 28 | """ 29 | Retrieves all the checkpoints for the given configuration. 30 | 31 | Args: 32 | config (DictConfig): The configuration object containing the model location, datasets, and model name. 33 | 34 | Returns: 35 | Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]: A tuple containing two dictionaries. 36 | The first dictionary contains the checkpoints for each dataset in the configuration's validation datasets. 37 | The second dictionary contains the checkpoint for the zeroshot model. 38 | """ 39 | 40 | model_dir = config.model_location 41 | print("I am getting out all the checkpoints") 42 | print("datasets:", config.DATASETS_VAL) 43 | print("model:", config.model) 44 | for dataset in config.DATASETS_VAL: 45 | path = get_finetuned_path(model_dir, dataset, model=config.model) 46 | if os.path.exists(path): 47 | print(f"{path} exists") 48 | else: 49 | print(f"{path} does not exist") 50 | 51 | params = { 52 | dataset: torch.load( 53 | get_finetuned_path(model_dir, dataset, model=config.model), 54 | map_location="cpu", 55 | weights_only=True, 56 | ) 57 | for dataset in config.DATASETS_VAL 58 | } 59 | 60 | # convert dict to vector 61 | params = list(params.values()) 62 | 63 | try: 64 | ptm_check = torch.load( 65 | get_zeroshot_path(model_dir, "MNISTVal", model=config.model), 66 | map_location="cpu", 67 | weights_only=True, 68 | ) 69 | except: 70 | ptm_check = ImageEncoder(config.model).state_dict() 71 | torch.save(ptm_check, get_zeroshot_path(model_dir, "MNISTVal", model=config.model)) 72 | 73 | return params, ptm_check 74 | 75 | 76 | def create_task_vector( 77 | config: DictConfig, 78 | ) -> Tuple[torch.Tensor, Optional[Dict[str, torch.Tensor]]]: 79 | """ 80 | Creates a task vector based on the given configuration. 81 | 82 | Args: 83 | config (DictConfig): The configuration for creating the task vector. 84 | 85 | Returns: 86 | Tuple[torch.Tensor, Optional[Dict[str, torch.Tensor]]]: A tuple containing the task vector and evaluation masks 87 | (if applicable). 88 | """ 89 | 90 | ft_checks, ptm_check = get_all_checkpoints(config) 91 | check_parameterNamesMatch(ft_checks + [ptm_check]) 92 | 93 | remove_keys = [] 94 | 95 | print(f"Flattening out Checkpoints") 96 | flat_ft = torch.vstack([state_dict_to_vector(check, remove_keys) for check in ft_checks]) 97 | flat_ptm = state_dict_to_vector(ptm_check, remove_keys) 98 | 99 | # compute the task vector as {\theta_t - \theta_0}. 100 | tv_flat_checks = flat_ft - flat_ptm 101 | 102 | # check if the vectorized state dicts can be converted back to the original state dicts 103 | # covnert back the flat task vectors to state dict and see if the original and converted sd's are equal 104 | assert check_state_dicts_equal(vector_to_state_dict(flat_ptm, ptm_check, remove_keys), ptm_check) 105 | assert all( 106 | [ 107 | check_state_dicts_equal(vector_to_state_dict(flat_ft[i], ptm_check, remove_keys), ft_checks[i]) 108 | for i in range(len(ft_checks)) 109 | ] 110 | ) 111 | print(f"MODEL: {config.model}, METHOD {config.method.name}") 112 | 113 | if config.method.name == "ties": 114 | # TIES Merging 115 | merge_func = f"dis-{config.method.agg}" 116 | merged_tv = ties_merging(tv_flat_checks, reset_thresh=config.method.k, merge_func=merge_func) 117 | elif config.method.name in ["sum", "zeroshot", "average"]: 118 | # "sum" corresponds to Task Arithmetic (TA) 119 | # TA, zeroshot, weight average all construct the task vector with sum, but use different scaling factors. 120 | tv_flat_checks, _ = topk_values_mask(tv_flat_checks, K=config.method.k, return_mask=False) 121 | merged_tv = tv_flat_checks.sum(dim=0) 122 | elif config.method.name == "single_task": 123 | # load the single task vector from task_index 124 | tv_flat_checks, _ = topk_values_mask(tv_flat_checks, K=config.method.k, return_mask=False) 125 | # take only the task vector of the task_index-th task 126 | merged_tv = tv_flat_checks[config.method.task_index] 127 | elif config.method.name == "tall_mask": 128 | # construct multi-task vector 129 | if config.method.use_ties: 130 | print(f"Using TIES for constructing multi-task vector") 131 | merged_tv = ties_merging(tv_flat_checks, reset_thresh=20, merge_func=f"dis-sum") 132 | else: 133 | print(f"Using Task Arithmetic for constructing multi-task vector") 134 | tv_flat_checks, _ = topk_values_mask(tv_flat_checks, K=config.method.k, return_mask=False) 135 | merged_tv = tv_flat_checks.sum(dim=0) 136 | # get TALL masks 137 | if config.method.load_mask: 138 | # load tall masks directly from storage 139 | eval_masks = load_tall_mask(remove_keys, ptm_check, config) 140 | else: 141 | print(f"=== Constructing TALL Mask ===") 142 | # construct tall masks 143 | eval_masks = construct_tall_mask( 144 | tv_flat_checks, 145 | flat_ft, 146 | flat_ptm, 147 | merged_tv, 148 | ptm_check, 149 | remove_keys, 150 | config, 151 | ) 152 | elif config.method.name == "consensus": # consensus merging 153 | # construct consensus mask (assuming the TALL masks have already been constructed) 154 | consensus_mask = construct_consensus_mask(ptm_check, config.method.prun_thre_k, config, remove_keys) 155 | # construct multi-task vector 156 | if config.method.use_ties: 157 | merged_tv = ties_merging(tv_flat_checks, reset_thresh=20, merge_func="dis-sum") 158 | else: 159 | tv_flat_checks, _ = topk_values_mask( 160 | tv_flat_checks, K=config.method.k, return_mask=False 161 | ) # top-k mag filtering 162 | merged_tv = tv_flat_checks.sum(dim=0) 163 | # apply the consensus mask to filter multi-task vector 164 | merged_tv = merged_tv * consensus_mask 165 | elif config.method.name == "mag_masking": 166 | # Magnitude masking baseline 167 | print(f"=== Using Magnitude Masking ===") 168 | merged_tv = tv_flat_checks.sum(dim=0) 169 | _, _, eval_masks = topk_values_mask(tv_flat_checks, K=config.method.k, return_mask=True) 170 | eval_masks = [vector_to_state_dict(mask, ptm_check, remove_keys=remove_keys) for mask in eval_masks] 171 | eval_masks = {key: value for key, value in zip(config.DATASETS, eval_masks)} 172 | else: 173 | raise ValueError(f"Method {config.method.name} not defined.") 174 | 175 | # compute the l1 norm of the multi-task vector 176 | with open_dict(config): 177 | config.norm_mtv = (merged_tv).abs().sum().item() 178 | config.norm_summed_tvs = (tv_flat_checks.sum(dim=0)).abs().sum().item() 179 | 180 | merged_tv_state_dict = vector_to_state_dict(merged_tv, ptm_check, remove_keys=remove_keys) 181 | task_vector = NonLinearTaskVector(model_name=config.model, vector=merged_tv_state_dict) 182 | 183 | if config.method.name not in ["tall_mask", "mag_masking"]: 184 | eval_masks = None 185 | 186 | return task_vector, eval_masks 187 | -------------------------------------------------------------------------------- /src/eval/eval.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import time 3 | 4 | import numpy as np 5 | import torch 6 | import tqdm 7 | 8 | from src.datasets.common import get_dataloader, maybe_dictionarize 9 | from src.datasets.registry import get_dataset 10 | from src.models.heads import get_classification_head 11 | from src.models.modeling import ImageClassifier 12 | from src.models.task_vectors import _Checkpoint, _TaskVector 13 | from src.utils import utils 14 | 15 | import wandb 16 | 17 | 18 | def eval_single_dataset(image_encoder, dataset_name, args): 19 | start_time = time.time() 20 | classification_head = get_classification_head(args, dataset_name) 21 | model = ImageClassifier(image_encoder, classification_head) 22 | 23 | model.eval() 24 | 25 | dataset = get_dataset( 26 | dataset_name, 27 | model.val_preprocess, 28 | location=args.data_location, 29 | batch_size=args.batch_size, 30 | ) 31 | dataloader = get_dataloader(dataset, is_train=False, args=args, image_encoder=None) 32 | device = args.device 33 | 34 | with torch.no_grad(): 35 | top1, correct, n = 0.0, 0.0, 0.0 36 | for _, data in enumerate(dataloader): 37 | data = maybe_dictionarize(data) 38 | x = data["images"].to(device) 39 | y = data["labels"].to(device) 40 | 41 | logits = utils.get_logits(x, model) 42 | 43 | pred = logits.argmax(dim=1, keepdim=True).to(device) 44 | correct += pred.eq(y.view_as(pred)).sum().item() 45 | n += y.size(0) 46 | 47 | top1 = correct / n 48 | 49 | metrics = {"top1": top1} 50 | dt = time.time() - start_time 51 | print(f"Done evaluating on {dataset_name}.\t Accuracy: {100*top1:.2f}%.\t Total time: {dt:.2f}s") 52 | 53 | return metrics 54 | 55 | 56 | def LiNeS_scaling(task_vector, alpha, beta, num_blocks): 57 | """ 58 | LiNeS: Progressively scales the task vector based on layer depth. 59 | 60 | Parameters: 61 | ----------- 62 | task_vector : dict 63 | A dictionary representing the residual between the fine-tuned checkpoint 64 | and the pre-trained checkpoint. 65 | alpha : float 66 | The minimum scaling factor for the blocks. 67 | beta : float 68 | The maximum scaling coefficient difference between the last and first block. 69 | num_blocks : int 70 | The total number of layer blocks in the model. 71 | Returns: 72 | -------- 73 | scaled_task_vector : dict 74 | A copy of `task_vector` where each key is scaled based on the layer depth. 75 | """ 76 | 77 | scaled_task_vector = copy.deepcopy(task_vector) 78 | 79 | key_blocks = list(f".{i}." for i in range(0, num_blocks)) 80 | 81 | layer_scalings_dict = {} 82 | for k in scaled_task_vector.vector.keys(): 83 | for layer, block in enumerate(key_blocks): 84 | if block in k: 85 | layer_scalings_dict[k] = alpha + beta * (layer / (num_blocks - 1)) 86 | break 87 | 88 | print(f"LiNeS: The layers are scaled between {alpha} to {alpha + beta}") 89 | 90 | # apply scaling to the task vector 91 | scaled_task_vector.vector = { 92 | # scale with alpha for layers outside residual blocks 93 | k: scaled_task_vector.vector[k] * layer_scalings_dict.get(k, alpha) 94 | for k in scaled_task_vector.vector.keys() 95 | } 96 | 97 | return scaled_task_vector 98 | 99 | 100 | def evaluate(pretrained_checkpoint, task_vector, args, scaling_coef, eval_masks=None, test=False): 101 | per_dataset_results = {} 102 | eval_datasets = args.eval_datasets if args.control_dataset is None else args.eval_datasets + [args.control_dataset] 103 | 104 | if eval_masks != None: 105 | assert args.method.name in ["tall_mask", "mag_masking"] 106 | else: 107 | if args.method.apply_lines: 108 | # line scaling: this part is the key difference to task arithmetic and other merging methods 109 | num_blocks = 12 if args.model != "ViT-L-14" else 24 110 | if args.method.name == "single_task": 111 | task_vector = LiNeS_scaling( 112 | task_vector, 113 | alpha=scaling_coef, 114 | beta=1 - scaling_coef, 115 | num_blocks=num_blocks, 116 | ) 117 | else: 118 | # for multi-task setting, we scale alpha based on the norm of the task vectors, as well as number of tasks 119 | alpha = (args.norm_summed_tvs / args.norm_mtv) * 1 / args.num_tasks 120 | task_vector = LiNeS_scaling( 121 | task_vector, 122 | alpha=alpha, 123 | beta=scaling_coef, 124 | num_blocks=num_blocks, 125 | ) 126 | image_encoder = task_vector.apply_to(pretrained_checkpoint, scaling_coef=1.0) 127 | else: 128 | # constant scaling: baseline model merging methods 129 | image_encoder = task_vector.apply_to(pretrained_checkpoint, scaling_coef=scaling_coef) 130 | 131 | for dataset_name in eval_datasets: 132 | 133 | if eval_masks != None: 134 | sparse_task_vector = copy.deepcopy(task_vector) 135 | # remove "Val" from dataset_name 136 | mask = eval_masks[dataset_name[:-3]] if "Val" in dataset_name else eval_masks[dataset_name] 137 | # apply mask to sparsify the task vectors with Hadamard product 138 | sparse_task_vector.vector = {k: sparse_task_vector.vector[k] * mask[k].bool().cpu() for k in mask.keys()} 139 | # reconstruct theta_t^ 140 | image_encoder = sparse_task_vector.apply_to(pretrained_checkpoint, scaling_coef=1.0) 141 | 142 | # evalute performance 143 | results = eval_single_dataset(image_encoder, dataset_name, args) 144 | per_dataset_results[dataset_name + ":top1"] = results["top1"] 145 | 146 | return per_dataset_results 147 | 148 | 149 | def evaluate_task_vector_at_coef( 150 | task_vector: _TaskVector, 151 | pretrained_checkpoint: _Checkpoint, 152 | args, 153 | scaling_coef: float, 154 | eval_masks=None, 155 | test=False, 156 | ): 157 | start_time = time.time() 158 | 159 | coef_info = evaluate(pretrained_checkpoint, task_vector, args, scaling_coef, eval_masks, test) 160 | 161 | coef_info = add_normalized_accuracy(coef_info, args) 162 | coef_info["avg_normalized_top1"] = np.mean( 163 | [coef_info[dataset + ":normalized_top1"] for dataset in args.eval_datasets] 164 | ) 165 | coef_info["avg_top1"] = np.mean([coef_info[dataset + ":top1"] for dataset in args.eval_datasets]) 166 | 167 | if args.method.name == "single_task": 168 | # log both target and control accuracies 169 | coef_info = add_normalized_accuracy(coef_info, args, based_on="zeroshot") 170 | coef_info["target_accuracy"] = coef_info[args.eval_datasets[args.method.task_index] + ":top1"] 171 | coef_info["control_accuracy"] = np.mean( 172 | [ 173 | coef_info[dataset + ":top1"] 174 | for dataset in args.eval_datasets 175 | if dataset != args.eval_datasets[args.method.task_index] 176 | ] 177 | ) 178 | # normalize target accuracy with finetuned accuracy; normalize control accuracy with zeroshot accuracy 179 | coef_info["target_normalized_accuracy"] = coef_info[ 180 | args.eval_datasets[args.method.task_index] + ":normalized_top1" 181 | ] 182 | coef_info["control_normalized_accuracy"] = np.mean( 183 | [ 184 | coef_info[dataset + ":normalized_top1_zeroshot"] 185 | for dataset in args.eval_datasets 186 | if dataset != args.eval_datasets[args.method.task_index] 187 | ] 188 | ) 189 | 190 | print(f"Total evaluation time: {time.time() - start_time:.2f}s") 191 | return coef_info 192 | 193 | 194 | def evaluate_task_vector(task_vector, pretrained_checkpoint, args, eval_masks=None): 195 | info = {} 196 | 197 | if args.method.name == "tall_mask" or eval_masks is not None: 198 | scaling_coef_range = [1.0] 199 | elif args.method.name == "single_task": 200 | print(f"Fine-tuned task: {args.eval_datasets[args.method.task_index]}") 201 | if args.method.apply_lines: 202 | scaling_coef_range = np.arange(0.0, 1.1, 0.1)[::-1] 203 | else: 204 | # return the fine-tuned residual directly 205 | scaling_coef_range = [1.0] 206 | elif args.method.name == "zeroshot": 207 | scaling_coef_range = [0.0] 208 | elif args.method.name == "average": 209 | scaling_coef_range = [1 / args.num_tasks] 210 | elif args.specify_lambda != "None": 211 | scaling_coef_range = [args.specify_lambda] 212 | elif args.method.name == "ties": 213 | scaling_coef_range = np.arange(0.1, 1.6, 0.1) 214 | else: 215 | scaling_coef_range = np.linspace(0.0, 1.0, args.n_eval_points // 2 + 1)[1:] 216 | 217 | if args.method.name == "tall_mask": 218 | if args.method.load_mask: 219 | print("=" * 43, f"Evaluating the loaded TALL masks", "=" * 43) 220 | info["loaded_mask"] = evaluate_task_vector_at_coef( 221 | task_vector, 222 | pretrained_checkpoint, 223 | args, 224 | 1.0, 225 | eval_masks, 226 | ) 227 | print( 228 | "\t avg_normalized_top1: {}%\t avg_top1: {}%".format( 229 | round(info["loaded_mask"]["avg_normalized_top1"] * 100, 2), 230 | round(info["loaded_mask"]["avg_top1"] * 100, 2), 231 | ) 232 | ) 233 | else: 234 | for tall_mask_lambda in [0.2, 0.3, 0.4, 0.5, 0.6]: 235 | print("\n" * 2) 236 | print("=" * 43, f"tall_mask_lambda = {tall_mask_lambda:.2f}", "=" * 43) 237 | info[tall_mask_lambda] = evaluate_task_vector_at_coef( 238 | task_vector, 239 | pretrained_checkpoint, 240 | args, 241 | 1.0, 242 | eval_masks[tall_mask_lambda], 243 | ) 244 | print( 245 | "\t avg_normalized_top1: {}%\t avg_top1: {}%".format( 246 | round(info[tall_mask_lambda]["avg_normalized_top1"] * 100, 2), 247 | round(info[tall_mask_lambda]["avg_top1"] * 100, 2), 248 | ) 249 | ) 250 | else: 251 | best_acc = 0.0 252 | for scaling_coef in scaling_coef_range: 253 | print("\n" * 2) 254 | print("=" * 43, f"alpha = {scaling_coef:.2f}", "=" * 43) 255 | info[scaling_coef] = evaluate_task_vector_at_coef( 256 | task_vector, pretrained_checkpoint, args, scaling_coef, eval_masks 257 | ) 258 | if args.method.name == "single_task": 259 | print(f"Fine-tuned task: {args.eval_datasets[args.method.task_index]}") 260 | print( 261 | "\t target_acc: {}%\t target_acc_norm: {}%".format( 262 | round(info[scaling_coef]["target_accuracy"] * 100, 2), 263 | round(info[scaling_coef]["target_normalized_accuracy"] * 100, 2), 264 | ) 265 | ) 266 | print( 267 | "\t control_acc: {}%\t control_acc_norm: {}%".format( 268 | round(info[scaling_coef]["control_accuracy"] * 100, 2), 269 | round(info[scaling_coef]["control_normalized_accuracy"] * 100, 2), 270 | ) 271 | ) 272 | else: 273 | print( 274 | "\t avg_normalized_top1: {}%\t avg_top1: {}%".format( 275 | round(info[scaling_coef]["avg_normalized_top1"] * 100, 2), 276 | round(info[scaling_coef]["avg_top1"] * 100, 2), 277 | ) 278 | ) 279 | return info 280 | 281 | 282 | def add_normalized_accuracy(results, args, based_on="finetuned"): 283 | if based_on == "finetuned": 284 | # normalize based on the finetuned accuracy (for target tasks) 285 | for dataset_name in args.eval_datasets: 286 | results[dataset_name + ":normalized_top1"] = ( 287 | results[dataset_name + ":top1"] / args.finetuning_accuracies[dataset_name] 288 | ) 289 | elif based_on == "zeroshot": 290 | # normalize based on the zeroshot accuracy (for control tasks) 291 | for dataset_name in args.eval_datasets: 292 | results[dataset_name + ":normalized_top1_zeroshot"] = ( 293 | results[dataset_name + ":top1"] / args.zeroshot_accuracies[dataset_name + ":top1"] 294 | ) 295 | return results 296 | -------------------------------------------------------------------------------- /src/eval/eval_utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | 4 | import torch 5 | import wandb 6 | from omegaconf import open_dict 7 | 8 | from src.eval.eval import evaluate_task_vector, evaluate_task_vector_at_coef 9 | from src.utils.tallmask_utils import find_optimal_mask 10 | from src.utils.utils import find_optimal_coef, find_optimal_coef_tradeoff 11 | from src.utils.logging import log_results 12 | from src.utils.variables_and_paths import ( 13 | get_finetuned_path, 14 | get_zeroshot_path, 15 | get_single_task_accuracies_path, 16 | get_zero_shot_accuracies_path, 17 | ) 18 | 19 | 20 | def perform_eval_with_merged_vector(args, task_vector, eval_masks=None): 21 | assert task_vector is not None, "Task vector should not be None." 22 | if eval_masks is not None: 23 | assert args.method.name in ["tall_mask", "mag_masking"] 24 | with open_dict(args): 25 | args.save_dir = os.path.join(args.model_location, args.model) 26 | 27 | ft_accuracies_path = get_single_task_accuracies_path(args.model) 28 | zs_accuracies_path = get_zero_shot_accuracies_path(args.model) 29 | pretrained_checkpoint = get_zeroshot_path(args.model_location, "MNIST", args.model) 30 | 31 | with open_dict(args): 32 | with open(ft_accuracies_path) as f: 33 | args.finetuning_accuracies = json.load(f) 34 | with open(zs_accuracies_path) as f: 35 | args.zeroshot_accuracies = json.load(f)["val_best"] 36 | args.eval_datasets = args.DATASETS_VAL 37 | args.control_dataset = None 38 | 39 | # evaluate on validation set 40 | val_metrics = evaluate_task_vector(task_vector, pretrained_checkpoint, args, eval_masks=eval_masks) 41 | 42 | if args.method.name == "tall_mask": 43 | if args.method.load_mask: 44 | best_masks_for_test = eval_masks 45 | best_val_metrics = val_metrics 46 | else: 47 | # find the best mask individually for each task based on validation accuracy 48 | best_masks_for_test, best_val_metrics = find_optimal_mask(val_metrics, eval_masks, args, save_masks=True) 49 | elif args.method.name == "mag_masking": 50 | best_masks_for_test = eval_masks 51 | best_val_metrics = val_metrics[1.0] 52 | elif args.method.name == "single_task": 53 | # for single-task setting, find best hyper-parameter based on a trade-off accuracy on both target and control tasks 54 | # here tradeoff_target_weight is the weight for the target tasks in the trade-off 55 | optimal_coef = find_optimal_coef_tradeoff( 56 | val_metrics, 57 | tradeoff_target_weight=args.method.tradeoff_target_weight, 58 | minimize=False, 59 | ) 60 | best_val_metrics = val_metrics[optimal_coef] 61 | else: 62 | # find scaling factor alpha based on validation accuracy (for Task Arithmetic, TIES, Consensus Merging) 63 | optimal_coef = find_optimal_coef(val_metrics, metric="avg_normalized_top1", minimize=False) 64 | best_val_metrics = val_metrics[optimal_coef] 65 | 66 | print("\n" * 2) 67 | 68 | # Evaluate on the test set with the optimal coefficients / masks 69 | with open_dict(args): 70 | args.eval_datasets = args.DATASETS 71 | 72 | if args.method.name in ["tall_mask", "mag_masking"]: 73 | test_metrics = evaluate_task_vector_at_coef( 74 | task_vector, 75 | pretrained_checkpoint, 76 | args, 77 | 1.0, 78 | eval_masks=best_masks_for_test, 79 | ) 80 | else: 81 | test_metrics = evaluate_task_vector_at_coef( 82 | task_vector, 83 | pretrained_checkpoint, 84 | args, 85 | float(optimal_coef), 86 | eval_masks=None, 87 | test=True, 88 | ) 89 | 90 | print("=" * 100) 91 | 92 | if args.method.name == "single_task": 93 | # for single task, report both target and control task accuracies 94 | print(f"Test target task accuracy: {test_metrics['target_accuracy']:.4f}") 95 | print(f"Test target task normalized accuracy: {test_metrics['target_normalized_accuracy']:.4f}") 96 | print(f"Test control task accuracy: {test_metrics['control_accuracy']:.4f}") 97 | print(f"Test control task normalized accuracy: {test_metrics['control_normalized_accuracy']:.4f}") 98 | else: 99 | # for multi-task, report the average accuracy on all tasks 100 | print(f"Test normalized accuracy: {test_metrics['avg_normalized_top1']:.4f}") 101 | print(f"Test absolute accuracy: {test_metrics['avg_top1']:.4f}") 102 | 103 | final_results = { 104 | "test": test_metrics, 105 | "val": val_metrics, 106 | "val_best": best_val_metrics, 107 | } 108 | log_results(final_results, args) 109 | 110 | return final_results 111 | -------------------------------------------------------------------------------- /src/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .heads import get_classification_head 2 | from .modeling import ImageClassifier, ImageEncoder 3 | 4 | __all__ = [ 5 | "get_classification_head", 6 | "LinearizedImageEncoder", 7 | "ImageClassifier", 8 | "ImageEncoder", 9 | ] 10 | -------------------------------------------------------------------------------- /src/models/heads.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from typing import Callable, List 4 | 5 | import open_clip 6 | import torch 7 | import torch.nn as nn 8 | from tqdm import tqdm 9 | 10 | from src.datasets.registry import get_dataset 11 | from src.datasets.templates import get_templates 12 | from src.models.modeling import ClassificationHead, ImageEncoder 13 | from src.utils.variables_and_paths import TQDM_BAR_FORMAT 14 | 15 | 16 | def build_classification_head( 17 | model: nn.Module, 18 | dataset_name: str, 19 | template: List[Callable[[str], str]], 20 | data_location: str, 21 | device: torch.device, 22 | ) -> ClassificationHead: 23 | """ 24 | Builds a classification head for a given model and dataset. 25 | 26 | Args: 27 | model (nn.Module): The model to use for text encoding. 28 | dataset_name (str): The name of the dataset to use for zero-shot classification. 29 | template (List[Callable[[str], str]]): A list of functions that generate text templates for each class. 30 | data_location (str): The location of the dataset. 31 | device (torch.device): The device to use for computation. 32 | 33 | Returns: 34 | A ClassificationHead object with normalized weights for zero-shot classification. 35 | """ 36 | template = get_templates(dataset_name) 37 | 38 | logit_scale = model.logit_scale 39 | dataset = get_dataset(dataset_name, None, location=data_location) 40 | model.eval() 41 | model.to(device) 42 | 43 | print("Building classification head.") 44 | with torch.no_grad(): 45 | zeroshot_weights = [] 46 | for classname in tqdm(dataset.classnames, bar_format=TQDM_BAR_FORMAT): 47 | texts = [] 48 | for t in template: 49 | texts.append(t(classname)) 50 | texts = open_clip.tokenize(texts).to(device) # tokenize 51 | embeddings = model.encode_text(texts) # embed with text encoder 52 | embeddings /= embeddings.norm(dim=-1, keepdim=True) 53 | 54 | embeddings = embeddings.mean(dim=0, keepdim=True) 55 | embeddings /= embeddings.norm() 56 | 57 | zeroshot_weights.append(embeddings) 58 | 59 | zeroshot_weights = torch.stack(zeroshot_weights, dim=0).to(device) 60 | zeroshot_weights = torch.transpose(zeroshot_weights, 0, 2) 61 | 62 | zeroshot_weights *= logit_scale.exp() 63 | 64 | zeroshot_weights = zeroshot_weights.squeeze().float() 65 | zeroshot_weights = torch.transpose(zeroshot_weights, 0, 1) 66 | 67 | print(f"zeroshot shape, P{zeroshot_weights.shape}") 68 | classification_head = ClassificationHead(normalize=True, weights=zeroshot_weights) 69 | return classification_head 70 | 71 | 72 | def get_classification_head(args: argparse.Namespace, dataset: str) -> nn.Module: 73 | """ 74 | Retrieves or builds a classification head for a given model and dataset. 75 | 76 | If the classification head file does not exist, it builds one from scratch in the location specified by `args.save_dir`. 77 | 78 | Args: 79 | args (argparse.Namespace): The command-line arguments. 80 | dataset (str): The name of the dataset. 81 | 82 | Returns: 83 | nn.Module: The classification head module. 84 | 85 | Raises: 86 | FileNotFoundError: If the classification head file does not exist. 87 | 88 | """ 89 | if not dataset.endswith("Val"): 90 | dataset += "Val" 91 | 92 | filename = os.path.join(args.save_dir, f"head_{dataset}.pt") 93 | if os.path.exists(filename): 94 | print(f"Loading classification head for {args.model} on {dataset} from {filename}") 95 | return ClassificationHead.load(filename) 96 | print(f"Did not find classification head for {args.model} on {dataset} at {filename}, building one from scratch.") 97 | model = ImageEncoder(args.model, keep_lang=True).model 98 | template = get_templates(dataset) 99 | 100 | classification_head = build_classification_head(model, dataset, template, args.data_location, args.device) 101 | os.makedirs(args.save_dir, exist_ok=True) 102 | classification_head.save(filename) 103 | return classification_head 104 | -------------------------------------------------------------------------------- /src/models/modeling.py: -------------------------------------------------------------------------------- 1 | import open_clip 2 | import torch 3 | 4 | from src.utils import utils 5 | from src.utils.variables_and_paths import CACHEDIR, MODELS, OPENCLIP_CACHEDIR 6 | 7 | 8 | class ImageEncoder(torch.nn.Module): 9 | def __init__(self, model_name: str, keep_lang=False): 10 | super().__init__() 11 | assert model_name in MODELS, f"Invalid model name: {model_name}. Valid models are: {MODELS}" 12 | 13 | if "__pretrained__" in model_name: 14 | name, pretrained = model_name.split("__pretrained__") 15 | elif "__init__" in model_name: 16 | print("Using random initialization.") 17 | name, pretrained = model_name.split("__init__")[0], None 18 | else: 19 | name = model_name 20 | pretrained = "openai" 21 | ( 22 | self.model, 23 | self.train_preprocess, 24 | self.val_preprocess, 25 | ) = open_clip.create_model_and_transforms(name, pretrained=pretrained, cache_dir=OPENCLIP_CACHEDIR) 26 | 27 | self.cache_dir = CACHEDIR 28 | 29 | if not keep_lang and hasattr(self.model, "transformer"): 30 | delattr(self.model, "transformer") 31 | 32 | def forward(self, images): 33 | assert self.model is not None 34 | return self.model.encode_image(images) 35 | 36 | def __call__(self, inputs): 37 | return self.forward(inputs) 38 | 39 | def save(self, filename): 40 | print(f"Saving image encoder to {filename}") 41 | utils.torch_save(self, filename) 42 | 43 | @classmethod 44 | def load(cls, model_name, filename): 45 | print(f"Loading image encoder from {filename}") 46 | 47 | state_dict = torch.load(filename, map_location="cpu", weights_only=True) 48 | 49 | model = cls(model_name) 50 | model.load_state_dict(state_dict) 51 | return model 52 | 53 | 54 | class ClassificationHead(torch.nn.Linear): 55 | def __init__(self, normalize, weights, biases=None): 56 | output_size, input_size = weights.shape 57 | super().__init__(input_size, output_size) 58 | self.normalize = normalize 59 | if weights is not None: 60 | self.weight = torch.nn.Parameter(weights.clone()) 61 | if biases is not None: 62 | self.bias = torch.nn.Parameter(biases.clone()) 63 | else: 64 | self.bias = torch.nn.Parameter(torch.zeros_like(self.bias)) 65 | 66 | def forward(self, inputs): 67 | if self.normalize: 68 | inputs = inputs / inputs.norm(dim=-1, keepdim=True) 69 | return super().forward(inputs) 70 | 71 | def __call__(self, inputs): 72 | return self.forward(inputs) 73 | 74 | def save(self, filename): 75 | print(f"Saving classification head to {filename}") 76 | utils.torch_save(self, filename, save_state_dict=False) 77 | 78 | @classmethod 79 | def load(cls, filename): 80 | # print(f"Loading classification head from {filename}") 81 | return utils.torch_load(filename) 82 | 83 | 84 | class ImageClassifier(torch.nn.Module): 85 | def __init__(self, image_encoder, classification_head): 86 | super().__init__() 87 | self.image_encoder = image_encoder 88 | self.classification_head = classification_head 89 | if self.image_encoder is not None: 90 | self.train_preprocess = self.image_encoder.train_preprocess 91 | self.val_preprocess = self.image_encoder.val_preprocess 92 | 93 | def freeze_head(self): 94 | self.classification_head.weight.requires_grad_(False) 95 | self.classification_head.bias.requires_grad_(False) 96 | 97 | def forward(self, inputs): 98 | features = self.image_encoder(inputs) 99 | outputs = self.classification_head(features) 100 | return outputs 101 | 102 | def __call__(self, inputs): 103 | return self.forward(inputs) 104 | 105 | def save(self, filename): 106 | print(f"Saving image classifier to {filename}") 107 | utils.torch_save(self, filename) 108 | 109 | @classmethod 110 | def load(cls, filename): 111 | print(f"Loading image classifier from {filename}") 112 | return utils.torch_load(filename) 113 | 114 | 115 | class MultiHeadImageClassifier(torch.nn.Module): 116 | def __init__(self, image_encoder, classification_heads): 117 | super().__init__() 118 | self.image_encoder = image_encoder 119 | self.classification_heads = torch.nn.ModuleList(classification_heads) 120 | if self.image_encoder is not None: 121 | self.train_preprocess = self.image_encoder.train_preprocess 122 | self.val_preprocess = self.image_encoder.val_preprocess 123 | 124 | def freeze_head(self): 125 | for idx in range(len(self.classification_heads)): 126 | self.classification_heads[idx].weight.requires_grad_(False) 127 | self.classification_heads[idx].bias.requires_grad_(False) 128 | 129 | def forward(self, inputs, head_idx): 130 | features = self.image_encoder(inputs) 131 | outputs = self.classification_heads[head_idx](features) 132 | return outputs 133 | 134 | def __call__(self, inputs, head_idx): 135 | return self.forward(inputs, head_idx) 136 | 137 | def save(self, filename): 138 | print(f"Saving image classifier to {filename}") 139 | utils.torch_save(self, filename) 140 | 141 | @classmethod 142 | def load(cls, filename): 143 | print(f"Loading image classifier from {filename}") 144 | return utils.torch_load(filename) 145 | -------------------------------------------------------------------------------- /src/models/task_vectors.py: -------------------------------------------------------------------------------- 1 | import abc 2 | from typing import OrderedDict, Union 3 | 4 | import torch 5 | 6 | from src.models.modeling import ImageEncoder 7 | from src.utils.variables_and_paths import MODELS 8 | 9 | _Checkpoint = Union[str, dict, torch.nn.Module] 10 | 11 | 12 | def symmetric_difference(A, B): 13 | """Returns the symmetric difference between two lists.""" 14 | return list(set(A) ^ set(B)) 15 | 16 | 17 | class _TaskVector(abc.ABC): 18 | def __init__( 19 | self, 20 | model_name, 21 | pretrained_checkpoint=None, 22 | finetuned_checkpoint=None, 23 | vector=None, 24 | ): 25 | """Initializes the task vector from a pretrained and a finetuned checkpoints. 26 | 27 | This can either be done by passing two state dicts (one corresponding to the 28 | pretrained model, and another to the finetuned model), or by directly passying in 29 | the task vector state dict. 30 | """ 31 | assert model_name in MODELS, f"Invalid model name: {model_name}. Valid models are: {MODELS}" 32 | self.model_name = model_name 33 | if vector is not None: 34 | self.vector = vector 35 | else: 36 | assert pretrained_checkpoint is not None and finetuned_checkpoint is not None 37 | with torch.no_grad(): 38 | # parse pretrained_checkpoint 39 | pretrained_state_dict = self._safe_load(pretrained_checkpoint) 40 | 41 | # parse finetuned_checkpoint 42 | finetuned_state_dict = self._safe_load(finetuned_checkpoint) 43 | 44 | if "model_name" in finetuned_state_dict.keys(): 45 | finetuned_state_dict.pop("model_name") 46 | 47 | # the final task vector is the difference between finetuned and pretrained vectors. 48 | assert ( 49 | pretrained_state_dict.keys() == finetuned_state_dict.keys() 50 | ), f"State dicts have different keys: {symmetric_difference(pretrained_state_dict.keys(), finetuned_state_dict.keys())}." 51 | self.vector = {} 52 | for key in pretrained_state_dict: 53 | if pretrained_state_dict[key].dtype == torch.int64: 54 | continue 55 | if pretrained_state_dict[key].dtype == torch.uint8: 56 | continue 57 | self.vector[key] = finetuned_state_dict[key] - pretrained_state_dict[key] 58 | 59 | def _safe_load(self, checkpoint): 60 | if isinstance(checkpoint, str): 61 | return self._load_checkpoint(checkpoint).state_dict() 62 | elif isinstance(checkpoint, dict): 63 | return checkpoint 64 | elif isinstance(checkpoint, torch.nn.Module): 65 | return checkpoint.state_dict() 66 | else: 67 | raise ValueError(f"Invalid type for checkpoint: {type(checkpoint)}") 68 | 69 | @abc.abstractmethod 70 | def _load_checkpoint(self, checkpoint) -> torch.nn.Module: 71 | """Load a checkpoint into a model.""" 72 | raise NotImplementedError 73 | 74 | def __add__(self, other): 75 | """Add two task vectors together.""" 76 | with torch.no_grad(): 77 | new_vector = {} 78 | for key in self.vector: 79 | if key not in other.vector: 80 | print(f"Warning, key {key} is not present in both task vectors.") 81 | continue 82 | new_vector[key] = self.vector[key] + other.vector[key] 83 | return self.__class__(vector=new_vector) 84 | 85 | def __sub__(self, other): 86 | """Subtract two task vectors.""" 87 | return self.__add__(-other) 88 | 89 | def __radd__(self, other): 90 | if other is None or isinstance(other, int): 91 | return self 92 | return self.__add__(other) 93 | 94 | def __neg__(self): 95 | """Negate a task vector.""" 96 | with torch.no_grad(): 97 | new_vector = {} 98 | for key in self.vector: 99 | new_vector[key] = -self.vector[key] 100 | return self.__class__(vector=new_vector) 101 | 102 | def __pow__(self, power): 103 | """Power of a task vector.""" 104 | with torch.no_grad(): 105 | new_vector = {} 106 | for key in self.vector: 107 | new_vector[key] = self.vector[key] ** power 108 | return self.__class__(vector=new_vector) 109 | 110 | def __mul__(self, other): 111 | """Multiply a task vector by a scalar.""" 112 | with torch.no_grad(): 113 | new_vector = {} 114 | for key in self.vector: 115 | new_vector[key] = other * self.vector[key] 116 | return self.__class__(vector=new_vector) 117 | 118 | def dot(self, other): 119 | """Dot product of two task vectors.""" 120 | with torch.no_grad(): 121 | dot_product = 0.0 122 | for key in self.vector: 123 | if key not in other.vector: 124 | print(f"Warning, key {key} is not present in both task vectors.") 125 | continue 126 | dot_product += torch.sum(self.vector[key] * other.vector[key]) 127 | return dot_product 128 | 129 | def norm(self): 130 | """Norm of a task vector.""" 131 | return torch.sqrt(self.dot(self)) 132 | 133 | def apply_to(self, pretrained_checkpoint, scaling_coef=1.0): 134 | """Apply a task vector to a pretrained model.""" 135 | with torch.no_grad(): 136 | pretrained_model = self._load_checkpoint(pretrained_checkpoint) 137 | new_state_dict = {} 138 | pretrained_state_dict = pretrained_model.state_dict() 139 | for key in pretrained_state_dict: 140 | if key not in self.vector: 141 | print( 142 | f"Warning: key {key} is present in the pretrained state dict but not in the task vector" # noqa: E501 143 | ) 144 | continue 145 | new_state_dict[key] = ( 146 | pretrained_state_dict[key].to(self.vector[key].device) + scaling_coef * self.vector[key] 147 | ) 148 | pretrained_model.load_state_dict(new_state_dict) 149 | return pretrained_model 150 | 151 | 152 | class NonLinearTaskVector(_TaskVector): 153 | """A task vector for nonlinear models.""" 154 | 155 | def _load_checkpoint(self, checkpoint): 156 | """Load a checkpoint into a model.""" 157 | return ImageEncoder.load(self.model_name, checkpoint) 158 | 159 | def apply_to_nonlinear(self, pretrained_nonlinear_checkpoint, scaling_coef=1.0): 160 | """Apply a task vector to a nonlinear pretrained model.""" 161 | return self.apply_to(pretrained_nonlinear_checkpoint, scaling_coef) 162 | -------------------------------------------------------------------------------- /src/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .args import parse_arguments 2 | from .logging import initialize_wandb, wandb_log 3 | from .utils import find_optimal_coef 4 | 5 | __all__ = ["parse_arguments", "initialize_wandb", "wandb_log", "find_optimal_coef"] 6 | -------------------------------------------------------------------------------- /src/utils/args.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | 4 | import torch 5 | 6 | 7 | def parse_arguments(): 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument( 10 | "--data-dir", 11 | type=str, 12 | default=os.path.expanduser("~/data"), 13 | help="The root directory for the datasets.", 14 | ) 15 | parser.add_argument( 16 | "--eval-datasets", 17 | default=None, 18 | type=lambda x: x.split(","), 19 | help="Which datasets to use for evaluation. Split by comma, e.g. MNIST,EuroSAT. ", 20 | ) 21 | parser.add_argument( 22 | "--train-dataset", 23 | default=None, 24 | type=lambda x: x.split(","), 25 | help="Which dataset(s) to patch on.", 26 | ) 27 | parser.add_argument( 28 | "--exp_name", 29 | type=str, 30 | default=None, 31 | help="Name of the experiment, for organization purposes only.", 32 | ) 33 | parser.add_argument( 34 | "--results-db", 35 | type=str, 36 | default=None, 37 | help="Where to store the results, else does not store", 38 | ) 39 | parser.add_argument( 40 | "--model", 41 | type=str, 42 | default="ViT-B-32", 43 | help="The type of model (e.g. RN50, ViT-B-32).", 44 | ) 45 | parser.add_argument("--batch-size", type=int, default=128) 46 | parser.add_argument( 47 | "--num-grad-accumulation", 48 | type=int, 49 | default=1, 50 | help="Number of gradient accumulation steps.", 51 | ) 52 | parser.add_argument("--lr", type=float, default=0.001, help="Learning rate.") 53 | parser.add_argument("--wd", type=float, default=0.1, help="Weight decay") 54 | parser.add_argument("--ls", type=float, default=0.0, help="Label smoothing.") 55 | parser.add_argument("--warmup_length", type=int, default=500) 56 | parser.add_argument("--epochs", type=int, default=10) 57 | parser.add_argument( 58 | "--load", 59 | type=lambda x: x.split(","), 60 | default=None, 61 | help="Optionally load _classifiers_, e.g. a zero shot classifier or probe or ensemble both.", 62 | ) 63 | parser.add_argument( 64 | "--cache-dir", 65 | type=str, 66 | default=None, 67 | help="Directory for caching features and encoder", 68 | ) 69 | parser.add_argument( 70 | "--openclip-cachedir", 71 | type=str, 72 | default=os.path.expanduser("~/openclip-cachedir/open_clip"), 73 | help="Directory for caching models from OpenCLIP", 74 | ) 75 | parser.add_argument( 76 | "--world-size", 77 | type=int, 78 | default=1, 79 | help="Number of processes for distributed training.", 80 | ) 81 | parser.add_argument( 82 | "--checkpoint-every", 83 | type=int, 84 | default=-1, 85 | help="How often to checkpoint the model.", 86 | ) 87 | parser.add_argument("--port", type=int, default=12355, help="Port for distributed training.") 88 | parser.add_argument("--seed", type=int, default=None, help="Random seed.") 89 | parser.add_argument( 90 | "--n-eval-points", 91 | type=int, 92 | default=21, 93 | help="Number of evaluation points used to find optimal coefficient in task arithmetic.", 94 | ) 95 | parser.add_argument( 96 | "--finetuning-mode", 97 | type=str, 98 | default="standard", 99 | help="Finetuned mode; standard for nonlinear finetune; none for zeroshot", 100 | ) 101 | parser.add_argument( 102 | "--model-location", 103 | type=str, 104 | default="/mnt/lts4/scratch/checkpoints/tall_mask_checkpoints/mount/model_checkpoints", 105 | help="Directory for model location", 106 | ) 107 | parser.add_argument( 108 | "--data-location", 109 | type=str, 110 | default="/mnt/lts4/scratch/data", 111 | help="Directory for data location", 112 | ) 113 | 114 | parsed_args = parser.parse_args() 115 | parsed_args.device = "cuda" if torch.cuda.is_available() else "cpu" 116 | 117 | if parsed_args.load is not None and len(parsed_args.load) == 1: 118 | parsed_args.load = parsed_args.load[0] 119 | return parsed_args 120 | -------------------------------------------------------------------------------- /src/utils/distributed.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | 5 | 6 | def setup_ddp(rank, world_size, port=12357): 7 | """ 8 | Set up the Distributed Data Parallel (DDP) environment for multi-GPU training. 9 | 10 | Args: 11 | rank (int): The rank of the current process. 12 | world_size (int): The total number of processes. 13 | port (int, optional): The port number for communication. Defaults to 12357. 14 | """ 15 | os.environ["MASTER_ADDR"] = "localhost" 16 | os.environ["MASTER_PORT"] = str(port) 17 | 18 | # initialize the process group 19 | torch.distributed.init_process_group( 20 | "nccl", 21 | rank=rank, 22 | world_size=world_size, 23 | ) 24 | torch.cuda.set_device(rank) 25 | torch.distributed.barrier() 26 | 27 | 28 | def cleanup_ddp(): 29 | """ 30 | Cleans up the distributed data parallel (DDP) process group. 31 | 32 | This function is responsible for cleaning up the DDP process group after training or inference is complete. 33 | It ensures that all resources used by the DDP process group""" 34 | torch.distributed.destroy_process_group() 35 | 36 | 37 | def is_main_process(): 38 | return torch.distributed.get_rank() == 0 39 | 40 | 41 | def distribute_loader(loader): 42 | return torch.utils.data.DataLoader( 43 | loader.dataset, 44 | batch_size=loader.batch_size // torch.distributed.get_world_size(), 45 | sampler=torch.utils.data.distributed.DistributedSampler( 46 | loader.dataset, 47 | num_replicas=torch.distributed.get_world_size(), 48 | rank=torch.distributed.get_rank(), 49 | ), 50 | num_workers=loader.num_workers, 51 | pin_memory=loader.pin_memory, 52 | ) 53 | -------------------------------------------------------------------------------- /src/utils/logging.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import hydra 4 | from argparse import Namespace 5 | from pprint import pprint 6 | 7 | import wandb 8 | 9 | from src.utils.distributed import is_main_process 10 | 11 | 12 | def initialize_wandb(args, disabled=True): 13 | if disabled: 14 | # for debugging 15 | wandb.init(config=args, mode="disabled") 16 | else: 17 | wandb.init(config=args) 18 | 19 | if wandb.run is not None: 20 | INVALID_PATHS = [ 21 | "__old__", 22 | "checkpoints", 23 | "logs", 24 | "outputs", 25 | "results", 26 | "wandb", 27 | ] 28 | wandb.run.log_code( 29 | exclude_fn=lambda path: any( 30 | [path.startswith(os.path.expanduser(os.getcwd() + "/" + i)) for i in INVALID_PATHS] 31 | ) 32 | ) 33 | return wandb 34 | 35 | 36 | def wandb_log(dictionary: dict): 37 | if is_main_process(): 38 | wandb.log(dictionary) 39 | 40 | 41 | def log_results(final_results, args): 42 | if args.method.name == "tall_mask": 43 | mask_suffix = f"tall_mask_ties" if args.method.use_ties else f"tall_mask_ta" 44 | elif args.method.name == "mag_masking": 45 | mask_suffix = "mag_mask" 46 | elif args.method.name == "consensus": 47 | mask_suffix = ( 48 | f"k_{args.method.prun_thre_k}_ties" if args.method.use_ties else f"k_{args.method.prun_thre_k}_ta" 49 | ) 50 | else: 51 | mask_suffix = "" 52 | 53 | if "ties" in args.method.full_name: 54 | try: 55 | method_name = args.method.full_name + "_" + method.agg 56 | except: 57 | method_name = args.method.full_name 58 | else: 59 | method_name = args.method.full_name 60 | 61 | lines_suffix = "_lines" if args.method.apply_lines else "_" 62 | method_dir = "editing_single_task" if args.method.name == "single_task" else "merging_multi_task" 63 | 64 | if args.method.name == "single_task": 65 | save_file = f"results/{method_dir}/{args.model}_{args.num_tasks}tasks_{method_name}_finetuned_on_{args.method.task_index}task_{mask_suffix}{lines_suffix}.json" 66 | else: 67 | save_file = f"results/{method_dir}/{args.model}_{args.num_tasks}tasks_{method_name}_merged_{mask_suffix}{lines_suffix}.json" 68 | 69 | with open(save_file, "w") as f: 70 | json.dump(final_results, f, indent=4) 71 | hydra_dir = hydra.core.hydra_config.HydraConfig.get().runtime.output_dir 72 | hydra_save_file = f"{args.method.full_name}_nonlinear_additions.json" 73 | hydra_save_file = os.path.join(hydra_dir, hydra_save_file) 74 | json.dump(final_results, open(hydra_save_file, "w"), indent=4) 75 | 76 | print("saved results to: ", save_file) 77 | print("saved results to: ", hydra_save_file) 78 | artifact = wandb.Artifact(name="final_results", type="results") 79 | artifact.add_file(save_file) 80 | wandb.log_artifact(artifact) 81 | -------------------------------------------------------------------------------- /src/utils/tallmask_utils.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import os 3 | from typing import List, Optional 4 | 5 | import numpy as np 6 | import torch 7 | import wandb 8 | 9 | from .utils import state_dict_to_vector, vector_to_state_dict 10 | from .variables_and_paths import ALL_DATASETS 11 | 12 | 13 | def log_wandb_mask_sparsity(final_mask: torch.Tensor): 14 | """ 15 | Logs the mask sparsity for each dataset to Weights & Biases (wandb). 16 | 17 | Args: 18 | final_mask (torch.Tensor): The final mask tensor. 19 | """ 20 | dataset_sparsities = final_mask.float().mean(1) 21 | 22 | for i in range(len(final_mask)): 23 | dataset = ALL_DATASETS[i] 24 | wandb.log({f"mask_sparsity_{dataset}": dataset_sparsities[i]}) 25 | 26 | 27 | def generate_task_masks( 28 | tv_flat_checks: torch.Tensor, 29 | flat_ft: torch.Tensor, 30 | flat_ptm: torch.Tensor, 31 | tv: Optional[torch.Tensor] = None, 32 | tall_mask_lambda: float = 1.0, 33 | ) -> torch.Tensor: 34 | """ 35 | Generate task-specific TALL masks 36 | TALL masks are generated as: mask_t = |theta_0 - theta_t| > |theta_mt - theta_t| * lambda 37 | 38 | Args: 39 | tv_flat_checks: individual task vectors 40 | flat_ft: individual theta_t (fine-tuned weights) 41 | flat_ptm: theta_0 (pre-trained weight) 42 | tv: multi-task vector 43 | tall_mask_lambda: hyper-parameter lambda for generating TALL masks 44 | Returns: 45 | final_mask: generated TALL masks with the given lambda, in shape (n_task, n_parameter) 46 | """ 47 | 48 | print(f"Generating TALL masks.") 49 | 50 | if tv is None: 51 | tv = tv_flat_checks.sum(0) 52 | 53 | flat_multi = flat_ptm + tv 54 | 55 | original_shape = flat_ft.shape 56 | 57 | # generate masks by comparing the l1 distance between |theta_0 - theta_t| and |theta_mt - theta_t| 58 | diff_pt_ft = (flat_ptm - flat_ft).abs() 59 | diff_multi_ft = (flat_multi - flat_ft).abs() 60 | # compare the l1 distance, scaled with hyper-parameter lambda 61 | mask = diff_pt_ft > diff_multi_ft * tall_mask_lambda 62 | 63 | final_mask = mask.squeeze() if original_shape == tv_flat_checks.squeeze().shape else mask 64 | 65 | print( 66 | f"Average sparsity for the mask with tall_mask_lambda of {tall_mask_lambda}: {final_mask.float().mean():.4f}" 67 | ) 68 | log_wandb_mask_sparsity(final_mask) 69 | 70 | return final_mask 71 | 72 | 73 | def construct_tall_mask( 74 | tv_flat_checks: torch.Tensor, 75 | flat_ft: torch.Tensor, 76 | flat_ptm: torch.Tensor, 77 | merged_tv: torch.Tensor, 78 | ptm_check: torch.Tensor, 79 | remove_keys: List[str], 80 | config, 81 | ): 82 | """ 83 | Construct TALL masks for all tasks for each lambda, and store in dictionary 84 | 85 | Args: 86 | tv_flat_checks: individual task vectors 87 | flat_ft: individual theta_t (fine-tuned weights) 88 | flat_ptm: theta_0 (pre-trained weight) 89 | merged_tv: multi-task vector 90 | ptm_check: pre-trained weight as state dictionary 91 | remove_keys: the keys to be removed when converting between dictionary and vector 92 | Returns: 93 | tall_masks: constructed TALL masks in dictionary format of {lambda: {task: mask}} 94 | """ 95 | tall_masks = {} 96 | for tall_mask_lambda in [0.2, 0.3, 0.4, 0.5, 0.6]: 97 | # generate tall masks for each lambda 98 | masks_at_scale = generate_task_masks( 99 | tv_flat_checks, 100 | flat_ft, 101 | flat_ptm, 102 | tall_mask_lambda=tall_mask_lambda, 103 | tv=merged_tv, 104 | ) 105 | # convert vectors to dictionary 106 | masks_at_scale = [vector_to_state_dict(mask, ptm_check, remove_keys=remove_keys) for mask in masks_at_scale] 107 | # store the masks with {dataset: mask} 108 | tall_masks[tall_mask_lambda] = {key: value for key, value in zip(config.DATASETS, masks_at_scale)} 109 | return tall_masks 110 | 111 | 112 | def find_optimal_mask(val_metrics, eval_masks, args, save_masks=True): 113 | """ 114 | Respectively finds the optimal mask for each data task based on the validation accuracy 115 | 116 | Args: 117 | val_metrics: validation metrics for each lambda 118 | eval_masks: all generated masks 119 | 120 | Returns: 121 | best_masks_for_test: the best masks for each task, selected based on validation accuracy from each task 122 | best_val_metrics: best validation metrics for each task 123 | """ 124 | # transpose the dict from lambda-task to task-lambda 125 | transposed_dict = {} 126 | for key, inner_dict in val_metrics.items(): 127 | for inner_key, value in inner_dict.items(): 128 | if inner_key not in transposed_dict: 129 | transposed_dict[inner_key] = {} 130 | transposed_dict[inner_key][key] = value 131 | 132 | # for each task, find the best lambda 133 | max_subkeys = {key: max(inner_dict, key=inner_dict.get) for key, inner_dict in transposed_dict.items()} 134 | 135 | # select the best mask for each task, which will be used for testing later 136 | best_masks_for_test = {} 137 | best_masks_for_test_vector = {} # the selected masks as vectors 138 | best_val_metrics = {} 139 | # respectively for each task: 140 | for ds in args.DATASETS: 141 | # select the lambda which achieves the best valdiation accuracy 142 | best_lambda = float(max_subkeys[ds + "Val:top1"]) 143 | # select the mask based on the selected lambda, save as dictionaries 144 | best_masks_for_test[ds] = eval_masks[best_lambda][ds] 145 | # select the mask based on the selected lambda, save as vectors 146 | best_masks_for_test_vector[ds] = state_dict_to_vector(eval_masks[best_lambda][ds], remove_keys=[]) 147 | print(f"Best lambda for {ds} is {best_lambda}") 148 | # save the best validation metric based on the selected lambda 149 | best_val_metrics[ds + "Val:top1"] = val_metrics[best_lambda][ds + "Val:top1"] 150 | 151 | # save the best masks in disk 152 | if save_masks and not args.method.load_mask: 153 | # convert to numpy to save with np.packbits for saving storage 154 | best_masks_for_test_vector = {k: np.packbits(v) for k, v in best_masks_for_test_vector.items()} 155 | mask_save_dir = args.model_location.replace("model_checkpoints", "tall_masks") 156 | mask_name = ( 157 | f"TALL_mask_{args.num_tasks}task.npy" 158 | if not args.method.use_ties 159 | else f"TALL_mask_{args.num_tasks}task_use_ties_{args.method.ties_agg}.npy" 160 | ) 161 | np.save( 162 | os.path.join(mask_save_dir, args.model, mask_name), 163 | best_masks_for_test_vector, 164 | ) 165 | del best_masks_for_test_vector 166 | 167 | return best_masks_for_test, best_val_metrics 168 | 169 | 170 | def load_tall_mask(remove_keys, ptm_check, config): 171 | """Loads TALL masks from disk, unpack and transform to state dictionaries.""" 172 | mask_location = config.model_location.replace("model_checkpoints", "tall_masks") 173 | print(f"Loading TALL masks from {mask_location}") 174 | try: 175 | if config.method.use_ties: 176 | print("==== Loading TALL Masks built with TIES ====") 177 | tall_masks = np.load( 178 | os.path.join( 179 | mask_location, 180 | config.model, 181 | f"TALL_mask_{config.num_tasks}task_use_ties.npy", 182 | ), 183 | allow_pickle=True, 184 | ).item() 185 | else: 186 | print("==== Loading TALL Masks built with Task Arithmetic ====") 187 | tall_masks = np.load( 188 | os.path.join(mask_location, config.model, f"TALL_mask_{config.num_tasks}task.npy"), 189 | allow_pickle=True, 190 | ).item() 191 | # tall_masks = torch.load(os.path.join(mask_location, config.model, f"TALL_mask_{config.num_tasks}task.npy")) 192 | except: 193 | raise Exception("TALL Masks are not constructed yet.") 194 | 195 | # unpack masks and convert back to torch tensors 196 | tall_masks = {k: torch.from_numpy(np.unpackbits(v)) for k, v in tall_masks.items()} 197 | 198 | # convert vectors to dictionaries 199 | tall_masks = { 200 | dataset: vector_to_state_dict(mask, ptm_check, remove_keys=remove_keys) for dataset, mask in tall_masks.items() 201 | } 202 | 203 | return tall_masks 204 | 205 | 206 | def construct_consensus_mask(ptm_check, prun_thre_k, config, remove_keys=[]): 207 | """ 208 | Generate consensus mask by filtering out least-used parameters 209 | 210 | Args: 211 | ptm_check: pretrained_checkpoint as state dictionary 212 | prun_thre_k: weight-pruning threhold, stands for the least number of activated tasks for a parameter to be preserved from pruning 213 | if prun_thre_k is set to 2: remove both catastrophic and selfish weights; 214 | if prun_thre_k is set to 1: remove only catastrophic weights; 215 | if prun_thre_k is set to 0: remove no weights -> reduce to TA or TIES 216 | if prun_thre_k is set to > num_tasks: remove all weights -> reduce to zero-shot 217 | Returns: 218 | consensus_mask_vector: constructed consensus mask as vector (boolean in shape (n_parameter, )) 219 | """ 220 | 221 | print("==== Generating Consensus Mask ====") 222 | # load TALL masks (in shape (n_task, n_parameter)) 223 | tall_masks = load_tall_mask(remove_keys, ptm_check, config) 224 | tall_masks = list(tall_masks.values()) 225 | 226 | # generate consensus masks 227 | consensus_mask = copy.deepcopy(tall_masks[0]) 228 | for key, value in consensus_mask.items(): 229 | consensus_mask[key] = torch.zeros_like(value) 230 | # count for each parameter, the tasks it has been activated for 231 | for mask in tall_masks: 232 | consensus_mask[key] = consensus_mask[key] + mask[key].float() 233 | # filter out the least-activated parameters based on given threshold 234 | consensus_mask[key] = consensus_mask[key].float() >= prun_thre_k 235 | 236 | torch.save(consensus_mask, f"consensus_mask_{config.num_tasks}task_107.pt") 237 | 238 | consensus_mask_vector = state_dict_to_vector(consensus_mask, remove_keys=remove_keys) 239 | 240 | return consensus_mask_vector 241 | -------------------------------------------------------------------------------- /src/utils/ties_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .utils import topk_values_mask 4 | 5 | 6 | ## TIES MERGING UTILS 7 | def resolve_zero_signs(sign_to_mult, method="majority"): 8 | majority_sign = torch.sign(sign_to_mult.sum()) 9 | 10 | if method == "majority": 11 | sign_to_mult[sign_to_mult == 0] = majority_sign 12 | elif method == "minority": 13 | sign_to_mult[sign_to_mult == 0] = -1 * majority_sign 14 | return sign_to_mult 15 | 16 | 17 | def resolve_sign(tensor: torch.Tensor): 18 | sign_to_mult = torch.sign(tensor.sum(dim=0)) 19 | sign_to_mult = resolve_zero_signs(sign_to_mult, "majority") 20 | return sign_to_mult 21 | 22 | 23 | def disjoint_merge(tensor, merge_func, sign_to_mult): 24 | merge_func = merge_func.split("-")[-1] 25 | 26 | # If sign is provided then we select the corresponding entries and aggregate. 27 | if sign_to_mult is not None: 28 | rows_to_keep = torch.where(sign_to_mult.unsqueeze(0) > 0, tensor > 0, tensor < 0) 29 | selected_entries = tensor * rows_to_keep 30 | # Else we select all non-zero entries and aggregate. 31 | else: 32 | rows_to_keep = tensor != 0 33 | selected_entries = tensor * rows_to_keep 34 | 35 | if merge_func == "mean": 36 | non_zero_counts = (selected_entries != 0).sum(dim=0).float() 37 | disjoint_aggs = torch.sum(selected_entries, dim=0) / torch.clamp(non_zero_counts, min=1) 38 | elif merge_func == "sum": 39 | disjoint_aggs = torch.sum(selected_entries, dim=0) 40 | elif merge_func == "max": 41 | disjoint_aggs = selected_entries.abs().max(dim=0)[0] 42 | disjoint_aggs *= sign_to_mult 43 | else: 44 | raise ValueError(f"Merge method {merge_func} is not defined.") 45 | 46 | return disjoint_aggs 47 | 48 | 49 | def ties_merging(flat_task_checks, reset_thresh=None, merge_func=""): 50 | all_checks = flat_task_checks.clone() 51 | updated_checks, *_ = topk_values_mask(all_checks, K=reset_thresh, return_mask=False) 52 | print(f"RESOLVING SIGN") 53 | final_signs = resolve_sign(updated_checks) 54 | assert final_signs is not None 55 | 56 | print(f"Disjoint AGGREGATION: {merge_func}") 57 | merged_tv = disjoint_merge(updated_checks, merge_func, final_signs) 58 | 59 | return merged_tv 60 | -------------------------------------------------------------------------------- /src/utils/utils.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import os 3 | import pickle 4 | from collections import OrderedDict 5 | from typing import Any, Dict, List, Optional, Tuple 6 | 7 | import numpy as np 8 | import torch 9 | import torch.nn as nn 10 | 11 | 12 | def compute_l1_norm(model1: nn.Module, model2: nn.Module) -> Tuple[torch.Tensor, Dict[str, float]]: 13 | """ 14 | Computes the L1 norm between the parameters of two models. 15 | 16 | Args: 17 | model1 (nn.Module): The first model. 18 | model2 (nn.Module): The second model. 19 | 20 | Returns: 21 | Tuple[torch.Tensor, Dict[str, float]]: A tuple containing the total L1 norm and a dictionary 22 | with the L1 norm for each layer. 23 | 24 | """ 25 | norms = dict() 26 | l1_norm = 0.0 27 | for (n, p1), p2 in zip(model1.named_parameters(), model2.parameters()): 28 | layer_l1_norm = torch.norm(p1 - p2, 1) 29 | l1_norm += layer_l1_norm 30 | norms[n] = layer_l1_norm.item() 31 | 32 | return l1_norm, norms 33 | 34 | 35 | def assign_learning_rate(param_group, new_lr): 36 | param_group["lr"] = new_lr 37 | 38 | 39 | def _warmup_lr(base_lr, warmup_length, step): 40 | return base_lr * (step + 1) / warmup_length 41 | 42 | 43 | def cosine_lr(optimizer, base_lrs, warmup_length, steps): 44 | if not isinstance(base_lrs, list): 45 | base_lrs = [base_lrs for _ in optimizer.param_groups] 46 | assert len(base_lrs) == len(optimizer.param_groups) 47 | 48 | def _lr_adjuster(step): 49 | for param_group, base_lr in zip(optimizer.param_groups, base_lrs): 50 | if step < warmup_length: 51 | lr = _warmup_lr(base_lr, warmup_length, step) 52 | else: 53 | e = step - warmup_length 54 | es = steps - warmup_length 55 | lr = 0.5 * (1 + np.cos(np.pi * e / es)) * base_lr 56 | assign_learning_rate(param_group, lr) 57 | 58 | return _lr_adjuster 59 | 60 | 61 | def accuracy(output: torch.Tensor, target: torch.Tensor, topk: List[int] = (1,)): 62 | pred = output.topk(max(topk), 1, True, True)[1].t() 63 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 64 | return [float(correct[:k].reshape(-1).float().sum(0, keepdim=True).cpu().numpy()) for k in topk] 65 | 66 | 67 | def torch_load_old(save_path: str, device=None): 68 | with open(save_path, "rb") as f: 69 | classifier = pickle.load(f) 70 | if device is not None: 71 | classifier = classifier.to(device) 72 | return classifier 73 | 74 | 75 | def torch_save(model, save_path, save_state_dict=True): 76 | # TODO: hacky way to save state dict 77 | if save_state_dict and isinstance(model, torch.nn.Module): 78 | model = model.state_dict() 79 | if os.path.dirname(save_path) != "": 80 | os.makedirs(os.path.dirname(save_path), exist_ok=True) 81 | torch.save(model, save_path) 82 | 83 | 84 | def torch_load(save_path, device=None): 85 | model = torch.load(save_path, map_location="cpu", weights_only=False) 86 | if device is not None: 87 | model = model.to(device) 88 | return model 89 | 90 | 91 | def get_logits(inputs, classifier): 92 | assert callable(classifier) 93 | if hasattr(classifier, "to"): 94 | classifier = classifier.to(inputs.device) 95 | return classifier(inputs) 96 | 97 | 98 | def get_probs(inputs, classifier): 99 | if hasattr(classifier, "predict_proba"): 100 | probs = classifier.predict_proba(inputs.detach().cpu().numpy()) 101 | return torch.from_numpy(probs) 102 | logits = get_logits(inputs, classifier) 103 | return logits.softmax(dim=1) 104 | 105 | 106 | class LabelSmoothing(torch.nn.Module): 107 | def __init__(self, smoothing=0.0): 108 | super(LabelSmoothing, self).__init__() 109 | self.confidence = 1.0 - smoothing 110 | self.smoothing = smoothing 111 | 112 | def forward(self, x, target): 113 | logprobs = torch.nn.functional.log_softmax(x, dim=-1) 114 | 115 | nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1)) 116 | nll_loss = nll_loss.squeeze(1) 117 | smooth_loss = -logprobs.mean(dim=-1) 118 | loss = self.confidence * nll_loss + self.smoothing * smooth_loss 119 | return loss.mean() 120 | 121 | 122 | class DotDict(dict): 123 | """dot.notation access to dictionary attributes""" 124 | 125 | __getattr__ = dict.get 126 | __setattr__ = dict.__setitem__ 127 | __delattr__ = dict.__delitem__ 128 | 129 | 130 | def find_optimal_coef( 131 | results: Dict[str, Any], 132 | metric: str = "avg_normalized_top1", 133 | minimize: bool = False, 134 | control_metric: Optional[str] = None, 135 | control_metric_threshold: float = 0.0, 136 | ) -> float: 137 | """ 138 | Finds the optimal coefficient based on the given results and metric. 139 | 140 | Args: 141 | results (Dict[str, Any]): A dictionary containing the results for different scaling coefficients. 142 | metric (str, optional): The metric to optimize. Defaults to "avg_normalized_top1". 143 | minimize (bool, optional): Whether to minimize the metric. Defaults to False. 144 | control_metric (str, optional): The control metric to check against. Defaults to None. 145 | control_metric_threshold (float, optional): The threshold value for the control metric. Defaults to 0.0. 146 | 147 | Returns: 148 | The optimal coefficient based on the given results and metric. 149 | """ 150 | best_coef = None 151 | if minimize: 152 | best_metric = 1 153 | else: 154 | best_metric = 0 155 | for scaling_coef in results.keys(): 156 | if control_metric is not None: 157 | if results[scaling_coef][control_metric] < control_metric_threshold: 158 | print(f"Control metric fell below {control_metric_threshold} threshold") 159 | continue 160 | if minimize: 161 | if results[scaling_coef][metric] < best_metric: 162 | best_metric = results[scaling_coef][metric] 163 | best_coef = scaling_coef 164 | else: 165 | if results[scaling_coef][metric] > best_metric: 166 | best_metric = results[scaling_coef][metric] 167 | best_coef = scaling_coef 168 | return best_coef 169 | 170 | 171 | def find_optimal_coef_tradeoff( 172 | results: Dict[str, Any], 173 | tradeoff_target_weight: float = 5.0, 174 | minimize: bool = False, 175 | control_metric: Optional[str] = None, 176 | control_metric_threshold: float = 0.0, 177 | ) -> float: 178 | best_coef = None 179 | if minimize: 180 | best_metric = 1 181 | else: 182 | best_metric = 0 183 | for scaling_coef in results.keys(): 184 | if minimize: 185 | if ( 186 | tradeoff_target_weight * results[scaling_coef]["target_normalized_accuracy"] 187 | + results[scaling_coef]["control_normalized_accuracy"] 188 | ) < best_metric: 189 | best_metric = ( 190 | tradeoff_target_weight * results[scaling_coef]["target_normalized_accuracy"] 191 | + results[scaling_coef]["control_normalized_accuracy"] 192 | ) 193 | best_coef = scaling_coef 194 | else: 195 | if ( 196 | tradeoff_target_weight * results[scaling_coef]["target_normalized_accuracy"] 197 | + results[scaling_coef]["control_normalized_accuracy"] 198 | ) > best_metric: 199 | best_metric = ( 200 | tradeoff_target_weight * results[scaling_coef]["target_normalized_accuracy"] 201 | + results[scaling_coef]["control_normalized_accuracy"] 202 | ) 203 | best_coef = scaling_coef 204 | return best_coef 205 | 206 | 207 | def nonlinear_advantage(nonlinear_acc, linear_acc, num_classes): 208 | """Computes the normalized non-linear advantage of a finetuned model. 209 | 210 | The nonlinear_advantage is defined as: 211 | error_rate(linear_model) - error_rate(nonlinear_model) / (1 - 1 / num_classes) 212 | and takes values between [-1, 1]. A value of 0 indicates that the nonlinear 213 | model is no better than the linear one. Meanwhile, a value of 1 indicates 214 | that the nonlinear model is perfect and the linear trivial, and a value of 215 | -1 indicates the opposite. 216 | """ 217 | return (nonlinear_acc - linear_acc) / (1.0 - 1.0 / num_classes) 218 | 219 | 220 | def to_cuda(input_dict): 221 | cuda_dict = {} 222 | for key, value in input_dict.items(): 223 | cuda_dict[key] = value.to("cuda") 224 | return cuda_dict 225 | 226 | 227 | def state_dict_to_vector(state_dict, remove_keys=[]): 228 | shared_state_dict = copy.deepcopy(state_dict) 229 | for key in remove_keys: 230 | if key in shared_state_dict: 231 | del shared_state_dict[key] 232 | sorted_shared_state_dict = OrderedDict(sorted(shared_state_dict.items())) 233 | return torch.nn.utils.parameters_to_vector([value.reshape(-1) for key, value in sorted_shared_state_dict.items()]) 234 | 235 | 236 | def vector_to_state_dict(vector, state_dict, remove_keys=[]): 237 | # create a reference dict to define the order of the vector 238 | reference_dict = copy.deepcopy(state_dict) 239 | for key in remove_keys: 240 | if key in reference_dict: 241 | del reference_dict[key] 242 | sorted_reference_dict = OrderedDict(sorted(reference_dict.items())) 243 | 244 | # create a shared state dict using the refence dict 245 | torch.nn.utils.vector_to_parameters(vector, sorted_reference_dict.values()) 246 | 247 | # add back the encoder and decoder embedding weights. 248 | if "transformer.shared.weight" in sorted_reference_dict: 249 | for key in remove_keys: 250 | sorted_reference_dict[key] = sorted_reference_dict["transformer.shared.weight"] 251 | return sorted_reference_dict 252 | 253 | 254 | def add_ptm_to_tv(tv_dict, ptm_dict): 255 | assert set(tv_dict.keys()) == set(ptm_dict.keys()), "Differing parameter names in models." 256 | final_dict = copy.deepcopy(tv_dict) 257 | for k, v in ptm_dict.items(): 258 | final_dict[k] = tv_dict[k] + v 259 | return final_dict 260 | 261 | 262 | def check_parameterNamesMatch(checkpoints): 263 | parameter_names = set(checkpoints[0].keys()) 264 | 265 | if len(checkpoints) >= 2: 266 | # raise ValueError("Number of models is less than 2.") 267 | for checkpoint in checkpoints[1:]: 268 | current_parameterNames = set(checkpoint.keys()) 269 | if current_parameterNames != parameter_names: 270 | raise ValueError( 271 | "Differing parameter names in models. " 272 | f"The different parameters are {parameter_names.symmetric_difference(current_parameterNames)}" 273 | ) 274 | 275 | 276 | def check_state_dicts_equal(state_dict1, state_dict2): 277 | if set(state_dict1.keys()) != set(state_dict2.keys()): 278 | return False 279 | 280 | for key in state_dict1.keys(): 281 | if not torch.equal(state_dict1[key], state_dict2[key]): 282 | return False 283 | 284 | return True 285 | 286 | 287 | def topk_values_mask(M, K=0.7, return_mask=False, reshape_mask=False): 288 | if K == 100: 289 | # print("Not applying mask") 290 | if return_mask: 291 | return M, torch.ones_like(M), None 292 | else: 293 | return M, torch.ones_like(M) 294 | 295 | if K >= 1: 296 | K /= 100 297 | 298 | original_shape = M.shape 299 | if M.dim() == 1: 300 | M = M.unsqueeze(0) 301 | 302 | n, d = M.shape 303 | k = int(d * K) 304 | k = d - k # Keep top k elements instead of bottom k elements 305 | 306 | # Find the k-th smallest element by magnitude for each row 307 | kth_values, _ = M.abs().kthvalue(k, dim=1, keepdim=True) 308 | # Create a mask tensor with True for the top k elements in each row 309 | mask = M.abs() >= kth_values 310 | final_mask = mask.squeeze() if original_shape == M.squeeze().shape else mask 311 | 312 | if reshape_mask: 313 | final_mask = final_mask.reshape(M.shape) 314 | 315 | if return_mask: 316 | return M * final_mask, final_mask.float().mean(dim=1), final_mask 317 | else: 318 | return M * final_mask, final_mask.float().mean(dim=1) 319 | 320 | 321 | def cleanup_linear(state_dict): 322 | # The linear model also has keys for the reference point $\theta_0$ in the state dict with the prefix `params0`. 323 | state_dict = {k: v for k, v in state_dict.items() if "params." in k} 324 | return state_dict 325 | 326 | 327 | def get_ptm_linear(state_dict: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]: 328 | # rename keys so that they match afterwards 329 | state_dict_new = {k.replace("params0", "params"): v for k, v in state_dict.items() if "params0." in k} 330 | state_dict_remaining = {k: v for k, v in state_dict.items() if "params." not in k} 331 | 332 | return state_dict_new, state_dict_remaining 333 | -------------------------------------------------------------------------------- /src/utils/variables_and_paths.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Literal 3 | import os 4 | 5 | TQDM_BAR_FORMAT = "{l_bar}{bar:10}{r_bar}{bar:-10b}" 6 | MODELS = ["ViT-B-32", "ViT-B-16", "ViT-L-14"] 7 | OPENCLIP_CACHEDIR = Path(Path.home(), "openclip-cachedir", "open_clip").as_posix() 8 | CACHEDIR = None 9 | DATA_DIR = os.environ.get("DATA_DIR", "/mnt/lts4/scratch/data") 10 | 11 | ALL_DATASETS = [ 12 | "EuroSAT", 13 | "DTD", 14 | "SUN397", 15 | "MNIST", 16 | "RESISC45", 17 | "GTSRB", 18 | "Cars", 19 | "SVHN", 20 | "STL10", 21 | "OxfordIIITPet", 22 | "Flowers102", 23 | "CIFAR100", 24 | "PCAM", 25 | "FER2013", 26 | "CIFAR10", 27 | "Food101", 28 | "FashionMNIST", 29 | "RenderedSST2", 30 | "EMNIST", 31 | "KMNIST", 32 | ] 33 | 34 | DATASETS_8 = ALL_DATASETS[:8] 35 | DATASETS_14 = ALL_DATASETS[:14] 36 | DATASETS_20 = ALL_DATASETS[:20] 37 | 38 | 39 | def cleanup_dataset_name(dataset_name: str): 40 | return dataset_name.replace("Val", "") + "Val" 41 | 42 | 43 | def get_zeroshot_path(root, dataset, model): 44 | return Path(root, model, cleanup_dataset_name(dataset), f"nonlinear_zeroshot.pt").as_posix() 45 | 46 | 47 | # def get_finetuned_path(root, dataset, model): 48 | # return Path(root, model, cleanup_dataset_name(dataset), f"nonlinear_finetuned.pt").as_posix() 49 | 50 | 51 | def get_finetuned_path(root, dataset, model): 52 | return Path(root, model, cleanup_dataset_name(dataset), f"nonlinear_finetuned.pt").as_posix() 53 | # return f"/mnt/lts4/scratch/home/ndimitri/dev/tall_masks/new_checkpoints/{model}/exponential/{cleanup_dataset_name(dataset)}/a=0.4/{cleanup_dataset_name(dataset)}.pt" 54 | 55 | 56 | def get_single_task_accuracies_path(model): 57 | return Path("results/single_task", model, f"nonlinear_ft_accuracies.json").as_posix() 58 | 59 | 60 | def get_zero_shot_accuracies_path(model): 61 | return Path(f"results/zero_shot/{model}_20tasks_zeroshot.json").as_posix() 62 | --------------------------------------------------------------------------------