├── .gitignore ├── LICENSE ├── README.md ├── environment.yml ├── figures └── tangent.jpg ├── ruff.toml └── src ├── args.py ├── datasets ├── cars.py ├── cifar10.py ├── cifar100.py ├── common.py ├── dtd.py ├── eurosat.py ├── gtsrb.py ├── imagenet.py ├── mnist.py ├── registry.py ├── resisc45.py ├── stl10.py ├── sun397.py ├── svhn.py └── templates.py ├── distributed.py ├── eval.py ├── eval_single_task.py ├── eval_task_addition.py ├── eval_task_negation.py ├── finetune.py ├── heads.py ├── linearize.py ├── modeling.py ├── task_vectors.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | checkpoints*/ 2 | 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | share/python-wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | MANIFEST 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .nox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *.cover 51 | *.py,cover 52 | .hypothesis/ 53 | .pytest_cache/ 54 | cover/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | db.sqlite3 64 | db.sqlite3-journal 65 | 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | .pybuilder/ 78 | target/ 79 | 80 | # Jupyter Notebook 81 | .ipynb_checkpoints 82 | 83 | # IPython 84 | profile_default/ 85 | ipython_config.py 86 | 87 | # pyenv 88 | # For a library or package, you might want to ignore these files since the code is 89 | # intended to run in multiple environments; otherwise, check them in: 90 | # .python-version 91 | 92 | # pipenv 93 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 94 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 95 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 96 | # install all needed dependencies. 97 | #Pipfile.lock 98 | 99 | # poetry 100 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 101 | # This is especially recommended for binary packages to ensure reproducibility, and is more 102 | # commonly ignored for libraries. 103 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 104 | #poetry.lock 105 | 106 | # pdm 107 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 108 | #pdm.lock 109 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 110 | # in version control. 111 | # https://pdm.fming.dev/#use-with-ide 112 | .pdm.toml 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ 163 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 gortizji 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Task Arithmetic in the Tangent Space 2 | 3 | This is the source code to reproduce the experiments of the paper "[Task arithmetic in the tangent space: Improved editing of pre-trained models](https://arxiv.org/abs/2305.12827)" by Guillermo Ortiz-Jimenez*, Alessandro Favero* and Pascal Frossard. 4 | 5 | ![](figures/tangent.jpg) 6 | 7 | ## Dependencies 8 | 9 | To run the code, please install all its dependencies: 10 | ```sh 11 | conda env create 12 | conda activate tangent-arithmetic 13 | ``` 14 | and add the `src` directory to the `PYTHONPATH`: 15 | ```sh 16 | cd tangent_task_arithmetic 17 | export PYTHONPATH="$PYTHONPATH:$PWD" 18 | ``` 19 | 20 | ## Repository content 21 | 22 | This repository is heavily based on the code from [Ilharco et al. (2022)](https://github.com/mlfoundations/task_vectors) and follows the same structure. 23 | 24 | ### Task vectors 25 | 26 | The task vector logic in [src/task_vectors.py](src/task_vectors.py) has been extended to distinguish between `NonLinearTaskVector`s and `LinearizedTaskVector`s which can be applied to non-linear `ImageEncoder`s and `LinearizedImageEncoder`s, respectively. Given a pre-trained checkpoint and a fine-tuned checkpoint, you can create a linearized/standard task vector as: 27 | 28 | ```python 29 | from src.task_vectors import NonLinearTaskVector, LinearizedTaskVector 30 | 31 | # Non-linear task vector. 32 | zeroshot_checkpoint = ... # Pre-trained non-linear image encoder. 33 | finetuned_checkpoint = ... # Non-linearly fine-tuned checkpoint. 34 | 35 | nonlinear_task_vector = NonLinearTaskVector(zeroshot_checkpoint, finetuned_checkpoint) 36 | 37 | # Tangent task vector. 38 | linear_zeroshot_checkpoint = ... # Pre-trained linearized image encoder. 39 | linear_finetuned_checkpoint = ... # Linearly fine-tuned checkpoint. 40 | 41 | linear_task_vector = LinearizedTaskVector(linear_zeroshot_checkpoint, linear_finetuned_checkpoint) 42 | ``` 43 | 44 | Once created, we can modify and combine the task vectors through arithmetic operations in Python, e.g., 45 | ```python 46 | negated_task_vector = -task_vector # Negating a task vector. 47 | multi_task_vector = 0.5 * task_vector_1 + 0.7 * task_vector_2 # Adding two vectors. 48 | ``` 49 | and apply them to a pre-trained encoder as: 50 | ```python 51 | edited_encoder = task_vector.apply_to(pretrained_checkpoint, scaling_coef=0.8) 52 | ``` 53 | 54 | Sometimes, we may want to apply a non-linear task vector to a `LinearizedImageEncoder` (to obtain posthoc linearized models for example), or viceversa. Both `NonLinearTaskVector` and `LinearizedTaskVector` can be casted and applied to encoders from the complementary class as 55 | ```python 56 | linear_edited_encoder = nonlinear_task_vector.apply_to_linear(linear_pretrained_encoder, scaling_coef=0.8) 57 | ``` 58 | 59 | ### Linearized Models 60 | 61 | The module [src/linearize.py](src/linearize.py) provides tools to linearize any PyTorch `nn.Module`. 62 | 63 | To linearize any `model` object of the class `nn.Module` one can simply do: 64 | ```python 65 | from src.linearize import LinearizedModel 66 | 67 | model = ... # An object of the class `nn.Module`. 68 | linear_model = LinearizedModel(model) # This object can be treated as any other `nn.Module`. 69 | ``` 70 | Specifically for `ImageEncoder`s the class `LinearizedImageEncoder` provides a simple way to linearize a CLIP image encoder while retaining the same API as the original object from the `ImageEncoder` class. We can therefore create a linearized CLIP model as: 71 | ```python 72 | from src.linearize import LinearizedImageEncoder 73 | from src.heads import get_classification_head 74 | from src.modeling import ImageClassifier 75 | 76 | args = ... # Arguments used to define an `ImageEncoder`. 77 | linear_encoder = LinearizedImageEncoder(args, keep_lang=False) # This object can be treated as any other `ImageEncoder`. 78 | 79 | classification_head = get_classification_head(args, train_dataset) 80 | 81 | linear_clip = ImageClassifier(image_encoder, classification_head) 82 | ``` 83 | ### Training 84 | 85 | The script `src/finetune.py` can be used to reproduce the training protocol we used to fine-tune our models on all our downstream tasks (both linearly and non-linearly). 86 | ```sh 87 | python src/finetune.py --finetuning-mode=standard --model=ViT-B-32 --world-size=2 # Finetune non-linearly on 2 GPUs 88 | python src/finetune.py --finetuning-mode=linear --model=ViT-B-32 --world-size=2 # Finetune non-linearly on 2 GPUs 89 | ``` 90 | 91 | ### Evaluation 92 | 93 | We provide different scripts to evaluate the different task vectors obtained using the previous scripts. 94 | 95 | #### Single-task accuracy 96 | Having run `src/finetune.py` for a given model, you can evaluate the performance of the fine-tuned weights on each single task by running 97 | ```sh 98 | # Evaluate pre-trained models. 99 | python src/eval_single_task.py --model=ViT-B-32 --finetuning-mode=none 100 | 101 | # Evaluate non-linearly fine-tuned models. 102 | python src/eval_single_task.py --model=ViT-B-32 --finetuning-mode=standard 103 | 104 | # Evaluate linearly fine-tuned models. 105 | python src/eval_single_task.py --model=ViT-B-32 --finetuning-mode=linear 106 | 107 | # Evaluate post-hoc linearized models. Requires having run finetune.py with --finetuning=mode=standard. 108 | python src/eval_single_task.py --model=ViT-B-32 --finetuning-mode=posthoc 109 | ``` 110 | 111 | #### Task addition 112 | Once evaluated on the single tasks, we can evaluate the task arithmetic performance of the different strategies on the addition benchmark. 113 | ```sh 114 | # Evaluate non-linearly fine-tuned models. 115 | python src/eval_task_addition.py --model=ViT-B-32 --finetuning-mode=standard 116 | 117 | # Evaluate linearly fine-tuned models. 118 | python src/eval_task_addition.py --model=ViT-B-32 --finetuning-mode=linear 119 | 120 | # Evaluate post-hoc linearized models. 121 | python src/eval_task_addition.py --model=ViT-B-32 --finetuning-mode=posthoc 122 | ``` 123 | 124 | #### Task addition 125 | We can evaluate the task arithmetic performance of the different strategies on the negation benchmark. 126 | ```sh 127 | # Evaluate non-linearly fine-tuned models. 128 | python src/eval_task_negation.py --model=ViT-B-32 --finetuning-mode=standard 129 | 130 | # Evaluate linearly fine-tuned models. 131 | python src/eval_task_negation.py --model=ViT-B-32 --finetuning-mode=linear 132 | 133 | # Evaluate post-hoc linearized models. 134 | python src/eval_task_negation.py --model=ViT-B-32 --finetuning-mode=posthoc 135 | ``` 136 | 137 | ## Datasets 138 | To download and prepare the datasets, please follow the instructions in [this issue](https://github.com/mlfoundations/task_vectors/issues/1). 139 | 140 | ## Reference 141 | If you find this code useful, please cite the following paper: 142 | ```bibtex 143 | @article{ortizjimenez2023tangent, 144 | title = {Task Arithmetic in the Tangent Space: Improved Editing of Pre-Trained 145 | Models}, 146 | author = {Guillermo Ortiz{-}Jim{\'{e}}nez and 147 | Alessandro Favero and 148 | Pascal Frossard}, 149 | journal = {arXiv:2305.12827}, 150 | year = {2023}, 151 | note = {\url{https://arxiv.org/abs/2305:12827}}, 152 | } 153 | ``` 154 | 155 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: tangent-arithmetic 2 | channels: 3 | - pytorch 4 | - nvidia 5 | - defaults 6 | dependencies: 7 | - _libgcc_mutex=0.1 8 | - _openmp_mutex=5.1 9 | - blas=1.0 10 | - brotlipy=0.7.0 11 | - bzip2=1.0.8 12 | - ca-certificates=2023.05.30 13 | - certifi=2023.5.7 14 | - cffi=1.15.1 15 | - charset-normalizer=2.0.4 16 | - cryptography=39.0.1 17 | - cuda=11.6.1 18 | - cuda-cccl=11.6.55 19 | - cuda-command-line-tools=11.6.2 20 | - cuda-compiler=11.6.2 21 | - cuda-cudart=11.6.55 22 | - cuda-cudart-dev=11.6.55 23 | - cuda-cuobjdump=11.6.124 24 | - cuda-cupti=11.6.124 25 | - cuda-cuxxfilt=11.6.124 26 | - cuda-driver-dev=11.6.55 27 | - cuda-gdb=12.1.105 28 | - cuda-libraries=11.6.1 29 | - cuda-libraries-dev=11.6.1 30 | - cuda-memcheck=11.8.86 31 | - cuda-nsight=12.1.105 32 | - cuda-nsight-compute=12.1.1 33 | - cuda-nvcc=11.6.124 34 | - cuda-nvdisasm=12.1.105 35 | - cuda-nvml-dev=11.6.55 36 | - cuda-nvprof=12.1.105 37 | - cuda-nvprune=11.6.124 38 | - cuda-nvrtc=11.6.124 39 | - cuda-nvrtc-dev=11.6.124 40 | - cuda-nvtx=11.6.124 41 | - cuda-nvvp=12.1.105 42 | - cuda-runtime=11.6.1 43 | - cuda-samples=11.6.101 44 | - cuda-sanitizer-api=12.1.105 45 | - cuda-toolkit=11.6.1 46 | - cuda-tools=11.6.1 47 | - cuda-visual-tools=11.6.1 48 | - ffmpeg=4.3 49 | - freetype=2.12.1 50 | - gds-tools=1.6.1.9 51 | - giflib=5.2.1 52 | - gmp=6.2.1 53 | - gnutls=3.6.15 54 | - idna=3.4 55 | - intel-openmp=2023.1.0 56 | - jpeg=9e 57 | - lame=3.100 58 | - lcms2=2.12 59 | - ld_impl_linux-64=2.38 60 | - lerc=3.0 61 | - libcublas=11.9.2.110 62 | - libcublas-dev=11.9.2.110 63 | - libcufft=10.7.1.112 64 | - libcufft-dev=10.7.1.112 65 | - libcufile=1.6.1.9 66 | - libcufile-dev=1.6.1.9 67 | - libcurand=10.3.2.106 68 | - libcurand-dev=10.3.2.106 69 | - libcusolver=11.3.4.124 70 | - libcusparse=11.7.2.124 71 | - libcusparse-dev=11.7.2.124 72 | - libdeflate=1.17 73 | - libffi=3.4.4 74 | - libgcc-ng=11.2.0 75 | - libgomp=11.2.0 76 | - libiconv=1.16 77 | - libidn2=2.3.4 78 | - libnpp=11.6.3.124 79 | - libnpp-dev=11.6.3.124 80 | - libnvjpeg=11.6.2.124 81 | - libnvjpeg-dev=11.6.2.124 82 | - libpng=1.6.39 83 | - libstdcxx-ng=11.2.0 84 | - libtasn1=4.19.0 85 | - libtiff=4.5.0 86 | - libunistring=0.9.10 87 | - libuuid=1.41.5 88 | - libwebp=1.2.4 89 | - libwebp-base=1.2.4 90 | - lz4-c=1.9.4 91 | - mkl=2023.1.0 92 | - mkl-service=2.4.0 93 | - mkl_fft=1.3.6 94 | - mkl_random=1.2.2 95 | - ncurses=6.4 96 | - nettle=3.7.3 97 | - nsight-compute=2023.1.1.4 98 | - numpy=1.24.3 99 | - numpy-base=1.24.3 100 | - openh264=2.1.1 101 | - openssl=1.1.1t 102 | - pillow=9.4.0 103 | - pip=23.0.1 104 | - pycparser=2.21 105 | - pyopenssl=23.0.0 106 | - pysocks=1.7.1 107 | - python=3.10.11 108 | - pytorch=1.13.1 109 | - pytorch-cuda=11.6 110 | - pytorch-mutex=1.0 111 | - readline=8.2 112 | - requests=2.29.0 113 | - setuptools=67.8.0 114 | - sqlite=3.41.2 115 | - tbb=2021.8.0 116 | - tk=8.6.12 117 | - torchaudio=0.13.1 118 | - torchvision=0.14.1 119 | - typing_extensions=4.5.0 120 | - tzdata=2023c 121 | - urllib3=1.26.16 122 | - wheel=0.38.4 123 | - xz=5.4.2 124 | - zlib=1.2.13 125 | - zstd=1.5.5 126 | - pip: 127 | - filelock==3.12.0 128 | - fsspec==2023.5.0 129 | - ftfy==6.1.1 130 | - huggingface-hub==0.15.1 131 | - open-clip-torch==2.10.1 132 | - packaging==23.1 133 | - protobuf==3.20.3 134 | - pyyaml==6.0 135 | - regex==2023.6.3 136 | - safetensors==0.3.1 137 | - scipy==1.10.1 138 | - sentencepiece==0.1.99 139 | - timm==0.9.2 140 | - wcwidth==0.2.6 141 | -------------------------------------------------------------------------------- /figures/tangent.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gortizji/tangent_task_arithmetic/0aa0e1869d8b4b111fa92f2217b8ae863c084fc6/figures/tangent.jpg -------------------------------------------------------------------------------- /ruff.toml: -------------------------------------------------------------------------------- 1 | line-length = 100 2 | select = ["B","E","F","W","A","I"] 3 | ignore = ["B905"] 4 | fix = true 5 | -------------------------------------------------------------------------------- /src/args.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import torch 5 | 6 | 7 | def parse_arguments(): 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument( 10 | "--data-location", 11 | type=str, 12 | default=os.path.expanduser("~/data"), 13 | help="The root directory for the datasets.", 14 | ) 15 | parser.add_argument( 16 | "--eval-datasets", 17 | default=None, 18 | type=lambda x: x.split(","), 19 | help="Which datasets to use for evaluation. Split by comma, e.g. MNIST,EuroSAT. ", 20 | ) 21 | parser.add_argument( 22 | "--train-dataset", 23 | default=None, 24 | type=lambda x: x.split(","), 25 | help="Which dataset(s) to patch on.", 26 | ) 27 | parser.add_argument( 28 | "--exp_name", 29 | type=str, 30 | default=None, 31 | help="Name of the experiment, for organization purposes only.", 32 | ) 33 | parser.add_argument( 34 | "--results-db", 35 | type=str, 36 | default=None, 37 | help="Where to store the results, else does not store", 38 | ) 39 | parser.add_argument( 40 | "--model", 41 | type=str, 42 | default="ViT-B-32", 43 | help="The type of model (e.g. RN50, ViT-B-32).", 44 | ) 45 | parser.add_argument( 46 | "--batch-size", 47 | type=int, 48 | default=128, 49 | ) 50 | parser.add_argument( 51 | "--num-grad-accumulation", 52 | type=int, 53 | default=1, 54 | help="Number of gradient accumulation steps.", 55 | ) 56 | parser.add_argument("--lr", type=float, default=0.001, help="Learning rate.") 57 | parser.add_argument("--wd", type=float, default=0.1, help="Weight decay") 58 | parser.add_argument("--ls", type=float, default=0.0, help="Label smoothing.") 59 | parser.add_argument( 60 | "--warmup_length", 61 | type=int, 62 | default=500, 63 | ) 64 | parser.add_argument( 65 | "--epochs", 66 | type=int, 67 | default=10, 68 | ) 69 | parser.add_argument( 70 | "--load", 71 | type=lambda x: x.split(","), 72 | default=None, 73 | help="Optionally load _classifiers_, e.g. a zero shot classifier or probe or ensemble both.", # noqa: E501 74 | ) 75 | parser.add_argument( 76 | "--save", 77 | type=str, 78 | default=None, 79 | help="Optionally save a _classifier_, e.g. a zero shot classifier or probe.", 80 | ) 81 | parser.add_argument( 82 | "--cache-dir", 83 | type=str, 84 | default=None, 85 | help="Directory for caching features and encoder", 86 | ) 87 | parser.add_argument( 88 | "--openclip-cachedir", 89 | type=str, 90 | default=os.path.expanduser("~/openclip-cachedir/open_clip"), 91 | help="Directory for caching models from OpenCLIP", 92 | ) 93 | parser.add_argument( 94 | "--world-size", 95 | type=int, 96 | default=1, 97 | help="Number of processes for distributed training.", 98 | ) 99 | parser.add_argument( 100 | "--checkpoint-every", 101 | type=int, 102 | default=-1, 103 | help="How often to checkpoint the model.", 104 | ) 105 | parser.add_argument( 106 | "--port", 107 | type=int, 108 | default=12355, 109 | help="Port for distributed training.", 110 | ) 111 | parser.add_argument( 112 | "--seed", 113 | type=int, 114 | default=None, 115 | help="Random seed.", 116 | ) 117 | parser.add_argument( 118 | "--finetuning-mode", 119 | choices=["standard", "linear", "posthoc", "none"], 120 | help="Whether to use linearized models or not.", 121 | ) 122 | parser.add_argument( 123 | "--n-eval-points", 124 | type=int, 125 | default=21, 126 | help="Number of evaluation points used to find optimal coefficient in task arithmetic.", 127 | ) 128 | parsed_args = parser.parse_args() 129 | parsed_args.device = "cuda" if torch.cuda.is_available() else "cpu" 130 | 131 | if parsed_args.load is not None and len(parsed_args.load) == 1: 132 | parsed_args.load = parsed_args.load[0] 133 | return parsed_args 134 | -------------------------------------------------------------------------------- /src/datasets/cars.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torchvision.datasets as datasets 4 | 5 | 6 | import pathlib 7 | from typing import Callable, Optional, Any, Tuple 8 | 9 | from PIL import Image 10 | 11 | from torchvision.datasets.utils import download_and_extract_archive, download_url, verify_str_arg 12 | from torchvision.datasets.vision import VisionDataset 13 | 14 | 15 | class PytorchStanfordCars(VisionDataset): 16 | """`Stanford Cars `_ Dataset 17 | 18 | The Cars dataset contains 16,185 images of 196 classes of cars. The data is 19 | split into 8,144 training images and 8,041 testing images, where each class 20 | has been split roughly in a 50-50 split 21 | 22 | .. note:: 23 | 24 | This class needs `scipy `_ to load target files from `.mat` format. 25 | 26 | Args: 27 | root (string): Root directory of dataset 28 | split (string, optional): The dataset split, supports ``"train"`` (default) or ``"test"``. 29 | transform (callable, optional): A function/transform that takes in an PIL image 30 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 31 | target_transform (callable, optional): A function/transform that takes in the 32 | target and transforms it. 33 | download (bool, optional): If True, downloads the dataset from the internet and 34 | puts it in root directory. If dataset is already downloaded, it is not 35 | downloaded again.""" 36 | 37 | def __init__( 38 | self, 39 | root: str, 40 | split: str = "train", 41 | transform: Optional[Callable] = None, 42 | target_transform: Optional[Callable] = None, 43 | download: bool = False, 44 | ) -> None: 45 | 46 | try: 47 | import scipy.io as sio 48 | except ImportError: 49 | raise RuntimeError("Scipy is not found. This dataset needs to have scipy installed: pip install scipy") 50 | 51 | super().__init__(root, transform=transform, target_transform=target_transform) 52 | 53 | self._split = verify_str_arg(split, "split", ("train", "test")) 54 | self._base_folder = pathlib.Path(root) / "stanford_cars" 55 | devkit = self._base_folder / "devkit" 56 | 57 | if self._split == "train": 58 | self._annotations_mat_path = devkit / "cars_train_annos.mat" 59 | self._images_base_path = self._base_folder / "cars_train" 60 | else: 61 | self._annotations_mat_path = self._base_folder / "cars_test_annos_withlabels.mat" 62 | self._images_base_path = self._base_folder / "cars_test" 63 | 64 | if download: 65 | self.download() 66 | 67 | if not self._check_exists(): 68 | raise RuntimeError("Dataset not found. You can use download=True to download it") 69 | 70 | self._samples = [ 71 | ( 72 | str(self._images_base_path / annotation["fname"]), 73 | annotation["class"] - 1, # Original target mapping starts from 1, hence -1 74 | ) 75 | for annotation in sio.loadmat(self._annotations_mat_path, squeeze_me=True)["annotations"] 76 | ] 77 | 78 | self.classes = sio.loadmat(str(devkit / "cars_meta.mat"), squeeze_me=True)["class_names"].tolist() 79 | self.class_to_idx = {cls: i for i, cls in enumerate(self.classes)} 80 | 81 | def __len__(self) -> int: 82 | return len(self._samples) 83 | 84 | def __getitem__(self, idx: int) -> Tuple[Any, Any]: 85 | """Returns pil_image and class_id for given index""" 86 | image_path, target = self._samples[idx] 87 | pil_image = Image.open(image_path).convert("RGB") 88 | 89 | if self.transform is not None: 90 | pil_image = self.transform(pil_image) 91 | if self.target_transform is not None: 92 | target = self.target_transform(target) 93 | return pil_image, target 94 | 95 | 96 | def download(self) -> None: 97 | if self._check_exists(): 98 | return 99 | 100 | download_and_extract_archive( 101 | url="https://ai.stanford.edu/~jkrause/cars/car_devkit.tgz", 102 | download_root=str(self._base_folder), 103 | md5="c3b158d763b6e2245038c8ad08e45376", 104 | ) 105 | if self._split == "train": 106 | download_and_extract_archive( 107 | url="https://ai.stanford.edu/~jkrause/car196/cars_train.tgz", 108 | download_root=str(self._base_folder), 109 | md5="065e5b463ae28d29e77c1b4b166cfe61", 110 | ) 111 | else: 112 | download_and_extract_archive( 113 | url="https://ai.stanford.edu/~jkrause/car196/cars_test.tgz", 114 | download_root=str(self._base_folder), 115 | md5="4ce7ebf6a94d07f1952d94dd34c4d501", 116 | ) 117 | download_url( 118 | url="https://ai.stanford.edu/~jkrause/car196/cars_test_annos_withlabels.mat", 119 | root=str(self._base_folder), 120 | md5="b0a2b23655a3edd16d84508592a98d10", 121 | ) 122 | 123 | def _check_exists(self) -> bool: 124 | if not (self._base_folder / "devkit").is_dir(): 125 | return False 126 | 127 | return self._annotations_mat_path.exists() and self._images_base_path.is_dir() 128 | 129 | 130 | class Cars: 131 | def __init__(self, 132 | preprocess, 133 | location=os.path.expanduser('~/data'), 134 | batch_size=32, 135 | num_workers=16): 136 | # Data loading code 137 | 138 | self.train_dataset = PytorchStanfordCars(location, 'train', preprocess, download=True) 139 | self.train_loader = torch.utils.data.DataLoader( 140 | self.train_dataset, 141 | shuffle=True, 142 | batch_size=batch_size, 143 | num_workers=num_workers, 144 | ) 145 | 146 | self.test_dataset = PytorchStanfordCars(location, 'test', preprocess, download=True) 147 | self.test_loader = torch.utils.data.DataLoader( 148 | self.test_dataset, 149 | batch_size=batch_size, 150 | num_workers=num_workers 151 | ) 152 | idx_to_class = dict((v, k) 153 | for k, v in self.train_dataset.class_to_idx.items()) 154 | self.classnames = [idx_to_class[i].replace( 155 | '_', ' ') for i in range(len(idx_to_class))] 156 | -------------------------------------------------------------------------------- /src/datasets/cifar10.py: -------------------------------------------------------------------------------- 1 | import os 2 | import PIL 3 | import torch 4 | import numpy as np 5 | import torchvision 6 | from torchvision import transforms 7 | from torchvision.datasets import CIFAR10 as PyTorchCIFAR10 8 | from torchvision.datasets import VisionDataset 9 | 10 | cifar_classnames = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'] 11 | 12 | class CIFAR10: 13 | def __init__(self, preprocess, 14 | location=os.path.expanduser('~/data'), 15 | batch_size=128, 16 | num_workers=16): 17 | 18 | 19 | self.train_dataset = PyTorchCIFAR10( 20 | root=location, download=True, train=True, transform=preprocess 21 | ) 22 | 23 | self.train_loader = torch.utils.data.DataLoader( 24 | self.train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers 25 | ) 26 | 27 | self.test_dataset = PyTorchCIFAR10( 28 | root=location, download=True, train=False, transform=preprocess 29 | ) 30 | 31 | self.test_loader = torch.utils.data.DataLoader( 32 | self.test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers 33 | ) 34 | 35 | self.classnames = self.test_dataset.classes 36 | 37 | def convert(x): 38 | if isinstance(x, np.ndarray): 39 | return torchvision.transforms.functional.to_pil_image(x) 40 | return x 41 | 42 | class BasicVisionDataset(VisionDataset): 43 | def __init__(self, images, targets, transform=None, target_transform=None): 44 | if transform is not None: 45 | transform.transforms.insert(0, convert) 46 | super(BasicVisionDataset, self).__init__(root=None, transform=transform, target_transform=target_transform) 47 | assert len(images) == len(targets) 48 | 49 | self.images = images 50 | self.targets = targets 51 | 52 | def __getitem__(self, index): 53 | return self.transform(self.images[index]), self.targets[index] 54 | 55 | def __len__(self): 56 | return len(self.targets) 57 | -------------------------------------------------------------------------------- /src/datasets/cifar100.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torchvision.datasets import CIFAR100 as PyTorchCIFAR100 4 | 5 | class CIFAR100: 6 | def __init__(self, 7 | preprocess, 8 | location=os.path.expanduser('~/data'), 9 | batch_size=128, 10 | num_workers=16): 11 | 12 | self.train_dataset = PyTorchCIFAR100( 13 | root=location, download=True, train=True, transform=preprocess 14 | ) 15 | 16 | self.train_loader = torch.utils.data.DataLoader( 17 | self.train_dataset, batch_size=batch_size, num_workers=num_workers 18 | ) 19 | 20 | self.test_dataset = PyTorchCIFAR100( 21 | root=location, download=True, train=False, transform=preprocess 22 | ) 23 | 24 | self.test_loader = torch.utils.data.DataLoader( 25 | self.test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers 26 | ) 27 | 28 | self.classnames = self.test_dataset.classes 29 | 30 | 31 | -------------------------------------------------------------------------------- /src/datasets/common.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import json 4 | import glob 5 | import collections 6 | import random 7 | 8 | import numpy as np 9 | 10 | from tqdm import tqdm 11 | 12 | import torchvision.datasets as datasets 13 | from torch.utils.data import Dataset, DataLoader, Sampler 14 | 15 | 16 | class SubsetSampler(Sampler): 17 | def __init__(self, indices): 18 | self.indices = indices 19 | 20 | def __iter__(self): 21 | return (i for i in self.indices) 22 | 23 | def __len__(self): 24 | return len(self.indices) 25 | 26 | class ImageFolderWithPaths(datasets.ImageFolder): 27 | def __init__(self, path, transform, flip_label_prob=0.0): 28 | super().__init__(path, transform) 29 | self.flip_label_prob = flip_label_prob 30 | if self.flip_label_prob > 0: 31 | print(f'Flipping labels with probability {self.flip_label_prob}') 32 | num_classes = len(self.classes) 33 | for i in range(len(self.samples)): 34 | if random.random() < self.flip_label_prob: 35 | new_label = random.randint(0, num_classes-1) 36 | self.samples[i] = ( 37 | self.samples[i][0], 38 | new_label 39 | ) 40 | 41 | def __getitem__(self, index): 42 | image, label = super(ImageFolderWithPaths, self).__getitem__(index) 43 | return { 44 | 'images': image, 45 | 'labels': label, 46 | 'image_paths': self.samples[index][0] 47 | } 48 | 49 | 50 | def maybe_dictionarize(batch): 51 | if isinstance(batch, dict): 52 | return batch 53 | 54 | if len(batch) == 2: 55 | batch = {'images': batch[0], 'labels': batch[1]} 56 | elif len(batch) == 3: 57 | batch = {'images': batch[0], 'labels': batch[1], 'metadata': batch[2]} 58 | else: 59 | raise ValueError(f'Unexpected number of elements: {len(batch)}') 60 | 61 | return batch 62 | 63 | 64 | def get_features_helper(image_encoder, dataloader, device): 65 | all_data = collections.defaultdict(list) 66 | 67 | image_encoder = image_encoder.to(device) 68 | image_encoder = torch.nn.DataParallel(image_encoder, device_ids=[x for x in range(torch.cuda.device_count())]) 69 | image_encoder.eval() 70 | 71 | with torch.no_grad(): 72 | for batch in tqdm(dataloader): 73 | batch = maybe_dictionarize(batch) 74 | features = image_encoder(batch['images'].cuda()) 75 | 76 | all_data['features'].append(features.cpu()) 77 | 78 | for key, val in batch.items(): 79 | if key == 'images': 80 | continue 81 | if hasattr(val, 'cpu'): 82 | val = val.cpu() 83 | all_data[key].append(val) 84 | else: 85 | all_data[key].extend(val) 86 | 87 | for key, val in all_data.items(): 88 | if torch.is_tensor(val[0]): 89 | all_data[key] = torch.cat(val).numpy() 90 | 91 | return all_data 92 | 93 | 94 | def get_features(is_train, image_encoder, dataset, device): 95 | split = 'train' if is_train else 'val' 96 | dname = type(dataset).__name__ 97 | if image_encoder.cache_dir is not None: 98 | cache_dir = f'{image_encoder.cache_dir}/{dname}/{split}' 99 | cached_files = glob.glob(f'{cache_dir}/*') 100 | if image_encoder.cache_dir is not None and len(cached_files) > 0: 101 | print(f'Getting features from {cache_dir}') 102 | data = {} 103 | for cached_file in cached_files: 104 | name = os.path.splitext(os.path.basename(cached_file))[0] 105 | data[name] = torch.load(cached_file) 106 | else: 107 | print(f'Did not find cached features at {cache_dir}. Building from scratch.') 108 | loader = dataset.train_loader if is_train else dataset.test_loader 109 | data = get_features_helper(image_encoder, loader, device) 110 | if image_encoder.cache_dir is None: 111 | print('Not caching because no cache directory was passed.') 112 | else: 113 | os.makedirs(cache_dir, exist_ok=True) 114 | print(f'Caching data at {cache_dir}') 115 | for name, val in data.items(): 116 | torch.save(val, f'{cache_dir}/{name}.pt') 117 | return data 118 | 119 | 120 | class FeatureDataset(Dataset): 121 | def __init__(self, is_train, image_encoder, dataset, device): 122 | self.data = get_features(is_train, image_encoder, dataset, device) 123 | 124 | def __len__(self): 125 | return len(self.data['features']) 126 | 127 | def __getitem__(self, idx): 128 | data = {k: v[idx] for k, v in self.data.items()} 129 | data['features'] = torch.from_numpy(data['features']).float() 130 | return data 131 | 132 | 133 | def get_dataloader(dataset, is_train, args, image_encoder=None): 134 | if image_encoder is not None: 135 | feature_dataset = FeatureDataset(is_train, image_encoder, dataset, args.device) 136 | dataloader = DataLoader(feature_dataset, batch_size=args.batch_size, shuffle=is_train) 137 | else: 138 | dataloader = dataset.train_loader if is_train else dataset.test_loader 139 | return dataloader -------------------------------------------------------------------------------- /src/datasets/dtd.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torchvision.datasets as datasets 4 | 5 | 6 | class DTD: 7 | def __init__(self, 8 | preprocess, 9 | location=os.path.expanduser('~/data'), 10 | batch_size=32, 11 | num_workers=16): 12 | # Data loading code 13 | traindir = os.path.join(location, 'dtd', 'train') 14 | valdir = os.path.join(location, 'dtd', 'val') 15 | 16 | self.train_dataset = datasets.ImageFolder( 17 | traindir, transform=preprocess) 18 | self.train_loader = torch.utils.data.DataLoader( 19 | self.train_dataset, 20 | shuffle=True, 21 | batch_size=batch_size, 22 | num_workers=num_workers, 23 | ) 24 | 25 | self.test_dataset = datasets.ImageFolder(valdir, transform=preprocess) 26 | self.test_loader = torch.utils.data.DataLoader( 27 | self.test_dataset, 28 | batch_size=batch_size, 29 | num_workers=num_workers 30 | ) 31 | idx_to_class = dict((v, k) 32 | for k, v in self.train_dataset.class_to_idx.items()) 33 | self.classnames = [idx_to_class[i].replace( 34 | '_', ' ') for i in range(len(idx_to_class))] -------------------------------------------------------------------------------- /src/datasets/eurosat.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torchvision.datasets as datasets 4 | import re 5 | 6 | def pretify_classname(classname): 7 | l = re.findall(r'[A-Z](?:[a-z]+|[A-Z]*(?=[A-Z]|$))', classname) 8 | l = [i.lower() for i in l] 9 | out = ' '.join(l) 10 | if out.endswith('al'): 11 | return out + ' area' 12 | return out 13 | 14 | class EuroSATBase: 15 | def __init__(self, 16 | preprocess, 17 | test_split, 18 | location='~/datasets', 19 | batch_size=32, 20 | num_workers=16): 21 | # Data loading code 22 | traindir = os.path.join(location, 'EuroSAT_splits', 'train') 23 | testdir = os.path.join(location, 'EuroSAT_splits', test_split) 24 | 25 | 26 | self.train_dataset = datasets.ImageFolder(traindir, transform=preprocess) 27 | self.train_loader = torch.utils.data.DataLoader( 28 | self.train_dataset, 29 | shuffle=True, 30 | batch_size=batch_size, 31 | num_workers=num_workers, 32 | ) 33 | 34 | self.test_dataset = datasets.ImageFolder(testdir, transform=preprocess) 35 | self.test_loader = torch.utils.data.DataLoader( 36 | self.test_dataset, 37 | batch_size=batch_size, 38 | num_workers=num_workers 39 | ) 40 | idx_to_class = dict((v, k) 41 | for k, v in self.train_dataset.class_to_idx.items()) 42 | self.classnames = [idx_to_class[i].replace('_', ' ') for i in range(len(idx_to_class))] 43 | self.classnames = [pretify_classname(c) for c in self.classnames] 44 | ours_to_open_ai = { 45 | 'annual crop': 'annual crop land', 46 | 'forest': 'forest', 47 | 'herbaceous vegetation': 'brushland or shrubland', 48 | 'highway': 'highway or road', 49 | 'industrial area': 'industrial buildings or commercial buildings', 50 | 'pasture': 'pasture land', 51 | 'permanent crop': 'permanent crop land', 52 | 'residential area': 'residential buildings or homes or apartments', 53 | 'river': 'river', 54 | 'sea lake': 'lake or sea', 55 | } 56 | for i in range(len(self.classnames)): 57 | self.classnames[i] = ours_to_open_ai[self.classnames[i]] 58 | 59 | 60 | class EuroSAT(EuroSATBase): 61 | def __init__(self, 62 | preprocess, 63 | location='~/datasets', 64 | batch_size=32, 65 | num_workers=16): 66 | super().__init__(preprocess, 'test', location, batch_size, num_workers) 67 | 68 | 69 | class EuroSATVal(EuroSATBase): 70 | def __init__(self, 71 | preprocess, 72 | location='~/datasets', 73 | batch_size=32, 74 | num_workers=16): 75 | super().__init__(preprocess, 'val', location, batch_size, num_workers) 76 | -------------------------------------------------------------------------------- /src/datasets/gtsrb.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import os 3 | import pathlib 4 | from typing import Any, Callable, Dict, List, Optional, Tuple 5 | 6 | import numpy as np 7 | import PIL 8 | import torch 9 | from torchvision.datasets.folder import make_dataset 10 | from torchvision.datasets.utils import (download_and_extract_archive, 11 | verify_str_arg) 12 | from torchvision.datasets.vision import VisionDataset 13 | 14 | def find_classes(directory: str) -> Tuple[List[str], Dict[str, int]]: 15 | """Finds the class folders in a dataset. 16 | 17 | See :class:`DatasetFolder` for details. 18 | """ 19 | classes = sorted(entry.name for entry in os.scandir(directory) if entry.is_dir()) 20 | if not classes: 21 | raise FileNotFoundError(f"Couldn't find any class folder in {directory}.") 22 | 23 | class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)} 24 | return classes, class_to_idx 25 | 26 | class PyTorchGTSRB(VisionDataset): 27 | """`German Traffic Sign Recognition Benchmark (GTSRB) `_ Dataset. 28 | 29 | Modified from https://pytorch.org/vision/main/_modules/torchvision/datasets/gtsrb.html#GTSRB. 30 | 31 | Args: 32 | root (string): Root directory of the dataset. 33 | split (string, optional): The dataset split, supports ``"train"`` (default), or ``"test"``. 34 | transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed 35 | version. E.g, ``transforms.RandomCrop``. 36 | target_transform (callable, optional): A function/transform that takes in the target and transforms it. 37 | download (bool, optional): If True, downloads the dataset from the internet and 38 | puts it in root directory. If dataset is already downloaded, it is not 39 | downloaded again. 40 | """ 41 | 42 | def __init__( 43 | self, 44 | root: str, 45 | split: str = "train", 46 | transform: Optional[Callable] = None, 47 | target_transform: Optional[Callable] = None, 48 | download: bool = False, 49 | ) -> None: 50 | 51 | super().__init__(root, transform=transform, target_transform=target_transform) 52 | 53 | self._split = verify_str_arg(split, "split", ("train", "test")) 54 | self._base_folder = pathlib.Path(root) / "gtsrb" 55 | self._target_folder = ( 56 | self._base_folder / "GTSRB" / ("Training" if self._split == "train" else "Final_Test/Images") 57 | ) 58 | 59 | if download: 60 | self.download() 61 | 62 | if not self._check_exists(): 63 | raise RuntimeError("Dataset not found. You can use download=True to download it") 64 | 65 | if self._split == "train": 66 | _, class_to_idx = find_classes(str(self._target_folder)) 67 | samples = make_dataset(str(self._target_folder), extensions=(".ppm",), class_to_idx=class_to_idx) 68 | else: 69 | with open(self._base_folder / "GT-final_test.csv") as csv_file: 70 | samples = [ 71 | (str(self._target_folder / row["Filename"]), int(row["ClassId"])) 72 | for row in csv.DictReader(csv_file, delimiter=";", skipinitialspace=True) 73 | ] 74 | 75 | self._samples = samples 76 | self.transform = transform 77 | self.target_transform = target_transform 78 | 79 | def __len__(self) -> int: 80 | return len(self._samples) 81 | 82 | def __getitem__(self, index: int) -> Tuple[Any, Any]: 83 | 84 | path, target = self._samples[index] 85 | sample = PIL.Image.open(path).convert("RGB") 86 | 87 | if self.transform is not None: 88 | sample = self.transform(sample) 89 | 90 | if self.target_transform is not None: 91 | target = self.target_transform(target) 92 | 93 | return sample, target 94 | 95 | 96 | def _check_exists(self) -> bool: 97 | return self._target_folder.is_dir() 98 | 99 | def download(self) -> None: 100 | if self._check_exists(): 101 | return 102 | 103 | base_url = "https://sid.erda.dk/public/archives/daaeac0d7ce1152aea9b61d9f1e19370/" 104 | 105 | if self._split == "train": 106 | download_and_extract_archive( 107 | f"{base_url}GTSRB-Training_fixed.zip", 108 | download_root=str(self._base_folder), 109 | md5="513f3c79a4c5141765e10e952eaa2478", 110 | ) 111 | else: 112 | download_and_extract_archive( 113 | f"{base_url}GTSRB_Final_Test_Images.zip", 114 | download_root=str(self._base_folder), 115 | md5="c7e4e6327067d32654124b0fe9e82185", 116 | ) 117 | download_and_extract_archive( 118 | f"{base_url}GTSRB_Final_Test_GT.zip", 119 | download_root=str(self._base_folder), 120 | md5="fe31e9c9270bbcd7b84b7f21a9d9d9e5", 121 | ) 122 | 123 | 124 | class GTSRB: 125 | def __init__(self, 126 | preprocess, 127 | location=os.path.expanduser('~/data'), 128 | batch_size=128, 129 | num_workers=16): 130 | 131 | # to fit with repo conventions for location 132 | self.train_dataset = PyTorchGTSRB( 133 | root=location, 134 | download=True, 135 | split='train', 136 | transform=preprocess 137 | ) 138 | 139 | self.train_loader = torch.utils.data.DataLoader( 140 | self.train_dataset, 141 | batch_size=batch_size, 142 | shuffle=True, 143 | num_workers=num_workers 144 | ) 145 | 146 | self.test_dataset = PyTorchGTSRB( 147 | root=location, 148 | download=True, 149 | split='test', 150 | transform=preprocess 151 | ) 152 | 153 | self.test_loader = torch.utils.data.DataLoader( 154 | self.test_dataset, 155 | batch_size=batch_size, 156 | shuffle=False, 157 | num_workers=num_workers 158 | ) 159 | 160 | # from https://github.com/openai/CLIP/blob/e184f608c5d5e58165682f7c332c3a8b4c1545f2/data/prompts.md 161 | self.classnames = [ 162 | 'red and white circle 20 kph speed limit', 163 | 'red and white circle 30 kph speed limit', 164 | 'red and white circle 50 kph speed limit', 165 | 'red and white circle 60 kph speed limit', 166 | 'red and white circle 70 kph speed limit', 167 | 'red and white circle 80 kph speed limit', 168 | 'end / de-restriction of 80 kph speed limit', 169 | 'red and white circle 100 kph speed limit', 170 | 'red and white circle 120 kph speed limit', 171 | 'red and white circle red car and black car no passing', 172 | 'red and white circle red truck and black car no passing', 173 | 'red and white triangle road intersection warning', 174 | 'white and yellow diamond priority road', 175 | 'red and white upside down triangle yield right-of-way', 176 | 'stop', 177 | 'empty red and white circle', 178 | 'red and white circle no truck entry', 179 | 'red circle with white horizonal stripe no entry', 180 | 'red and white triangle with exclamation mark warning', 181 | 'red and white triangle with black left curve approaching warning', 182 | 'red and white triangle with black right curve approaching warning', 183 | 'red and white triangle with black double curve approaching warning', 184 | 'red and white triangle rough / bumpy road warning', 185 | 'red and white triangle car skidding / slipping warning', 186 | 'red and white triangle with merging / narrow lanes warning', 187 | 'red and white triangle with person digging / construction / road work warning', 188 | 'red and white triangle with traffic light approaching warning', 189 | 'red and white triangle with person walking warning', 190 | 'red and white triangle with child and person walking warning', 191 | 'red and white triangle with bicyle warning', 192 | 'red and white triangle with snowflake / ice warning', 193 | 'red and white triangle with deer warning', 194 | 'white circle with gray strike bar no speed limit', 195 | 'blue circle with white right turn arrow mandatory', 196 | 'blue circle with white left turn arrow mandatory', 197 | 'blue circle with white forward arrow mandatory', 198 | 'blue circle with white forward or right turn arrow mandatory', 199 | 'blue circle with white forward or left turn arrow mandatory', 200 | 'blue circle with white keep right arrow mandatory', 201 | 'blue circle with white keep left arrow mandatory', 202 | 'blue circle with white arrows indicating a traffic circle', 203 | 'white circle with gray strike bar indicating no passing for cars has ended', 204 | 'white circle with gray strike bar indicating no passing for trucks has ended', 205 | ] 206 | -------------------------------------------------------------------------------- /src/datasets/imagenet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | from .common import ImageFolderWithPaths, SubsetSampler 5 | import numpy as np 6 | 7 | 8 | imagenet_classnames = [ 9 | "tench", "goldfish", "great white shark", "tiger shark", "hammerhead shark", "electric ray", 10 | "stingray", "rooster", "hen", "ostrich", "brambling", "goldfinch", "house finch", "junco", 11 | "indigo bunting", "American robin", "bulbul", "jay", "magpie", "chickadee", "American dipper", 12 | "kite (bird of prey)", "bald eagle", "vulture", "great grey owl", "fire salamander", 13 | "smooth newt", "newt", "spotted salamander", "axolotl", "American bullfrog", "tree frog", 14 | "tailed frog", "loggerhead sea turtle", "leatherback sea turtle", "mud turtle", "terrapin", 15 | "box turtle", "banded gecko", "green iguana", "Carolina anole", 16 | "desert grassland whiptail lizard", "agama", "frilled-necked lizard", "alligator lizard", 17 | "Gila monster", "European green lizard", "chameleon", "Komodo dragon", "Nile crocodile", 18 | "American alligator", "triceratops", "worm snake", "ring-necked snake", 19 | "eastern hog-nosed snake", "smooth green snake", "kingsnake", "garter snake", "water snake", 20 | "vine snake", "night snake", "boa constrictor", "African rock python", "Indian cobra", 21 | "green mamba", "sea snake", "Saharan horned viper", "eastern diamondback rattlesnake", 22 | "sidewinder rattlesnake", "trilobite", "harvestman", "scorpion", "yellow garden spider", 23 | "barn spider", "European garden spider", "southern black widow", "tarantula", "wolf spider", 24 | "tick", "centipede", "black grouse", "ptarmigan", "ruffed grouse", "prairie grouse", "peafowl", 25 | "quail", "partridge", "african grey parrot", "macaw", "sulphur-crested cockatoo", "lorikeet", 26 | "coucal", "bee eater", "hornbill", "hummingbird", "jacamar", "toucan", "duck", 27 | "red-breasted merganser", "goose", "black swan", "tusker", "echidna", "platypus", "wallaby", 28 | "koala", "wombat", "jellyfish", "sea anemone", "brain coral", "flatworm", "nematode", "conch", 29 | "snail", "slug", "sea slug", "chiton", "chambered nautilus", "Dungeness crab", "rock crab", 30 | "fiddler crab", "red king crab", "American lobster", "spiny lobster", "crayfish", "hermit crab", 31 | "isopod", "white stork", "black stork", "spoonbill", "flamingo", "little blue heron", 32 | "great egret", "bittern bird", "crane bird", "limpkin", "common gallinule", "American coot", 33 | "bustard", "ruddy turnstone", "dunlin", "common redshank", "dowitcher", "oystercatcher", 34 | "pelican", "king penguin", "albatross", "grey whale", "killer whale", "dugong", "sea lion", 35 | "Chihuahua", "Japanese Chin", "Maltese", "Pekingese", "Shih Tzu", "King Charles Spaniel", 36 | "Papillon", "toy terrier", "Rhodesian Ridgeback", "Afghan Hound", "Basset Hound", "Beagle", 37 | "Bloodhound", "Bluetick Coonhound", "Black and Tan Coonhound", "Treeing Walker Coonhound", 38 | "English foxhound", "Redbone Coonhound", "borzoi", "Irish Wolfhound", "Italian Greyhound", 39 | "Whippet", "Ibizan Hound", "Norwegian Elkhound", "Otterhound", "Saluki", "Scottish Deerhound", 40 | "Weimaraner", "Staffordshire Bull Terrier", "American Staffordshire Terrier", 41 | "Bedlington Terrier", "Border Terrier", "Kerry Blue Terrier", "Irish Terrier", 42 | "Norfolk Terrier", "Norwich Terrier", "Yorkshire Terrier", "Wire Fox Terrier", 43 | "Lakeland Terrier", "Sealyham Terrier", "Airedale Terrier", "Cairn Terrier", 44 | "Australian Terrier", "Dandie Dinmont Terrier", "Boston Terrier", "Miniature Schnauzer", 45 | "Giant Schnauzer", "Standard Schnauzer", "Scottish Terrier", "Tibetan Terrier", 46 | "Australian Silky Terrier", "Soft-coated Wheaten Terrier", "West Highland White Terrier", 47 | "Lhasa Apso", "Flat-Coated Retriever", "Curly-coated Retriever", "Golden Retriever", 48 | "Labrador Retriever", "Chesapeake Bay Retriever", "German Shorthaired Pointer", "Vizsla", 49 | "English Setter", "Irish Setter", "Gordon Setter", "Brittany dog", "Clumber Spaniel", 50 | "English Springer Spaniel", "Welsh Springer Spaniel", "Cocker Spaniel", "Sussex Spaniel", 51 | "Irish Water Spaniel", "Kuvasz", "Schipperke", "Groenendael dog", "Malinois", "Briard", 52 | "Australian Kelpie", "Komondor", "Old English Sheepdog", "Shetland Sheepdog", "collie", 53 | "Border Collie", "Bouvier des Flandres dog", "Rottweiler", "German Shepherd Dog", "Dobermann", 54 | "Miniature Pinscher", "Greater Swiss Mountain Dog", "Bernese Mountain Dog", 55 | "Appenzeller Sennenhund", "Entlebucher Sennenhund", "Boxer", "Bullmastiff", "Tibetan Mastiff", 56 | "French Bulldog", "Great Dane", "St. Bernard", "husky", "Alaskan Malamute", "Siberian Husky", 57 | "Dalmatian", "Affenpinscher", "Basenji", "pug", "Leonberger", "Newfoundland dog", 58 | "Great Pyrenees dog", "Samoyed", "Pomeranian", "Chow Chow", "Keeshond", "brussels griffon", 59 | "Pembroke Welsh Corgi", "Cardigan Welsh Corgi", "Toy Poodle", "Miniature Poodle", 60 | "Standard Poodle", "Mexican hairless dog (xoloitzcuintli)", "grey wolf", "Alaskan tundra wolf", 61 | "red wolf or maned wolf", "coyote", "dingo", "dhole", "African wild dog", "hyena", "red fox", 62 | "kit fox", "Arctic fox", "grey fox", "tabby cat", "tiger cat", "Persian cat", "Siamese cat", 63 | "Egyptian Mau", "cougar", "lynx", "leopard", "snow leopard", "jaguar", "lion", "tiger", 64 | "cheetah", "brown bear", "American black bear", "polar bear", "sloth bear", "mongoose", 65 | "meerkat", "tiger beetle", "ladybug", "ground beetle", "longhorn beetle", "leaf beetle", 66 | "dung beetle", "rhinoceros beetle", "weevil", "fly", "bee", "ant", "grasshopper", 67 | "cricket insect", "stick insect", "cockroach", "praying mantis", "cicada", "leafhopper", 68 | "lacewing", "dragonfly", "damselfly", "red admiral butterfly", "ringlet butterfly", 69 | "monarch butterfly", "small white butterfly", "sulphur butterfly", "gossamer-winged butterfly", 70 | "starfish", "sea urchin", "sea cucumber", "cottontail rabbit", "hare", "Angora rabbit", 71 | "hamster", "porcupine", "fox squirrel", "marmot", "beaver", "guinea pig", "common sorrel horse", 72 | "zebra", "pig", "wild boar", "warthog", "hippopotamus", "ox", "water buffalo", "bison", 73 | "ram (adult male sheep)", "bighorn sheep", "Alpine ibex", "hartebeest", "impala (antelope)", 74 | "gazelle", "arabian camel", "llama", "weasel", "mink", "European polecat", 75 | "black-footed ferret", "otter", "skunk", "badger", "armadillo", "three-toed sloth", "orangutan", 76 | "gorilla", "chimpanzee", "gibbon", "siamang", "guenon", "patas monkey", "baboon", "macaque", 77 | "langur", "black-and-white colobus", "proboscis monkey", "marmoset", "white-headed capuchin", 78 | "howler monkey", "titi monkey", "Geoffroy's spider monkey", "common squirrel monkey", 79 | "ring-tailed lemur", "indri", "Asian elephant", "African bush elephant", "red panda", 80 | "giant panda", "snoek fish", "eel", "silver salmon", "rock beauty fish", "clownfish", 81 | "sturgeon", "gar fish", "lionfish", "pufferfish", "abacus", "abaya", "academic gown", 82 | "accordion", "acoustic guitar", "aircraft carrier", "airliner", "airship", "altar", "ambulance", 83 | "amphibious vehicle", "analog clock", "apiary", "apron", "trash can", "assault rifle", 84 | "backpack", "bakery", "balance beam", "balloon", "ballpoint pen", "Band-Aid", "banjo", 85 | "baluster / handrail", "barbell", "barber chair", "barbershop", "barn", "barometer", "barrel", 86 | "wheelbarrow", "baseball", "basketball", "bassinet", "bassoon", "swimming cap", "bath towel", 87 | "bathtub", "station wagon", "lighthouse", "beaker", "military hat (bearskin or shako)", 88 | "beer bottle", "beer glass", "bell tower", "baby bib", "tandem bicycle", "bikini", 89 | "ring binder", "binoculars", "birdhouse", "boathouse", "bobsleigh", "bolo tie", "poke bonnet", 90 | "bookcase", "bookstore", "bottle cap", "hunting bow", "bow tie", "brass memorial plaque", "bra", 91 | "breakwater", "breastplate", "broom", "bucket", "buckle", "bulletproof vest", 92 | "high-speed train", "butcher shop", "taxicab", "cauldron", "candle", "cannon", "canoe", 93 | "can opener", "cardigan", "car mirror", "carousel", "tool kit", "cardboard box / carton", 94 | "car wheel", "automated teller machine", "cassette", "cassette player", "castle", "catamaran", 95 | "CD player", "cello", "mobile phone", "chain", "chain-link fence", "chain mail", "chainsaw", 96 | "storage chest", "chiffonier", "bell or wind chime", "china cabinet", "Christmas stocking", 97 | "church", "movie theater", "cleaver", "cliff dwelling", "cloak", "clogs", "cocktail shaker", 98 | "coffee mug", "coffeemaker", "spiral or coil", "combination lock", "computer keyboard", 99 | "candy store", "container ship", "convertible", "corkscrew", "cornet", "cowboy boot", 100 | "cowboy hat", "cradle", "construction crane", "crash helmet", "crate", "infant bed", 101 | "Crock Pot", "croquet ball", "crutch", "cuirass", "dam", "desk", "desktop computer", 102 | "rotary dial telephone", "diaper", "digital clock", "digital watch", "dining table", 103 | "dishcloth", "dishwasher", "disc brake", "dock", "dog sled", "dome", "doormat", "drilling rig", 104 | "drum", "drumstick", "dumbbell", "Dutch oven", "electric fan", "electric guitar", 105 | "electric locomotive", "entertainment center", "envelope", "espresso machine", "face powder", 106 | "feather boa", "filing cabinet", "fireboat", "fire truck", "fire screen", "flagpole", "flute", 107 | "folding chair", "football helmet", "forklift", "fountain", "fountain pen", "four-poster bed", 108 | "freight car", "French horn", "frying pan", "fur coat", "garbage truck", 109 | "gas mask or respirator", "gas pump", "goblet", "go-kart", "golf ball", "golf cart", "gondola", 110 | "gong", "gown", "grand piano", "greenhouse", "radiator grille", "grocery store", "guillotine", 111 | "hair clip", "hair spray", "half-track", "hammer", "hamper", "hair dryer", "hand-held computer", 112 | "handkerchief", "hard disk drive", "harmonica", "harp", "combine harvester", "hatchet", 113 | "holster", "home theater", "honeycomb", "hook", "hoop skirt", "gymnastic horizontal bar", 114 | "horse-drawn vehicle", "hourglass", "iPod", "clothes iron", "carved pumpkin", "jeans", "jeep", 115 | "T-shirt", "jigsaw puzzle", "rickshaw", "joystick", "kimono", "knee pad", "knot", "lab coat", 116 | "ladle", "lampshade", "laptop computer", "lawn mower", "lens cap", "letter opener", "library", 117 | "lifeboat", "lighter", "limousine", "ocean liner", "lipstick", "slip-on shoe", "lotion", 118 | "music speaker", "loupe magnifying glass", "sawmill", "magnetic compass", "messenger bag", 119 | "mailbox", "tights", "one-piece bathing suit", "manhole cover", "maraca", "marimba", "mask", 120 | "matchstick", "maypole", "maze", "measuring cup", "medicine cabinet", "megalith", "microphone", 121 | "microwave oven", "military uniform", "milk can", "minibus", "miniskirt", "minivan", "missile", 122 | "mitten", "mixing bowl", "mobile home", "ford model t", "modem", "monastery", "monitor", 123 | "moped", "mortar and pestle", "graduation cap", "mosque", "mosquito net", "vespa", 124 | "mountain bike", "tent", "computer mouse", "mousetrap", "moving van", "muzzle", "metal nail", 125 | "neck brace", "necklace", "baby pacifier", "notebook computer", "obelisk", "oboe", "ocarina", 126 | "odometer", "oil filter", "pipe organ", "oscilloscope", "overskirt", "bullock cart", 127 | "oxygen mask", "product packet / packaging", "paddle", "paddle wheel", "padlock", "paintbrush", 128 | "pajamas", "palace", "pan flute", "paper towel", "parachute", "parallel bars", "park bench", 129 | "parking meter", "railroad car", "patio", "payphone", "pedestal", "pencil case", 130 | "pencil sharpener", "perfume", "Petri dish", "photocopier", "plectrum", "Pickelhaube", 131 | "picket fence", "pickup truck", "pier", "piggy bank", "pill bottle", "pillow", "ping-pong ball", 132 | "pinwheel", "pirate ship", "drink pitcher", "block plane", "planetarium", "plastic bag", 133 | "plate rack", "farm plow", "plunger", "Polaroid camera", "pole", "police van", "poncho", 134 | "pool table", "soda bottle", "plant pot", "potter's wheel", "power drill", "prayer rug", 135 | "printer", "prison", "missile", "projector", "hockey puck", "punching bag", "purse", "quill", 136 | "quilt", "race car", "racket", "radiator", "radio", "radio telescope", "rain barrel", 137 | "recreational vehicle", "fishing casting reel", "reflex camera", "refrigerator", 138 | "remote control", "restaurant", "revolver", "rifle", "rocking chair", "rotisserie", "eraser", 139 | "rugby ball", "ruler measuring stick", "sneaker", "safe", "safety pin", "salt shaker", "sandal", 140 | "sarong", "saxophone", "scabbard", "weighing scale", "school bus", "schooner", "scoreboard", 141 | "CRT monitor", "screw", "screwdriver", "seat belt", "sewing machine", "shield", "shoe store", 142 | "shoji screen / room divider", "shopping basket", "shopping cart", "shovel", "shower cap", 143 | "shower curtain", "ski", "balaclava ski mask", "sleeping bag", "slide rule", "sliding door", 144 | "slot machine", "snorkel", "snowmobile", "snowplow", "soap dispenser", "soccer ball", "sock", 145 | "solar thermal collector", "sombrero", "soup bowl", "keyboard space bar", "space heater", 146 | "space shuttle", "spatula", "motorboat", "spider web", "spindle", "sports car", "spotlight", 147 | "stage", "steam locomotive", "through arch bridge", "steel drum", "stethoscope", "scarf", 148 | "stone wall", "stopwatch", "stove", "strainer", "tram", "stretcher", "couch", "stupa", 149 | "submarine", "suit", "sundial", "sunglasses", "sunglasses", "sunscreen", "suspension bridge", 150 | "mop", "sweatshirt", "swim trunks / shorts", "swing", "electrical switch", "syringe", 151 | "table lamp", "tank", "tape player", "teapot", "teddy bear", "television", "tennis ball", 152 | "thatched roof", "front curtain", "thimble", "threshing machine", "throne", "tile roof", 153 | "toaster", "tobacco shop", "toilet seat", "torch", "totem pole", "tow truck", "toy store", 154 | "tractor", "semi-trailer truck", "tray", "trench coat", "tricycle", "trimaran", "tripod", 155 | "triumphal arch", "trolleybus", "trombone", "hot tub", "turnstile", "typewriter keyboard", 156 | "umbrella", "unicycle", "upright piano", "vacuum cleaner", "vase", "vaulted or arched ceiling", 157 | "velvet fabric", "vending machine", "vestment", "viaduct", "violin", "volleyball", 158 | "waffle iron", "wall clock", "wallet", "wardrobe", "military aircraft", "sink", 159 | "washing machine", "water bottle", "water jug", "water tower", "whiskey jug", "whistle", 160 | "hair wig", "window screen", "window shade", "Windsor tie", "wine bottle", "airplane wing", 161 | "wok", "wooden spoon", "wool", "split-rail fence", "shipwreck", "sailboat", "yurt", "website", 162 | "comic book", "crossword", "traffic or street sign", "traffic light", "dust jacket", "menu", 163 | "plate", "guacamole", "consomme", "hot pot", "trifle", "ice cream", "popsicle", "baguette", 164 | "bagel", "pretzel", "cheeseburger", "hot dog", "mashed potatoes", "cabbage", "broccoli", 165 | "cauliflower", "zucchini", "spaghetti squash", "acorn squash", "butternut squash", "cucumber", 166 | "artichoke", "bell pepper", "cardoon", "mushroom", "Granny Smith apple", "strawberry", "orange", 167 | "lemon", "fig", "pineapple", "banana", "jackfruit", "cherimoya (custard apple)", "pomegranate", 168 | "hay", "carbonara", "chocolate syrup", "dough", "meatloaf", "pizza", "pot pie", "burrito", 169 | "red wine", "espresso", "tea cup", "eggnog", "mountain", "bubble", "cliff", "coral reef", 170 | "geyser", "lakeshore", "promontory", "sandbar", "beach", "valley", "volcano", "baseball player", 171 | "bridegroom", "scuba diver", "rapeseed", "daisy", "yellow lady's slipper", "corn", "acorn", 172 | "rose hip", "horse chestnut seed", "coral fungus", "agaric", "gyromitra", "stinkhorn mushroom", 173 | "earth star fungus", "hen of the woods mushroom", "bolete", "corn cob", "toilet paper" 174 | ] 175 | 176 | class ImageNet: 177 | def __init__(self, 178 | preprocess, 179 | location=os.path.expanduser('~/data'), 180 | batch_size=32, 181 | num_workers=32): 182 | self.preprocess = preprocess 183 | self.location = location 184 | self.batch_size = batch_size 185 | self.num_workers = num_workers 186 | self.classnames = imagenet_classnames 187 | 188 | self.populate_train() 189 | self.populate_test() 190 | 191 | def populate_train(self): 192 | traindir = os.path.join(self.location, self.name(), 'train') 193 | self.train_dataset = ImageFolderWithPaths( 194 | traindir, 195 | transform=self.preprocess) 196 | sampler = self.get_train_sampler() 197 | kwargs = {'shuffle' : True} if sampler is None else {} 198 | self.train_loader = torch.utils.data.DataLoader( 199 | self.train_dataset, 200 | sampler=sampler, 201 | batch_size=self.batch_size, 202 | num_workers=self.num_workers, 203 | **kwargs, 204 | ) 205 | 206 | def populate_test(self): 207 | self.test_dataset = self.get_test_dataset() 208 | self.test_loader = torch.utils.data.DataLoader( 209 | self.test_dataset, 210 | batch_size=self.batch_size, 211 | num_workers=self.num_workers, 212 | sampler=self.get_test_sampler() 213 | ) 214 | 215 | def get_test_path(self): 216 | test_path = os.path.join(self.location, self.name(), 'val_in_folder') 217 | if not os.path.exists(test_path): 218 | test_path = os.path.join(self.location, self.name(), 'val') 219 | return test_path 220 | 221 | def get_train_sampler(self): 222 | return None 223 | 224 | def get_test_sampler(self): 225 | return None 226 | 227 | def get_test_dataset(self): 228 | return ImageFolderWithPaths(self.get_test_path(), transform=self.preprocess) 229 | 230 | def name(self): 231 | return 'imagenet' 232 | 233 | class ImageNetTrain(ImageNet): 234 | 235 | def get_test_dataset(self): 236 | pass 237 | 238 | class ImageNetK(ImageNet): 239 | 240 | def get_train_sampler(self): 241 | idxs = np.zeros(len(self.train_dataset.targets)) 242 | target_array = np.array(self.train_dataset.targets) 243 | for c in range(1000): 244 | m = target_array == c 245 | n = len(idxs[m]) 246 | arr = np.zeros(n) 247 | arr[:self.k()] = 1 248 | np.random.shuffle(arr) 249 | idxs[m] = arr 250 | 251 | idxs = idxs.astype('int') 252 | sampler = SubsetSampler(np.where(idxs)[0]) 253 | return sampler -------------------------------------------------------------------------------- /src/datasets/mnist.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torchvision.datasets as datasets 4 | 5 | class MNIST: 6 | def __init__(self, 7 | preprocess, 8 | location=os.path.expanduser('~/data'), 9 | batch_size=128, 10 | num_workers=16): 11 | 12 | 13 | self.train_dataset = datasets.MNIST( 14 | root=location, 15 | download=True, 16 | train=True, 17 | transform=preprocess 18 | ) 19 | 20 | self.train_loader = torch.utils.data.DataLoader( 21 | self.train_dataset, 22 | batch_size=batch_size, 23 | shuffle=True, 24 | num_workers=num_workers 25 | ) 26 | 27 | self.test_dataset = datasets.MNIST( 28 | root=location, 29 | download=True, 30 | train=False, 31 | transform=preprocess 32 | ) 33 | 34 | self.test_loader = torch.utils.data.DataLoader( 35 | self.test_dataset, 36 | batch_size=batch_size, 37 | shuffle=False, 38 | num_workers=num_workers 39 | ) 40 | 41 | self.classnames = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'] -------------------------------------------------------------------------------- /src/datasets/registry.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import inspect 3 | import random 4 | import torch 5 | import copy 6 | 7 | from torch.utils.data.dataset import random_split 8 | 9 | from src.datasets.cars import Cars 10 | from src.datasets.cifar10 import CIFAR10 11 | from src.datasets.cifar100 import CIFAR100 12 | from src.datasets.dtd import DTD 13 | from src.datasets.eurosat import EuroSAT, EuroSATVal 14 | from src.datasets.gtsrb import GTSRB 15 | from src.datasets.imagenet import ImageNet 16 | from src.datasets.mnist import MNIST 17 | from src.datasets.resisc45 import RESISC45 18 | from src.datasets.stl10 import STL10 19 | from src.datasets.svhn import SVHN 20 | from src.datasets.sun397 import SUN397 21 | 22 | registry = { 23 | name: obj for name, obj in inspect.getmembers(sys.modules[__name__], inspect.isclass) 24 | } 25 | 26 | 27 | class GenericDataset(object): 28 | def __init__(self): 29 | self.train_dataset = None 30 | self.train_loader = None 31 | self.test_dataset = None 32 | self.test_loader = None 33 | self.classnames = None 34 | 35 | 36 | def split_train_into_train_val(dataset, new_dataset_class_name, batch_size, num_workers, val_fraction, max_val_samples=None, seed=0): 37 | assert val_fraction > 0. and val_fraction < 1. 38 | total_size = len(dataset.train_dataset) 39 | val_size = int(total_size * val_fraction) 40 | if max_val_samples is not None: 41 | val_size = min(val_size, max_val_samples) 42 | train_size = total_size - val_size 43 | 44 | assert val_size > 0 45 | assert train_size > 0 46 | 47 | lengths = [train_size, val_size] 48 | 49 | trainset, valset = random_split( 50 | dataset.train_dataset, 51 | lengths, 52 | generator=torch.Generator().manual_seed(seed) 53 | ) 54 | if new_dataset_class_name == 'MNISTVal': 55 | assert trainset.indices[0] == 36044 56 | 57 | 58 | new_dataset = None 59 | 60 | new_dataset_class = type(new_dataset_class_name, (GenericDataset, ), {}) 61 | new_dataset = new_dataset_class() 62 | 63 | new_dataset.train_dataset = trainset 64 | new_dataset.train_loader = torch.utils.data.DataLoader( 65 | new_dataset.train_dataset, 66 | shuffle=True, 67 | batch_size=batch_size, 68 | num_workers=num_workers, 69 | ) 70 | 71 | new_dataset.test_dataset = valset 72 | new_dataset.test_loader = torch.utils.data.DataLoader( 73 | new_dataset.test_dataset, 74 | batch_size=batch_size, 75 | num_workers=num_workers 76 | ) 77 | 78 | new_dataset.classnames = copy.copy(dataset.classnames) 79 | 80 | return new_dataset 81 | 82 | 83 | def get_dataset(dataset_name, preprocess, location, batch_size=128, num_workers=16, val_fraction=0.1, max_val_samples=5000): 84 | if dataset_name.endswith('Val'): 85 | # Handle val splits 86 | if dataset_name in registry: 87 | dataset_class = registry[dataset_name] 88 | else: 89 | base_dataset_name = dataset_name.split('Val')[0] 90 | base_dataset = get_dataset(base_dataset_name, preprocess, location, batch_size, num_workers) 91 | dataset = split_train_into_train_val( 92 | base_dataset, dataset_name, batch_size, num_workers, val_fraction, max_val_samples) 93 | return dataset 94 | else: 95 | assert dataset_name in registry, f'Unsupported dataset: {dataset_name}. Supported datasets: {list(registry.keys())}' 96 | dataset_class = registry[dataset_name] 97 | dataset = dataset_class( 98 | preprocess, location=location, batch_size=batch_size, num_workers=num_workers 99 | ) 100 | return dataset 101 | -------------------------------------------------------------------------------- /src/datasets/resisc45.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | import abc 5 | import os 6 | from typing import Any, Callable, Dict, Optional, Tuple 7 | 8 | import numpy as np 9 | import torch 10 | from torch import Tensor 11 | from torch.utils.data import Dataset 12 | from torchvision.datasets import ImageFolder 13 | from torchvision.datasets.folder import default_loader as pil_loader 14 | 15 | 16 | # modified from: https://github.com/microsoft/torchgeo 17 | class VisionDataset(Dataset[Dict[str, Any]], abc.ABC): 18 | """Abstract base class for datasets lacking geospatial information. 19 | This base class is designed for datasets with pre-defined image chips. 20 | """ 21 | 22 | @abc.abstractmethod 23 | def __getitem__(self, index: int) -> Dict[str, Any]: 24 | """Return an index within the dataset. 25 | Args: 26 | index: index to return 27 | Returns: 28 | data and labels at that index 29 | Raises: 30 | IndexError: if index is out of range of the dataset 31 | """ 32 | 33 | @abc.abstractmethod 34 | def __len__(self) -> int: 35 | """Return the length of the dataset. 36 | Returns: 37 | length of the dataset 38 | """ 39 | 40 | def __str__(self) -> str: 41 | """Return the informal string representation of the object. 42 | Returns: 43 | informal string representation 44 | """ 45 | return f"""\ 46 | {self.__class__.__name__} Dataset 47 | type: VisionDataset 48 | size: {len(self)}""" 49 | 50 | 51 | class VisionClassificationDataset(VisionDataset, ImageFolder): 52 | """Abstract base class for classification datasets lacking geospatial information. 53 | This base class is designed for datasets with pre-defined image chips which 54 | are separated into separate folders per class. 55 | """ 56 | 57 | def __init__( 58 | self, 59 | root: str, 60 | transforms: Optional[Callable[[Dict[str, Tensor]], Dict[str, Tensor]]] = None, 61 | loader: Optional[Callable[[str], Any]] = pil_loader, 62 | is_valid_file: Optional[Callable[[str], bool]] = None, 63 | ) -> None: 64 | """Initialize a new VisionClassificationDataset instance. 65 | Args: 66 | root: root directory where dataset can be found 67 | transforms: a function/transform that takes input sample and its target as 68 | entry and returns a transformed version 69 | loader: a callable function which takes as input a path to an image and 70 | returns a PIL Image or numpy array 71 | is_valid_file: A function that takes the path of an Image file and checks if 72 | the file is a valid file 73 | """ 74 | # When transform & target_transform are None, ImageFolder.__getitem__(index) 75 | # returns a PIL.Image and int for image and label, respectively 76 | super().__init__( 77 | root=root, 78 | transform=None, 79 | target_transform=None, 80 | loader=loader, 81 | is_valid_file=is_valid_file, 82 | ) 83 | 84 | # Must be set after calling super().__init__() 85 | self.transforms = transforms 86 | 87 | def __getitem__(self, index: int) -> Dict[str, Tensor]: 88 | """Return an index within the dataset. 89 | Args: 90 | index: index to return 91 | Returns: 92 | data and label at that index 93 | """ 94 | image, label = self._load_image(index) 95 | 96 | if self.transforms is not None: 97 | return self.transforms(image), label 98 | 99 | return image, label 100 | 101 | def __len__(self) -> int: 102 | """Return the number of data points in the dataset. 103 | Returns: 104 | length of the dataset 105 | """ 106 | return len(self.imgs) 107 | 108 | def _load_image(self, index: int) -> Tuple[Tensor, Tensor]: 109 | """Load a single image and it's class label. 110 | Args: 111 | index: index to return 112 | Returns: 113 | the image 114 | the image class label 115 | """ 116 | img, label = ImageFolder.__getitem__(self, index) 117 | label = torch.tensor(label) 118 | return img, label 119 | 120 | 121 | class RESISC45Dataset(VisionClassificationDataset): 122 | """RESISC45 dataset. 123 | The `RESISC45 `_ 124 | dataset is a dataset for remote sensing image scene classification. 125 | Dataset features: 126 | * 31,500 images with 0.2-30 m per pixel resolution (256x256 px) 127 | * three spectral bands - RGB 128 | * 45 scene classes, 700 images per class 129 | * images extracted from Google Earth from over 100 countries 130 | * images conditions with high variability (resolution, weather, illumination) 131 | Dataset format: 132 | * images are three-channel jpgs 133 | Dataset classes: 134 | 0. airplane 135 | 1. airport 136 | 2. baseball_diamond 137 | 3. basketball_court 138 | 4. beach 139 | 5. bridge 140 | 6. chaparral 141 | 7. church 142 | 8. circular_farmland 143 | 9. cloud 144 | 10. commercial_area 145 | 11. dense_residential 146 | 12. desert 147 | 13. forest 148 | 14. freeway 149 | 15. golf_course 150 | 16. ground_track_field 151 | 17. harbor 152 | 18. industrial_area 153 | 19. intersection 154 | 20. island 155 | 21. lake 156 | 22. meadow 157 | 23. medium_residential 158 | 24. mobile_home_park 159 | 25. mountain 160 | 26. overpass 161 | 27. palace 162 | 28. parking_lot 163 | 29. railway 164 | 30. railway_station 165 | 31. rectangular_farmland 166 | 32. river 167 | 33. roundabout 168 | 34. runway 169 | 35. sea_ice 170 | 36. ship 171 | 37. snowberg 172 | 38. sparse_residential 173 | 39. stadium 174 | 40. storage_tank 175 | 41. tennis_court 176 | 42. terrace 177 | 43. thermal_power_station 178 | 44. wetland 179 | This dataset uses the train/val/test splits defined in the "In-domain representation 180 | learning for remote sensing" paper: 181 | * https://arxiv.org/abs/1911.06721 182 | If you use this dataset in your research, please cite the following paper: 183 | * https://doi.org/10.1109/jproc.2017.2675998 184 | """ 185 | 186 | # url = "https://drive.google.com/file/d/1DnPSU5nVSN7xv95bpZ3XQ0JhKXZOKgIv" 187 | # md5 = "d824acb73957502b00efd559fc6cfbbb" 188 | # filename = "NWPU-RESISC45.rar" 189 | directory = "resisc45/NWPU-RESISC45" 190 | 191 | splits = ["train", "val", "test"] 192 | split_urls = { 193 | "train": "https://storage.googleapis.com/remote_sensing_representations/resisc45-train.txt", # noqa: E501 194 | "val": "https://storage.googleapis.com/remote_sensing_representations/resisc45-val.txt", # noqa: E501 195 | "test": "https://storage.googleapis.com/remote_sensing_representations/resisc45-test.txt", # noqa: E501 196 | } 197 | split_md5s = { 198 | "train": "b5a4c05a37de15e4ca886696a85c403e", 199 | "val": "a0770cee4c5ca20b8c32bbd61e114805", 200 | "test": "3dda9e4988b47eb1de9f07993653eb08", 201 | } 202 | classes = [ 203 | "airplane", 204 | "airport", 205 | "baseball_diamond", 206 | "basketball_court", 207 | "beach", 208 | "bridge", 209 | "chaparral", 210 | "church", 211 | "circular_farmland", 212 | "cloud", 213 | "commercial_area", 214 | "dense_residential", 215 | "desert", 216 | "forest", 217 | "freeway", 218 | "golf_course", 219 | "ground_track_field", 220 | "harbor", 221 | "industrial_area", 222 | "intersection", 223 | "island", 224 | "lake", 225 | "meadow", 226 | "medium_residential", 227 | "mobile_home_park", 228 | "mountain", 229 | "overpass", 230 | "palace", 231 | "parking_lot", 232 | "railway", 233 | "railway_station", 234 | "rectangular_farmland", 235 | "river", 236 | "roundabout", 237 | "runway", 238 | "sea_ice", 239 | "ship", 240 | "snowberg", 241 | "sparse_residential", 242 | "stadium", 243 | "storage_tank", 244 | "tennis_court", 245 | "terrace", 246 | "thermal_power_station", 247 | "wetland", 248 | ] 249 | 250 | def __init__( 251 | self, 252 | root: str = "data", 253 | split: str = "train", 254 | transforms: Optional[Callable[[Dict[str, Tensor]], Dict[str, Tensor]]] = None, 255 | ) -> None: 256 | """Initialize a new RESISC45 dataset instance. 257 | Args: 258 | root: root directory where dataset can be found 259 | split: one of "train", "val", or "test" 260 | transforms: a function/transform that takes input sample and its target as 261 | entry and returns a transformed version 262 | """ 263 | assert split in self.splits 264 | self.root = root 265 | 266 | valid_fns = set() 267 | with open(os.path.join(self.root, "resisc45", f"resisc45-{split}.txt")) as f: 268 | for fn in f: 269 | valid_fns.add(fn.strip()) 270 | is_in_split: Callable[[str], bool] = lambda x: os.path.basename( 271 | x) in valid_fns 272 | 273 | super().__init__( 274 | root=os.path.join(root, self.directory), 275 | transforms=transforms, 276 | is_valid_file=is_in_split, 277 | ) 278 | 279 | 280 | 281 | class RESISC45: 282 | def __init__(self, 283 | preprocess, 284 | location=os.path.expanduser('~/data'), 285 | batch_size=32, 286 | num_workers=16): 287 | 288 | self.train_dataset = RESISC45Dataset(root=location, split='train', transforms=preprocess) 289 | self.train_loader = torch.utils.data.DataLoader( 290 | self.train_dataset, 291 | shuffle=True, 292 | batch_size=batch_size, 293 | num_workers=num_workers, 294 | ) 295 | 296 | self.test_dataset = RESISC45Dataset(root=location, split='test', transforms=preprocess) 297 | self.test_loader = torch.utils.data.DataLoader( 298 | self.test_dataset, 299 | batch_size=batch_size, 300 | num_workers=num_workers 301 | ) 302 | 303 | # class names have _ so split on this for better zero-shot head 304 | self.classnames = [' '.join(c.split('_')) for c in RESISC45Dataset.classes] 305 | -------------------------------------------------------------------------------- /src/datasets/stl10.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torchvision.datasets as datasets 4 | 5 | class STL10: 6 | def __init__(self, 7 | preprocess, 8 | location=os.path.expanduser('~/data'), 9 | batch_size=128, 10 | num_workers=16): 11 | 12 | location = os.path.join(location, 'stl10') 13 | self.train_dataset = datasets.STL10( 14 | root=location, 15 | download=True, 16 | split='train', 17 | transform=preprocess 18 | ) 19 | 20 | self.train_loader = torch.utils.data.DataLoader( 21 | self.train_dataset, 22 | batch_size=batch_size, 23 | shuffle=True, 24 | num_workers=num_workers 25 | ) 26 | 27 | self.test_dataset = datasets.STL10( 28 | root=location, 29 | download=True, 30 | split='test', 31 | transform=preprocess 32 | ) 33 | 34 | self.test_loader = torch.utils.data.DataLoader( 35 | self.test_dataset, 36 | batch_size=batch_size, 37 | shuffle=False, 38 | num_workers=num_workers 39 | ) 40 | 41 | self.classnames = self.train_dataset.classes -------------------------------------------------------------------------------- /src/datasets/sun397.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torchvision.datasets as datasets 4 | 5 | class SUN397: 6 | def __init__(self, 7 | preprocess, 8 | location=os.path.expanduser('~/data'), 9 | batch_size=32, 10 | num_workers=16): 11 | # Data loading code 12 | traindir = os.path.join(location, 'sun397', 'train') 13 | valdir = os.path.join(location, 'sun397', 'val') 14 | 15 | 16 | self.train_dataset = datasets.ImageFolder(traindir, transform=preprocess) 17 | self.train_loader = torch.utils.data.DataLoader( 18 | self.train_dataset, 19 | shuffle=True, 20 | batch_size=batch_size, 21 | num_workers=num_workers, 22 | ) 23 | 24 | self.test_dataset = datasets.ImageFolder(valdir, transform=preprocess) 25 | self.test_loader = torch.utils.data.DataLoader( 26 | self.test_dataset, 27 | batch_size=batch_size, 28 | num_workers=num_workers 29 | ) 30 | idx_to_class = dict((v, k) 31 | for k, v in self.train_dataset.class_to_idx.items()) 32 | self.classnames = [idx_to_class[i][2:].replace('_', ' ') for i in range(len(idx_to_class))] 33 | -------------------------------------------------------------------------------- /src/datasets/svhn.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torchvision.datasets import SVHN as PyTorchSVHN 4 | import numpy as np 5 | 6 | 7 | class SVHN: 8 | def __init__(self, 9 | preprocess, 10 | location=os.path.expanduser('~/data'), 11 | batch_size=128, 12 | num_workers=16): 13 | 14 | # to fit with repo conventions for location 15 | modified_location = os.path.join(location, 'svhn') 16 | 17 | self.train_dataset = PyTorchSVHN( 18 | root=modified_location, 19 | download=True, 20 | split='train', 21 | transform=preprocess 22 | ) 23 | 24 | self.train_loader = torch.utils.data.DataLoader( 25 | self.train_dataset, 26 | batch_size=batch_size, 27 | shuffle=True, 28 | num_workers=num_workers 29 | ) 30 | 31 | self.test_dataset = PyTorchSVHN( 32 | root=modified_location, 33 | download=True, 34 | split='test', 35 | transform=preprocess 36 | ) 37 | 38 | self.test_loader = torch.utils.data.DataLoader( 39 | self.test_dataset, 40 | batch_size=batch_size, 41 | shuffle=False, 42 | num_workers=num_workers 43 | ) 44 | 45 | self.classnames = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'] 46 | -------------------------------------------------------------------------------- /src/datasets/templates.py: -------------------------------------------------------------------------------- 1 | cars_template = [ 2 | lambda c: f'a photo of a {c}.', 3 | lambda c: f'a photo of the {c}.', 4 | lambda c: f'a photo of my {c}.', 5 | lambda c: f'i love my {c}!', 6 | lambda c: f'a photo of my dirty {c}.', 7 | lambda c: f'a photo of my clean {c}.', 8 | lambda c: f'a photo of my new {c}.', 9 | lambda c: f'a photo of my old {c}.', 10 | ] 11 | 12 | cifar10_template = [ 13 | lambda c: f'a photo of a {c}.', 14 | lambda c: f'a blurry photo of a {c}.', 15 | lambda c: f'a black and white photo of a {c}.', 16 | lambda c: f'a low contrast photo of a {c}.', 17 | lambda c: f'a high contrast photo of a {c}.', 18 | lambda c: f'a bad photo of a {c}.', 19 | lambda c: f'a good photo of a {c}.', 20 | lambda c: f'a photo of a small {c}.', 21 | lambda c: f'a photo of a big {c}.', 22 | lambda c: f'a photo of the {c}.', 23 | lambda c: f'a blurry photo of the {c}.', 24 | lambda c: f'a black and white photo of the {c}.', 25 | lambda c: f'a low contrast photo of the {c}.', 26 | lambda c: f'a high contrast photo of the {c}.', 27 | lambda c: f'a bad photo of the {c}.', 28 | lambda c: f'a good photo of the {c}.', 29 | lambda c: f'a photo of the small {c}.', 30 | lambda c: f'a photo of the big {c}.', 31 | ] 32 | 33 | cifar100_template = [ 34 | lambda c: f'a photo of a {c}.', 35 | lambda c: f'a blurry photo of a {c}.', 36 | lambda c: f'a black and white photo of a {c}.', 37 | lambda c: f'a low contrast photo of a {c}.', 38 | lambda c: f'a high contrast photo of a {c}.', 39 | lambda c: f'a bad photo of a {c}.', 40 | lambda c: f'a good photo of a {c}.', 41 | lambda c: f'a photo of a small {c}.', 42 | lambda c: f'a photo of a big {c}.', 43 | lambda c: f'a photo of the {c}.', 44 | lambda c: f'a blurry photo of the {c}.', 45 | lambda c: f'a black and white photo of the {c}.', 46 | lambda c: f'a low contrast photo of the {c}.', 47 | lambda c: f'a high contrast photo of the {c}.', 48 | lambda c: f'a bad photo of the {c}.', 49 | lambda c: f'a good photo of the {c}.', 50 | lambda c: f'a photo of the small {c}.', 51 | lambda c: f'a photo of the big {c}.', 52 | ] 53 | 54 | dtd_template = [ 55 | lambda c: f'a photo of a {c} texture.', 56 | lambda c: f'a photo of a {c} pattern.', 57 | lambda c: f'a photo of a {c} thing.', 58 | lambda c: f'a photo of a {c} object.', 59 | lambda c: f'a photo of the {c} texture.', 60 | lambda c: f'a photo of the {c} pattern.', 61 | lambda c: f'a photo of the {c} thing.', 62 | lambda c: f'a photo of the {c} object.', 63 | ] 64 | 65 | eurosat_template = [ 66 | lambda c: f'a centered satellite photo of {c}.', 67 | lambda c: f'a centered satellite photo of a {c}.', 68 | lambda c: f'a centered satellite photo of the {c}.', 69 | ] 70 | 71 | food101_template = [ 72 | lambda c: f'a photo of {c}, a type of food.', 73 | ] 74 | 75 | gtsrb_template = [ 76 | lambda c: f'a zoomed in photo of a "{c}" traffic sign.', 77 | lambda c: f'a centered photo of a "{c}" traffic sign.', 78 | lambda c: f'a close up photo of a "{c}" traffic sign.', 79 | ] 80 | 81 | mnist_template = [ 82 | lambda c: f'a photo of the number: "{c}".', 83 | ] 84 | 85 | imagenet_template = [ 86 | lambda c: f'a bad photo of a {c}.', 87 | lambda c: f'a photo of many {c}.', 88 | lambda c: f'a sculpture of a {c}.', 89 | lambda c: f'a photo of the hard to see {c}.', 90 | lambda c: f'a low resolution photo of the {c}.', 91 | lambda c: f'a rendering of a {c}.', 92 | lambda c: f'graffiti of a {c}.', 93 | lambda c: f'a bad photo of the {c}.', 94 | lambda c: f'a cropped photo of the {c}.', 95 | lambda c: f'a tattoo of a {c}.', 96 | lambda c: f'the embroidered {c}.', 97 | lambda c: f'a photo of a hard to see {c}.', 98 | lambda c: f'a bright photo of a {c}.', 99 | lambda c: f'a photo of a clean {c}.', 100 | lambda c: f'a photo of a dirty {c}.', 101 | lambda c: f'a dark photo of the {c}.', 102 | lambda c: f'a drawing of a {c}.', 103 | lambda c: f'a photo of my {c}.', 104 | lambda c: f'the plastic {c}.', 105 | lambda c: f'a photo of the cool {c}.', 106 | lambda c: f'a close-up photo of a {c}.', 107 | lambda c: f'a black and white photo of the {c}.', 108 | lambda c: f'a painting of the {c}.', 109 | lambda c: f'a painting of a {c}.', 110 | lambda c: f'a pixelated photo of the {c}.', 111 | lambda c: f'a sculpture of the {c}.', 112 | lambda c: f'a bright photo of the {c}.', 113 | lambda c: f'a cropped photo of a {c}.', 114 | lambda c: f'a plastic {c}.', 115 | lambda c: f'a photo of the dirty {c}.', 116 | lambda c: f'a jpeg corrupted photo of a {c}.', 117 | lambda c: f'a blurry photo of the {c}.', 118 | lambda c: f'a photo of the {c}.', 119 | lambda c: f'a good photo of the {c}.', 120 | lambda c: f'a rendering of the {c}.', 121 | lambda c: f'a {c} in a video game.', 122 | lambda c: f'a photo of one {c}.', 123 | lambda c: f'a doodle of a {c}.', 124 | lambda c: f'a close-up photo of the {c}.', 125 | lambda c: f'a photo of a {c}.', 126 | lambda c: f'the origami {c}.', 127 | lambda c: f'the {c} in a video game.', 128 | lambda c: f'a sketch of a {c}.', 129 | lambda c: f'a doodle of the {c}.', 130 | lambda c: f'a origami {c}.', 131 | lambda c: f'a low resolution photo of a {c}.', 132 | lambda c: f'the toy {c}.', 133 | lambda c: f'a rendition of the {c}.', 134 | lambda c: f'a photo of the clean {c}.', 135 | lambda c: f'a photo of a large {c}.', 136 | lambda c: f'a rendition of a {c}.', 137 | lambda c: f'a photo of a nice {c}.', 138 | lambda c: f'a photo of a weird {c}.', 139 | lambda c: f'a blurry photo of a {c}.', 140 | lambda c: f'a cartoon {c}.', 141 | lambda c: f'art of a {c}.', 142 | lambda c: f'a sketch of the {c}.', 143 | lambda c: f'a embroidered {c}.', 144 | lambda c: f'a pixelated photo of a {c}.', 145 | lambda c: f'itap of the {c}.', 146 | lambda c: f'a jpeg corrupted photo of the {c}.', 147 | lambda c: f'a good photo of a {c}.', 148 | lambda c: f'a plushie {c}.', 149 | lambda c: f'a photo of the nice {c}.', 150 | lambda c: f'a photo of the small {c}.', 151 | lambda c: f'a photo of the weird {c}.', 152 | lambda c: f'the cartoon {c}.', 153 | lambda c: f'art of the {c}.', 154 | lambda c: f'a drawing of the {c}.', 155 | lambda c: f'a photo of the large {c}.', 156 | lambda c: f'a black and white photo of a {c}.', 157 | lambda c: f'the plushie {c}.', 158 | lambda c: f'a dark photo of a {c}.', 159 | lambda c: f'itap of a {c}.', 160 | lambda c: f'graffiti of the {c}.', 161 | lambda c: f'a toy {c}.', 162 | lambda c: f'itap of my {c}.', 163 | lambda c: f'a photo of a cool {c}.', 164 | lambda c: f'a photo of a small {c}.', 165 | lambda c: f'a tattoo of the {c}.', 166 | ] 167 | 168 | resisc45_template = [ 169 | lambda c: f'satellite imagery of {c}.', 170 | lambda c: f'aerial imagery of {c}.', 171 | lambda c: f'satellite photo of {c}.', 172 | lambda c: f'aerial photo of {c}.', 173 | lambda c: f'satellite view of {c}.', 174 | lambda c: f'aerial view of {c}.', 175 | lambda c: f'satellite imagery of a {c}.', 176 | lambda c: f'aerial imagery of a {c}.', 177 | lambda c: f'satellite photo of a {c}.', 178 | lambda c: f'aerial photo of a {c}.', 179 | lambda c: f'satellite view of a {c}.', 180 | lambda c: f'aerial view of a {c}.', 181 | lambda c: f'satellite imagery of the {c}.', 182 | lambda c: f'aerial imagery of the {c}.', 183 | lambda c: f'satellite photo of the {c}.', 184 | lambda c: f'aerial photo of the {c}.', 185 | lambda c: f'satellite view of the {c}.', 186 | lambda c: f'aerial view of the {c}.', 187 | ] 188 | 189 | stl10_template = [ 190 | lambda c: f'a photo of a {c}.', 191 | lambda c: f'a photo of the {c}.', 192 | ] 193 | 194 | sun397_template = [ 195 | lambda c: f'a photo of a {c}.', 196 | lambda c: f'a photo of the {c}.', 197 | ] 198 | 199 | svhn_template = [ 200 | lambda c: f'a photo of the number: "{c}".', 201 | ] 202 | 203 | 204 | dataset_to_template = { 205 | 'Cars': cars_template, 206 | 'CIFAR10': cifar10_template, 207 | 'CIFAR100': cifar100_template, 208 | 'DTD': dtd_template, 209 | 'EuroSAT': eurosat_template, 210 | 'Food101': food101_template, 211 | 'GTSRB': gtsrb_template, 212 | 'MNIST': mnist_template, 213 | 'ImageNet': imagenet_template, 214 | 'RESISC45': resisc45_template, 215 | 'STL10': stl10_template, 216 | 'SUN397': sun397_template, 217 | 'SVHN': svhn_template, 218 | } 219 | 220 | 221 | def get_templates(dataset_name): 222 | if dataset_name.endswith('Val'): 223 | return get_templates(dataset_name.replace('Val', '')) 224 | assert dataset_name in dataset_to_template, f'Unsupported dataset: {dataset_name}' 225 | return dataset_to_template[dataset_name] -------------------------------------------------------------------------------- /src/distributed.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | 5 | 6 | def setup_ddp(rank, world_size, port=12357): 7 | os.environ["MASTER_ADDR"] = "localhost" 8 | os.environ["MASTER_PORT"] = str(port) 9 | 10 | # initialize the process group 11 | torch.distributed.init_process_group( 12 | "nccl", 13 | rank=rank, 14 | world_size=world_size, 15 | ) 16 | torch.cuda.set_device(rank) 17 | torch.distributed.barrier() 18 | 19 | 20 | def cleanup_ddp(): 21 | torch.distributed.destroy_process_group() 22 | 23 | 24 | def is_main_process(): 25 | return torch.distributed.get_rank() == 0 26 | 27 | 28 | def distribute_loader(loader): 29 | return torch.utils.data.DataLoader( 30 | loader.dataset, 31 | batch_size=loader.batch_size // torch.distributed.get_world_size(), 32 | sampler=torch.utils.data.distributed.DistributedSampler( 33 | loader.dataset, 34 | num_replicas=torch.distributed.get_world_size(), 35 | rank=torch.distributed.get_rank(), 36 | ), 37 | num_workers=loader.num_workers, 38 | pin_memory=loader.pin_memory, 39 | ) 40 | -------------------------------------------------------------------------------- /src/eval.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import tqdm 4 | 5 | from src import utils 6 | from src.datasets.common import get_dataloader, maybe_dictionarize 7 | from src.datasets.registry import get_dataset 8 | from src.heads import get_classification_head 9 | from src.linearize import LinearizedImageEncoder 10 | from src.modeling import ImageClassifier 11 | 12 | 13 | def eval_single_dataset(image_encoder, dataset_name, args): 14 | classification_head = get_classification_head(args, dataset_name) 15 | model = ImageClassifier(image_encoder, classification_head) 16 | 17 | model.eval() 18 | 19 | dataset = get_dataset( 20 | dataset_name, 21 | model.val_preprocess, 22 | location=args.data_location, 23 | batch_size=args.batch_size, 24 | ) 25 | dataloader = get_dataloader(dataset, is_train=False, args=args, image_encoder=None) 26 | device = args.device 27 | 28 | with torch.no_grad(): 29 | top1, correct, n = 0.0, 0.0, 0.0 30 | for _, data in enumerate(tqdm.tqdm(dataloader)): 31 | data = maybe_dictionarize(data) 32 | x = data["images"].to(device) 33 | y = data["labels"].to(device) 34 | 35 | logits = utils.get_logits(x, model) 36 | 37 | pred = logits.argmax(dim=1, keepdim=True).to(device) 38 | 39 | correct += pred.eq(y.view_as(pred)).sum().item() 40 | 41 | n += y.size(0) 42 | 43 | top1 = correct / n 44 | 45 | metrics = {"top1": top1} 46 | print(f"Done evaluating on {dataset_name}. Accuracy: {100*top1:.2f}%") 47 | 48 | return metrics 49 | 50 | 51 | def evaluate(image_encoder, args): 52 | if args.eval_datasets is None: 53 | return 54 | per_dataset_results = {} 55 | eval_datasets = ( 56 | args.eval_datasets 57 | if args.control_dataset is None 58 | else args.eval_datasets + [args.control_dataset] 59 | ) 60 | for dataset_name in eval_datasets: 61 | print("Evaluating on", dataset_name) 62 | 63 | results = eval_single_dataset(image_encoder, dataset_name, args) 64 | 65 | print(f"{dataset_name} Top-1 accuracy: {results['top1']:.4f}") 66 | per_dataset_results[dataset_name + ":top1"] = results["top1"] 67 | 68 | return per_dataset_results 69 | 70 | 71 | def evaluate_task_vector_at_coef( 72 | task_vector, pretrained_checkpoint, args, scaling_coef, posthoc_linearization=False 73 | ): 74 | image_encoder = task_vector.apply_to( 75 | pretrained_checkpoint, scaling_coef=scaling_coef 76 | ) 77 | if posthoc_linearization: 78 | pretrained_encoder = task_vector.apply_to( 79 | pretrained_checkpoint, scaling_coef=0.0 80 | ) 81 | image_encoder = LinearizedImageEncoder( 82 | init_encoder=pretrained_encoder, image_encoder=image_encoder, args=args 83 | ) 84 | coef_info = evaluate(image_encoder, args) 85 | 86 | coef_info = add_normalized_accuracy(coef_info, args) 87 | coef_info["avg_normalized_top1"] = np.mean( 88 | [coef_info[dataset + ":normalized_top1"] for dataset in args.eval_datasets] 89 | ) 90 | coef_info["avg_top1"] = np.mean( 91 | [coef_info[dataset + ":top1"] for dataset in args.eval_datasets] 92 | ) 93 | 94 | return coef_info 95 | 96 | 97 | def evaluate_task_vector( 98 | task_vector, pretrained_checkpoint, args, posthoc_linearization=False 99 | ): 100 | info = {} 101 | for scaling_coef in np.linspace(0.0, 1.0, args.n_eval_points): 102 | print(f"Evaluating for scaling coefficient {scaling_coef:.2f}") 103 | info[scaling_coef] = evaluate_task_vector_at_coef( 104 | task_vector, 105 | pretrained_checkpoint, 106 | args, 107 | scaling_coef, 108 | posthoc_linearization, 109 | ) 110 | 111 | return info 112 | 113 | 114 | def add_normalized_accuracy(results, args): 115 | for dataset_name in args.eval_datasets: 116 | results[dataset_name + ":normalized_top1"] = ( 117 | results[dataset_name + ":top1"] / args.finetuning_accuracies[dataset_name] 118 | ) 119 | 120 | return results 121 | 122 | 123 | def nonlinear_advantage(acc_linear, acc_nonlinear, num_classes): 124 | err_linear = 1 - acc_linear 125 | err_nonlinear = 1 - acc_nonlinear 126 | return (err_linear - err_nonlinear) * num_classes / (num_classes - 1) 127 | -------------------------------------------------------------------------------- /src/eval_single_task.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | from src.args import parse_arguments 4 | from src.eval import eval_single_dataset 5 | from src.linearize import LinearizedImageEncoder 6 | from src.task_vectors import LinearizedTaskVector, NonLinearTaskVector 7 | 8 | args = parse_arguments() 9 | if args.seed is not None: 10 | args.save = f"checkpoints_{args.seed}/{args.model}" 11 | else: 12 | args.save = f"checkpoints/{args.model}" 13 | 14 | accuracies = {} 15 | 16 | 17 | print("*" * 100) 18 | if args.finetuning_mode == "none": 19 | print("Evaluating pretrained models.") 20 | elif args.finetuning_mode == "standard": 21 | print("Evaluating non-linear FT models.") 22 | elif args.finetuning_mode == "linear": 23 | print("Evaluating linear FT models.") 24 | elif args.finetuning_mode == "posthoc": 25 | print("Evaluating post-hoc linearized models.") 26 | 27 | for dataset in [ 28 | "Cars", 29 | "DTD", 30 | "EuroSAT", 31 | "GTSRB", 32 | "MNIST", 33 | "RESISC45", 34 | "SUN397", 35 | "SVHN", 36 | ]: 37 | print("*" * 100) 38 | print(f"Evaluating on {dataset}") 39 | 40 | pretrained_checkpoint = f"{args.save}/{dataset}Val/zeroshot.pt" 41 | 42 | finetuned_checkpoint = ( 43 | f"{args.save}/{dataset}Val/linear_finetuned.pt" 44 | if args.finetuning_mode == "linear" 45 | else f"{args.save}/{dataset}Val/finetuned.pt" 46 | ) 47 | 48 | try: 49 | task_vector = ( 50 | LinearizedTaskVector(pretrained_checkpoint, finetuned_checkpoint) 51 | if args.finetuning_mode == "linear" 52 | else NonLinearTaskVector(pretrained_checkpoint, finetuned_checkpoint) 53 | ) 54 | except FileNotFoundError: 55 | print(f"Error: Could not find {finetuned_checkpoint}.") 56 | continue 57 | 58 | if args.finetuning_mode == "none": 59 | image_encoder = task_vector.apply_to(pretrained_checkpoint, scaling_coef=0.0) 60 | elif args.finetuning_mode == "standard" or args.finetuning_mode == "linear": 61 | image_encoder = task_vector.apply_to(pretrained_checkpoint, scaling_coef=1.0) 62 | elif args.finetuning_mode == "posthoc": 63 | zs_encoder = task_vector.apply_to(pretrained_checkpoint, scaling_coef=0.0) 64 | ft_encoder = task_vector.apply_to(pretrained_checkpoint, scaling_coef=1.0) 65 | image_encoder = LinearizedImageEncoder( 66 | init_encoder=zs_encoder, image_encoder=ft_encoder, args=args 67 | ) 68 | 69 | for split in ["test", "val"]: 70 | # Evaluate 71 | print("=" * 100) 72 | print(f"Evaluating on {split} split.") 73 | eval_dataset = dataset if split == "test" else f"{dataset}Val" 74 | 75 | accuracies[eval_dataset] = eval_single_dataset( 76 | image_encoder, eval_dataset, args 77 | )["top1"] 78 | 79 | 80 | if args.finetuning_mode == "none": 81 | # Evaluate zero-shot accuracy on ImageNet 82 | for split in ["ImageNetVal", "ImageNet"]: 83 | accuracies[split] = eval_single_dataset(image_encoder, split, args)["top1"] 84 | 85 | # Save results 86 | if args.finetuning_mode == "none": 87 | save_path = f"{args.save}/zeroshot_accuracies.json" 88 | elif args.finetuning_mode == "standard": 89 | save_path = f"{args.save}/ft_accuracies.json" 90 | elif args.finetuning_mode == "linear": 91 | save_path = f"{args.save}/linear_ft_accuracies.json" 92 | elif args.finetuning_mode == "posthoc": 93 | save_path = f"{args.save}/posthoc_ft_accuracies.json" 94 | 95 | with open(save_path, "w") as f: 96 | json.dump(accuracies, f) 97 | -------------------------------------------------------------------------------- /src/eval_task_addition.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | 4 | from utils import find_optimal_coef 5 | 6 | from src.args import parse_arguments 7 | from src.eval import evaluate_task_vector, evaluate_task_vector_at_coef 8 | from src.task_vectors import LinearizedTaskVector, NonLinearTaskVector 9 | 10 | args = parse_arguments() 11 | 12 | if args.seed is not None: 13 | args.save = f"checkpoints_{args.seed}/{args.model}" 14 | else: 15 | args.save = f"checkpoints/{args.model}" 16 | 17 | 18 | print("*" * 100) 19 | if args.finetuning_mode == "standard": 20 | print("Evaluating non-linear FT models.") 21 | ft_accuracies_path = os.path.join(args.save, "ft_accuracies.json") 22 | elif args.finetuning_mode == "linear": 23 | print("Evaluating linear FT models.") 24 | ft_accuracies_path = os.path.join(args.save, "linear_ft_accuracies.json") 25 | elif args.finetuning_mode == "posthoc": 26 | print("Evaluating post-hoc linearized models.") 27 | ft_accuracies_path = os.path.join(args.save, "posthoc_ft_accuracies.json") 28 | else: 29 | raise ValueError(f"Invalid finetuning mode: {args.finetuning_mode}") 30 | print("*" * 100) 31 | 32 | with open(ft_accuracies_path) as f: 33 | args.finetuning_accuracies = json.load(f) 34 | 35 | with open(os.path.join(args.save, "zeroshot_accuracies.json")) as f: 36 | pretrained_accuracies = json.load(f) 37 | 38 | eval_datasets = [ 39 | "Cars", 40 | "DTD", 41 | "EuroSAT", 42 | "GTSRB", 43 | "MNIST", 44 | "RESISC45", 45 | "SVHN", 46 | "SUN397", 47 | ] 48 | 49 | task_vectors = [] 50 | 51 | for dataset in eval_datasets: 52 | if args.finetuning_mode == "linear": 53 | pretrained_checkpoint = f"{args.save}/{dataset}Val/linear_zeroshot.pt" 54 | finetuned_checkpoint = f"{args.save}/{dataset}Val/linear_finetuned.pt" 55 | task_vectors.append( 56 | LinearizedTaskVector(pretrained_checkpoint, finetuned_checkpoint) 57 | ) 58 | else: 59 | pretrained_checkpoint = f"{args.save}/{dataset}Val/zeroshot.pt" 60 | finetuned_checkpoint = f"{args.save}/{dataset}Val/finetuned.pt" 61 | task_vectors.append( 62 | NonLinearTaskVector(pretrained_checkpoint, finetuned_checkpoint) 63 | ) 64 | 65 | task_vector = sum(task_vectors) 66 | 67 | args.eval_datasets = [dataset + "Val" for dataset in eval_datasets] 68 | args.control_dataset = None 69 | 70 | # We use the validation set to choose the optimal coefficient. 71 | val_metrics = evaluate_task_vector( 72 | task_vector, 73 | pretrained_checkpoint, 74 | args, 75 | posthoc_linearization=args.finetuning_mode == "posthoc", 76 | ) 77 | 78 | optimal_coef = find_optimal_coef( 79 | val_metrics, 80 | metric="avg_normalized_top1", 81 | minimize=False, 82 | ) 83 | 84 | # Evaluate on the test set with the optimal coefficient. 85 | args.eval_datasets = [dataset for dataset in eval_datasets] 86 | test_metrics = evaluate_task_vector_at_coef( 87 | task_vector, 88 | pretrained_checkpoint, 89 | args, 90 | float(optimal_coef), 91 | posthoc_linearization=args.finetuning_mode == "posthoc", 92 | ) 93 | 94 | print("=" * 100) 95 | print(f"Test normalized accuracy: {test_metrics['avg_normalized_top1']}") 96 | print(f"Test absolute accuracy: {test_metrics['avg_top1']}") 97 | additive_accuracies = {"test": test_metrics, "val": val_metrics} 98 | 99 | if args.finetuning_mode == "standard": 100 | save_file = f"{args.save}/additions.json" 101 | elif args.finetuning_mode == "linear": 102 | save_file = f"{args.save}/linear_additions.json" 103 | elif args.finetuning_mode == "posthoc": 104 | save_file = f"{args.save}/posthoc_additions.json" 105 | with open(save_file, "w") as f: 106 | json.dump(additive_accuracies, f, indent=4) 107 | -------------------------------------------------------------------------------- /src/eval_task_negation.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | 4 | from utils import find_optimal_coef 5 | 6 | from src.args import parse_arguments 7 | from src.eval import evaluate_task_vector, evaluate_task_vector_at_coef 8 | from src.task_vectors import LinearizedTaskVector, NonLinearTaskVector 9 | 10 | args = parse_arguments() 11 | 12 | 13 | if args.seed is not None: 14 | args.save = f"checkpoints_{args.seed}/{args.model}" 15 | else: 16 | args.save = f"checkpoints/{args.model}" 17 | 18 | with open(os.path.join(args.save, "zeroshot_accuracies.json")) as f: 19 | pretrained_accuracies = json.load(f) 20 | 21 | eval_datasets = [ 22 | "Cars", 23 | "DTD", 24 | "EuroSAT", 25 | "GTSRB", 26 | "MNIST", 27 | "RESISC45", 28 | "SUN397", 29 | "SVHN", 30 | ] 31 | 32 | print("*" * 100) 33 | if args.finetuning_mode == "standard": 34 | print("Evaluating non-linear FT models.") 35 | ft_accuracies_path = os.path.join(args.save, "ft_accuracies.json") 36 | elif args.finetuning_mode == "linear": 37 | print("Evaluating linear FT models.") 38 | ft_accuracies_path = os.path.join(args.save, "linear_ft_accuracies.json") 39 | elif args.finetuning_mode == "posthoc": 40 | print("Evaluating post-hoc linearized models.") 41 | ft_accuracies_path = os.path.join(args.save, "posthoc_ft_accuracies.json") 42 | else: 43 | raise ValueError(f"Invalid finetuning mode: {args.finetuning_mode}") 44 | print("*" * 100) 45 | 46 | with open(ft_accuracies_path) as f: 47 | args.finetuning_accuracies = json.load(f) 48 | 49 | control_dataset = "ImageNet" 50 | negation_accuracies = {} 51 | 52 | for dataset in eval_datasets: 53 | if args.finetuning_mode == "linear": 54 | pretrained_checkpoint = f"{args.save}/{dataset}Val/linear_zeroshot.pt" 55 | finetuned_checkpoint = f"{args.save}/{dataset}Val/linear_finetuned.pt" 56 | task_vector = -LinearizedTaskVector(pretrained_checkpoint, finetuned_checkpoint) 57 | else: 58 | pretrained_checkpoint = f"{args.save}/{dataset}Val/zeroshot.pt" 59 | finetuned_checkpoint = f"{args.save}/{dataset}Val/finetuned.pt" 60 | task_vector = -NonLinearTaskVector(pretrained_checkpoint, finetuned_checkpoint) 61 | 62 | # We use the validation set to choose the optimal coefficient. 63 | args.eval_datasets = [dataset + "Val"] 64 | args.control_dataset = control_dataset + "Val" 65 | val_metrics = evaluate_task_vector( 66 | task_vector, 67 | pretrained_checkpoint, 68 | args, 69 | posthoc_linearization=args.finetuning_mode == "posthoc", 70 | ) 71 | 72 | optimal_coef = find_optimal_coef( 73 | val_metrics, 74 | metric=f"{dataset}Val:top1", 75 | minimize=True, 76 | control_metric=f"{control_dataset}Val:top1", 77 | control_metric_threshold=args.control_threshold 78 | * pretrained_accuracies[control_dataset + "Val"], 79 | ) 80 | 81 | # Evaluate on the test set with the optimal coefficient. 82 | args.eval_datasets = [dataset] 83 | args.control_dataset = control_dataset 84 | test_metrics = evaluate_task_vector_at_coef( 85 | task_vector, 86 | pretrained_checkpoint, 87 | args, 88 | optimal_coef, 89 | posthoc_linearization=args.finetuning_mode == "posthoc", 90 | ) 91 | 92 | print("=" * 100) 93 | print(f"Test accuracy: {test_metrics[f'{dataset}:top1']}") 94 | 95 | negation_accuracies[dataset] = { 96 | "test": test_metrics[f"{dataset}:top1"], 97 | "test_control": test_metrics[f"{control_dataset}:top1"], 98 | "val": val_metrics, 99 | } 100 | 101 | if args.finetuning_mode == "standard": 102 | save_file = f"{args.save}/negations.json" 103 | elif args.finetuning_mode == "linear": 104 | save_file = f"{args.save}/linear_negations.json" 105 | elif args.finetuning_mode == "posthoc": 106 | save_file = f"{args.save}/posthoc_negations.json" 107 | 108 | with open(save_file, "w") as f: 109 | json.dump(negation_accuracies, f, indent=4) 110 | -------------------------------------------------------------------------------- /src/finetune.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | 4 | import torch 5 | 6 | from src.args import parse_arguments 7 | from src.datasets.common import get_dataloader, maybe_dictionarize 8 | from src.datasets.registry import get_dataset 9 | from src.distributed import cleanup_ddp, distribute_loader, is_main_process, setup_ddp 10 | from src.eval import eval_single_dataset 11 | from src.heads import get_classification_head 12 | from src.linearize import LinearizedImageEncoder 13 | from src.modeling import ImageClassifier, ImageEncoder 14 | from src.utils import LabelSmoothing, cosine_lr 15 | 16 | 17 | def finetune(rank, args): 18 | setup_ddp(rank, args.world_size, port=args.port) 19 | 20 | train_dataset = args.train_dataset 21 | ckpdir = os.path.join(args.save, train_dataset) 22 | 23 | assert args.finetuning_mode in [ 24 | "linear", 25 | "standard", 26 | ], "Only linear and standard fine-tuning are supported." 27 | 28 | linearized_finetuning = args.finetuning_mode == "linear" 29 | if linearized_finetuning: 30 | print("Using linearized fine-tuning.") 31 | 32 | # Check if checkpoints already exist 33 | ft_path = ( 34 | os.path.join(args.save, train_dataset, "linear_finetuned.pt") 35 | if linearized_finetuning 36 | else os.path.join(args.save, train_dataset, "finetuned.pt") 37 | ) 38 | zs_path = ( 39 | os.path.join(args.save, train_dataset, "linear_zeroshot.pt") 40 | if linearized_finetuning 41 | else os.path.join(args.save, train_dataset, "zeroshot.pt") 42 | ) 43 | if os.path.exists(zs_path) and os.path.exists(ft_path): 44 | print(f"Skipping fine-tuning because {ft_path} exists.") 45 | return zs_path, ft_path 46 | 47 | assert train_dataset is not None, "Please provide a training dataset." 48 | 49 | if args.load is not None and args.load.endswith("pt"): 50 | image_encoder = ( 51 | LinearizedImageEncoder.load(args.load) 52 | if linearized_finetuning 53 | else ImageEncoder.load(args.load) 54 | ) 55 | else: 56 | print("Building image encoder.") 57 | image_encoder = ( 58 | LinearizedImageEncoder(args, keep_lang=False) 59 | if linearized_finetuning 60 | else ImageEncoder(args) 61 | ) 62 | 63 | classification_head = get_classification_head(args, train_dataset) 64 | 65 | model = ImageClassifier(image_encoder, classification_head) 66 | 67 | model.freeze_head() 68 | model = model.cuda() 69 | 70 | preprocess_fn = model.train_preprocess 71 | print_every = 100 72 | 73 | dataset = get_dataset( 74 | train_dataset, 75 | preprocess_fn, 76 | location=args.data_location, 77 | batch_size=args.batch_size, 78 | ) 79 | data_loader = get_dataloader(dataset, is_train=True, args=args, image_encoder=None) 80 | num_batches = len(dataset.train_loader) 81 | 82 | # Distribute the data and model across the GPUs. 83 | ddp_loader = distribute_loader(data_loader) 84 | ddp_model = torch.nn.parallel.DistributedDataParallel( 85 | model, 86 | device_ids=[rank], 87 | find_unused_parameters=True, 88 | output_device=rank, 89 | ) 90 | 91 | if args.ls > 0: 92 | loss_fn = LabelSmoothing(args.ls) 93 | else: 94 | loss_fn = torch.nn.CrossEntropyLoss() 95 | 96 | params = [p for p in ddp_model.parameters() if p.requires_grad] 97 | optimizer = torch.optim.AdamW(params, lr=args.lr, weight_decay=args.wd) 98 | 99 | scheduler = cosine_lr( 100 | optimizer, 101 | args.lr, 102 | args.warmup_length, 103 | args.epochs * num_batches // args.num_grad_accumulation, 104 | ) 105 | 106 | # Saving zero-shot model 107 | if args.save is not None and is_main_process(): 108 | os.makedirs(ckpdir, exist_ok=True) 109 | model_path = ( 110 | os.path.join(ckpdir, "linear_zeroshot.pt") 111 | if linearized_finetuning 112 | else os.path.join(ckpdir, "zeroshot.pt") 113 | ) 114 | ddp_model.module.image_encoder.save(model_path) 115 | 116 | for epoch in range(args.epochs): 117 | ddp_model.train() 118 | 119 | for i, batch in enumerate(ddp_loader): 120 | start_time = time.time() 121 | 122 | step = ( 123 | i // args.num_grad_accumulation 124 | + epoch * num_batches // args.num_grad_accumulation 125 | ) 126 | 127 | batch = maybe_dictionarize(batch) 128 | inputs = batch["images"].cuda() 129 | labels = batch["labels"].cuda() 130 | data_time = time.time() - start_time 131 | 132 | logits = ddp_model(inputs) 133 | 134 | loss = loss_fn(logits, labels) 135 | 136 | loss.backward() 137 | 138 | if (i + 1) % args.num_grad_accumulation == 0: 139 | scheduler(step) 140 | 141 | torch.nn.utils.clip_grad_norm_(params, 1.0) 142 | optimizer.step() 143 | optimizer.zero_grad() 144 | 145 | batch_time = time.time() - start_time 146 | 147 | if ( 148 | args.checkpoint_every > 0 149 | and step % args.checkpoint_every == 0 150 | and is_main_process() 151 | ): 152 | print("Saving checkpoint.") 153 | model_path = ( 154 | os.path.join(ckpdir, f"linear_checkpoint_{step}.pt") 155 | if linearized_finetuning 156 | else os.path.join(ckpdir, f"checkpoint_{step}.pt") 157 | ) 158 | ddp_model.module.image_encoder.save(model_path) 159 | 160 | if ( 161 | step % print_every == 0 162 | and ((i + 1) % args.num_grad_accumulation == 0) 163 | and is_main_process() 164 | ): 165 | percent_complete = 100 * i / len(ddp_loader) 166 | print( 167 | f"Train Epoch: {epoch} [{percent_complete:.0f}% {i}/{len(dataset.train_loader)}]\t" # noqa: E501 168 | f"Loss: {loss.item():.6f}\tData (t) {data_time:.3f}\tBatch (t) {batch_time:.3f}", # noqa: E501 169 | flush=True, 170 | ) 171 | 172 | # FIXME: Make this work with DDP. 173 | if is_main_process(): 174 | # We only need to evaluate the model on the first GPU. 175 | image_encoder = ddp_model.module.image_encoder 176 | eval_single_dataset(image_encoder, train_dataset, args) 177 | 178 | if args.save is not None and is_main_process(): 179 | zs_path = ( 180 | os.path.join(ckpdir, "linear_zeroshot.pt") 181 | if linearized_finetuning 182 | else os.path.join(ckpdir, "zeroshot.pt") 183 | ) 184 | ft_path = ( 185 | os.path.join(ckpdir, "linear_finetuned.pt") 186 | if linearized_finetuning 187 | else os.path.join(ckpdir, "finetuned.pt") 188 | ) 189 | image_encoder.save(ft_path) 190 | return zs_path, ft_path 191 | 192 | cleanup_ddp() 193 | 194 | 195 | if __name__ == "__main__": 196 | train_datasets = [ 197 | "Cars", 198 | "DTD", 199 | "EuroSAT", 200 | "GTSRB", 201 | "MNIST", 202 | "RESISC45", 203 | "SUN397", 204 | "SVHN", 205 | ] 206 | epochs = { 207 | "Cars": 35, 208 | "DTD": 76, 209 | "EuroSAT": 12, 210 | "GTSRB": 11, 211 | "MNIST": 5, 212 | "RESISC45": 15, 213 | "SUN397": 14, 214 | "SVHN": 4, 215 | } 216 | 217 | for dataset in train_datasets: 218 | args = parse_arguments() 219 | 220 | # HACK: Some command line arguments are overwritten by defaults here. 221 | args.lr = 1e-5 222 | args.epochs = epochs[dataset] 223 | args.train_dataset = dataset + "Val" 224 | 225 | # We use gradient accumulation to simulate larger batch sizes if the model does not fit in memory. 226 | args.batch_size = 64 if args.model == "ViT-L-14" else 128 227 | args.num_grad_accumulation = 2 if args.model == "ViT-L-14" else 1 228 | 229 | if args.seed is not None: 230 | args.save = f"checkpoints_{args.seed}/{args.model}" 231 | else: 232 | args.save = f"checkpoints/{args.model}" 233 | print("=" * 100) 234 | print(f"Finetuning {args.model} on {dataset}") 235 | print("=" * 100) 236 | torch.multiprocessing.spawn(finetune, args=(args,), nprocs=args.world_size) 237 | -------------------------------------------------------------------------------- /src/heads.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import open_clip 4 | import torch 5 | from tqdm import tqdm 6 | 7 | from src.datasets.registry import get_dataset 8 | from src.datasets.templates import get_templates 9 | from src.modeling import ClassificationHead, ImageEncoder 10 | 11 | 12 | def build_classification_head(model, dataset_name, template, data_location, device): 13 | template = get_templates(dataset_name) 14 | 15 | logit_scale = model.logit_scale 16 | dataset = get_dataset(dataset_name, None, location=data_location) 17 | model.eval() 18 | model.to(device) 19 | 20 | print("Building classification head.") 21 | with torch.no_grad(): 22 | zeroshot_weights = [] 23 | for classname in tqdm(dataset.classnames): 24 | texts = [] 25 | for t in template: 26 | texts.append(t(classname)) 27 | texts = open_clip.tokenize(texts).to(device) # tokenize 28 | embeddings = model.encode_text(texts) # embed with text encoder 29 | embeddings /= embeddings.norm(dim=-1, keepdim=True) 30 | 31 | embeddings = embeddings.mean(dim=0, keepdim=True) 32 | embeddings /= embeddings.norm() 33 | 34 | zeroshot_weights.append(embeddings) 35 | 36 | zeroshot_weights = torch.stack(zeroshot_weights, dim=0).to(device) 37 | zeroshot_weights = torch.transpose(zeroshot_weights, 0, 2) 38 | 39 | zeroshot_weights *= logit_scale.exp() 40 | 41 | zeroshot_weights = zeroshot_weights.squeeze().float() 42 | zeroshot_weights = torch.transpose(zeroshot_weights, 0, 1) 43 | 44 | classification_head = ClassificationHead(normalize=True, weights=zeroshot_weights) 45 | 46 | return classification_head 47 | 48 | 49 | def get_classification_head(args, dataset): 50 | if not dataset.endswith("Val"): 51 | # We want to load the head for the validation set always to be consistent with the one generated at training time. 52 | dataset += "Val" 53 | 54 | filename = os.path.join(args.save, f"head_{dataset}.pt") 55 | if os.path.exists(filename): 56 | print(f"Classification head for {args.model} on {dataset} exists at {filename}") 57 | return ClassificationHead.load(filename) 58 | print( 59 | f"Did not find classification head for {args.model} on {dataset} at {filename}, building one from scratch." # noqa: E501 60 | ) 61 | model = ImageEncoder(args, keep_lang=True).model 62 | template = get_templates(dataset) 63 | classification_head = build_classification_head( 64 | model, dataset, template, args.data_location, args.device 65 | ) 66 | os.makedirs(args.save, exist_ok=True) 67 | classification_head.save(filename) 68 | return classification_head 69 | -------------------------------------------------------------------------------- /src/linearize.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import os 3 | 4 | import torch 5 | import torch.nn as nn 6 | from functorch import jvp, make_functional_with_buffers 7 | 8 | from src.modeling import ImageEncoder 9 | from src.utils import DotDict 10 | 11 | 12 | class LinearizedModel(nn.Module): 13 | """Creates a linearized version of a nn.Module. 14 | 15 | The linearized version of a model is a proper PyTorch model and can be 16 | trained as any other nn.Module. 17 | 18 | Args: 19 | model (nn.Module): The model to linearize. The trainable parameters of 20 | the linearized model will be initialized to the parameters of this 21 | model. 22 | init_model (nn.Module): A model of the same type as `model` containing 23 | the parameters around which the model is initialized. If not 24 | provided, `model` is used as the initialization model. 25 | """ 26 | 27 | def __init__(self, model: nn.Module, init_model: nn.Module = None) -> None: 28 | """Initializes the linearized model.""" 29 | super().__init__() 30 | if init_model is None: 31 | init_model = model 32 | 33 | func0, params0, self.buffers0 = make_functional_with_buffers( 34 | init_model.eval(), disable_autograd_tracking=True 35 | ) 36 | self.func0 = lambda params, x: func0(params, self.buffers0, x) 37 | 38 | _, params, _ = make_functional_with_buffers( 39 | model, disable_autograd_tracking=True 40 | ) 41 | 42 | self.params = nn.ParameterList(params) 43 | self.params0 = nn.ParameterList(params0) 44 | self._model_name = model.__class__.__name__ 45 | 46 | # The intial parameters are not trainable. 47 | for p in self.params0: 48 | p.requires_grad = False 49 | 50 | # The params are. 51 | for p in self.params: 52 | p.requires_grad = True 53 | 54 | def __call__(self, x) -> torch.Tensor: 55 | """Computes the linearized model output using a first-order Taylor decomposition.""" 56 | dparams = [p - p0 for p, p0 in zip(self.params, self.params0)] 57 | out, dp = jvp( 58 | lambda param: self.func0(param, x), 59 | (tuple(self.params0),), 60 | (tuple(dparams),), 61 | ) 62 | return out + dp 63 | 64 | 65 | class LinearizedImageEncoder(abc.ABC, nn.Module): 66 | """Creates a linearized version of an image encoder.""" 67 | 68 | def __init__( 69 | self, args=None, keep_lang=False, image_encoder=None, init_encoder=None 70 | ): 71 | super().__init__() 72 | if image_encoder is None: 73 | image_encoder = ImageEncoder(args, keep_lang) 74 | if init_encoder is None: 75 | init_encoder = image_encoder 76 | 77 | # Copy the attributes from the image encoder. 78 | self.train_preprocess = image_encoder.train_preprocess 79 | self.val_preprocess = image_encoder.val_preprocess 80 | self.cache_dir = image_encoder.cache_dir 81 | 82 | self._model_name = self._get_name(args.model) 83 | self.model = LinearizedModel(init_model=init_encoder, model=image_encoder) 84 | 85 | def _get_name(self, model_name): 86 | if "__pretrained__" in model_name: 87 | model_name, _ = model_name.split("__pretrained__", "") 88 | return model_name 89 | 90 | def forward(self, x): 91 | # use the taylorized version of the model. 92 | return self.model(x) 93 | 94 | def __call__(self, x): 95 | return self.forward(x) 96 | 97 | def save(self, filename): 98 | """Saves the linearized image encoder. 99 | 100 | We save the model name in the state dict so that we can load the 101 | correct model when loading the linearized image encoder. Directly using 102 | torch.save would not work becuse func0 is not serializable. 103 | 104 | Args: 105 | filename (str): The path to save the taylorized image encoder. 106 | """ 107 | if os.path.dirname(filename) != "": 108 | os.makedirs(os.path.dirname(filename), exist_ok=True) 109 | 110 | state_dict = self.state_dict() 111 | state_dict["model_name"] = self._model_name 112 | 113 | torch.save(state_dict, filename) 114 | 115 | @classmethod 116 | def load(cls, filename): 117 | """Loads a linearized image encoder. 118 | 119 | It first loads the state dict with the model name and then creates the 120 | correct model and loads the state dict. 121 | 122 | Args: 123 | filename (str): The path to the taylorized image encoder. 124 | 125 | Returns: 126 | LinearizedImageEncoder: The loaded taylorized image encoder. 127 | """ 128 | print(f"Loading image encoder from {filename}") 129 | state_dict = torch.load(filename, map_location="cpu") 130 | 131 | # ImageEncoder expects a DotDict 132 | args = DotDict({"model": state_dict["model_name"]}) 133 | taylorized_encoder = cls(args) 134 | 135 | # Remove the model name from the state dict so that we can load the 136 | # model. 137 | state_dict.pop("model_name") 138 | taylorized_encoder.load_state_dict(state_dict) 139 | return taylorized_encoder 140 | -------------------------------------------------------------------------------- /src/modeling.py: -------------------------------------------------------------------------------- 1 | import open_clip 2 | import torch 3 | 4 | from src import utils 5 | 6 | 7 | class ImageEncoder(torch.nn.Module): 8 | def __init__(self, args, keep_lang=False): 9 | super().__init__() 10 | 11 | print(f"Loading {args.model} pre-trained weights.") 12 | if "__pretrained__" in args.model: 13 | name, pretrained = args.model.split("__pretrained__") 14 | elif "__init__" in args.model: 15 | print("Using random initialization.") 16 | name, pretrained = args.model.split("__init__")[0], None 17 | else: 18 | name = args.model 19 | pretrained = "openai" 20 | ( 21 | self.model, 22 | self.train_preprocess, 23 | self.val_preprocess, 24 | ) = open_clip.create_model_and_transforms( 25 | name, pretrained=pretrained, cache_dir=args.openclip_cachedir 26 | ) 27 | 28 | self.cache_dir = args.cache_dir 29 | 30 | if not keep_lang and hasattr(self.model, "transformer"): 31 | delattr(self.model, "transformer") 32 | 33 | def forward(self, images): 34 | assert self.model is not None 35 | return self.model.encode_image(images) 36 | 37 | def __call__(self, inputs): 38 | return self.forward(inputs) 39 | 40 | def save(self, filename): 41 | print(f"Saving image encoder to {filename}") 42 | utils.torch_save(self, filename) 43 | 44 | @classmethod 45 | def load(cls, model_name, filename): 46 | print(f"Loading image encoder from {filename}") 47 | state_dict = torch.load(filename, map_location="cpu") 48 | return cls.load(model_name, state_dict) 49 | 50 | @classmethod 51 | def load_from_state_dict(cls, model_name, state_dict): 52 | ( 53 | self.model, 54 | self.train_preprocess, 55 | self.val_preprocess, 56 | ) = open_clip.create_model_and_transforms( 57 | name, pretrained=pretrained, cache_dir=args.openclip_cachedir 58 | ) 59 | self.model.load_from_state_dict(state_dict) 60 | 61 | 62 | class ClassificationHead(torch.nn.Linear): 63 | def __init__(self, normalize, weights, biases=None): 64 | output_size, input_size = weights.shape 65 | super().__init__(input_size, output_size) 66 | self.normalize = normalize 67 | if weights is not None: 68 | self.weight = torch.nn.Parameter(weights.clone()) 69 | if biases is not None: 70 | self.bias = torch.nn.Parameter(biases.clone()) 71 | else: 72 | self.bias = torch.nn.Parameter(torch.zeros_like(self.bias)) 73 | 74 | def forward(self, inputs): 75 | if self.normalize: 76 | inputs = inputs / inputs.norm(dim=-1, keepdim=True) 77 | return super().forward(inputs) 78 | 79 | def __call__(self, inputs): 80 | return self.forward(inputs) 81 | 82 | def save(self, filename): 83 | print(f"Saving classification head to {filename}") 84 | utils.torch_save(self, filename) 85 | 86 | @classmethod 87 | def load(cls, filename): 88 | print(f"Loading classification head from {filename}") 89 | return utils.torch_load(filename) 90 | 91 | 92 | class ImageClassifier(torch.nn.Module): 93 | def __init__(self, image_encoder, classification_head): 94 | super().__init__() 95 | self.image_encoder = image_encoder 96 | self.classification_head = classification_head 97 | if self.image_encoder is not None: 98 | self.train_preprocess = self.image_encoder.train_preprocess 99 | self.val_preprocess = self.image_encoder.val_preprocess 100 | 101 | def freeze_head(self): 102 | self.classification_head.weight.requires_grad_(False) 103 | self.classification_head.bias.requires_grad_(False) 104 | 105 | def forward(self, inputs): 106 | features = self.image_encoder(inputs) 107 | outputs = self.classification_head(features) 108 | return outputs 109 | 110 | def __call__(self, inputs): 111 | return self.forward(inputs) 112 | 113 | def save(self, filename): 114 | print(f"Saving image classifier to {filename}") 115 | utils.torch_save(self, filename) 116 | 117 | @classmethod 118 | def load(cls, filename): 119 | print(f"Loading image classifier from {filename}") 120 | return utils.torch_load(filename) 121 | 122 | 123 | class MultiHeadImageClassifier(torch.nn.Module): 124 | def __init__(self, image_encoder, classification_heads): 125 | super().__init__() 126 | self.image_encoder = image_encoder 127 | self.classification_heads = torch.nn.ModuleList(classification_heads) 128 | if self.image_encoder is not None: 129 | self.train_preprocess = self.image_encoder.train_preprocess 130 | self.val_preprocess = self.image_encoder.val_preprocess 131 | 132 | def freeze_head(self): 133 | for idx in range(len(self.classification_heads)): 134 | self.classification_heads[idx].weight.requires_grad_(False) 135 | self.classification_heads[idx].bias.requires_grad_(False) 136 | 137 | def forward(self, inputs, head_idx): 138 | features = self.image_encoder(inputs) 139 | outputs = self.classification_heads[head_idx](features) 140 | return outputs 141 | 142 | def __call__(self, inputs, head_idx): 143 | return self.forward(inputs, head_idx) 144 | 145 | def save(self, filename): 146 | print(f"Saving image classifier to {filename}") 147 | utils.torch_save(self, filename) 148 | 149 | @classmethod 150 | def load(cls, filename): 151 | print(f"Loading image classifier from {filename}") 152 | return utils.torch_load(filename) 153 | -------------------------------------------------------------------------------- /src/task_vectors.py: -------------------------------------------------------------------------------- 1 | import abc 2 | 3 | import torch 4 | 5 | from src.linearize import LinearizedImageEncoder 6 | 7 | 8 | class _TaskVector(abc.ABC): 9 | def __init__( 10 | self, pretrained_checkpoint=None, finetuned_checkpoint=None, vector=None 11 | ): 12 | """Initializes the task vector from a pretrained and a finetuned checkpoints. 13 | 14 | This can either be done by passing two state dicts (one corresponding to the 15 | pretrained model, and another to the finetuned model), or by directly passying in 16 | the task vector state dict. 17 | """ 18 | if vector is not None: 19 | self.vector = vector 20 | else: 21 | assert ( 22 | pretrained_checkpoint is not None and finetuned_checkpoint is not None 23 | ) 24 | with torch.no_grad(): 25 | pretrained_state_dict = self._load_checkpoint( 26 | pretrained_checkpoint 27 | ).state_dict() 28 | finetuned_state_dict = self._load_checkpoint( 29 | finetuned_checkpoint 30 | ).state_dict() 31 | self.vector = {} 32 | for key in pretrained_state_dict: 33 | if pretrained_state_dict[key].dtype == torch.int64: 34 | continue 35 | if pretrained_state_dict[key].dtype == torch.uint8: 36 | continue 37 | self.vector[key] = ( 38 | finetuned_state_dict[key] - pretrained_state_dict[key] 39 | ) 40 | 41 | @abc.abstractmethod 42 | def _load_checkpoint(self, checkpoint): 43 | """Load a checkpoint into a model.""" 44 | raise NotImplementedError 45 | 46 | @abc.abstractmethod 47 | def _cast_to_same_type(self, other): 48 | raise NotImplementedError 49 | 50 | def __add__(self, other): 51 | """Add two task vectors together.""" 52 | other = self._cast_to_same_type(other) 53 | with torch.no_grad(): 54 | new_vector = {} 55 | for key in self.vector: 56 | if key not in other.vector: 57 | print(f"Warning, key {key} is not present in both task vectors.") 58 | continue 59 | new_vector[key] = self.vector[key] + other.vector[key] 60 | return self.__class__(vector=new_vector) 61 | 62 | def __sub__(self, other): 63 | """Subtract two task vectors.""" 64 | return self.__add__(-other) 65 | 66 | def __radd__(self, other): 67 | if other is None or isinstance(other, int): 68 | return self 69 | return self.__add__(other) 70 | 71 | def __neg__(self): 72 | """Negate a task vector.""" 73 | with torch.no_grad(): 74 | new_vector = {} 75 | for key in self.vector: 76 | new_vector[key] = -self.vector[key] 77 | return self.__class__(vector=new_vector) 78 | 79 | def __pow__(self, power): 80 | """Power of a task vector.""" 81 | with torch.no_grad(): 82 | new_vector = {} 83 | for key in self.vector: 84 | new_vector[key] = self.vector[key] ** power 85 | return self.__class__(vector=new_vector) 86 | 87 | def __mul__(self, other): 88 | """Multiply a task vector by a scalar.""" 89 | with torch.no_grad(): 90 | new_vector = {} 91 | for key in self.vector: 92 | new_vector[key] = other * self.vector[key] 93 | return self.__class__(vector=new_vector) 94 | 95 | def dot(self, other): 96 | """Dot product of two task vectors.""" 97 | other = self._cast_to_same_type(other) 98 | with torch.no_grad(): 99 | dot_product = 0.0 100 | for key in self.vector: 101 | if key not in other.vector: 102 | print(f"Warning, key {key} is not present in both task vectors.") 103 | continue 104 | dot_product += torch.sum(self.vector[key] * other.vector[key]) 105 | return dot_product 106 | 107 | def norm(self): 108 | """Norm of a task vector.""" 109 | return torch.sqrt(self.dot(self)) 110 | 111 | def apply_to(self, pretrained_checkpoint, scaling_coef=1.0): 112 | """Apply a task vector to a pretrained model.""" 113 | with torch.no_grad(): 114 | pretrained_model = self._load_checkpoint(pretrained_checkpoint) 115 | new_state_dict = {} 116 | pretrained_state_dict = pretrained_model.state_dict() 117 | for key in pretrained_state_dict: 118 | if key not in self.vector: 119 | print( 120 | f"Warning: key {key} is present in the pretrained state dict but not in the task vector" # noqa: E501 121 | ) 122 | continue 123 | new_state_dict[key] = ( 124 | pretrained_state_dict[key] + scaling_coef * self.vector[key] 125 | ) 126 | pretrained_model.load_state_dict(new_state_dict) 127 | return pretrained_model 128 | 129 | 130 | class NonLinearTaskVector(_TaskVector): 131 | """A task vector for nonlinear models.""" 132 | 133 | def _load_checkpoint(self, checkpoint): 134 | """Load a checkpoint into a model.""" 135 | return torch.load(checkpoint, map_location="cpu") 136 | 137 | def apply_to_nonlinear(self, pretrained_nonlinear_checkpoint, scaling_coef=1.0): 138 | """Apply a task vector to a nonlinear pretrained model.""" 139 | return self.apply_to(pretrained_nonlinear_checkpoint, scaling_coef) 140 | 141 | def apply_to_linear(self, pretrained_linear_checkpoint, scaling_coef=1.0): 142 | """Apply a task vector to a linear pretrained model.""" 143 | return nonlinear_to_linear(self).apply_to( 144 | pretrained_linear_checkpoint, scaling_coef 145 | ) 146 | 147 | def _cast_to_same_type(self, other): 148 | return linear_to_nonlinear(other, self.vector.keys()) 149 | 150 | 151 | class LinearizedTaskVector(_TaskVector): 152 | """A task vector for linearized models.""" 153 | 154 | def _load_checkpoint(self, checkpoint): 155 | """Load a checkpoint into a model.""" 156 | return LinearizedImageEncoder.load(checkpoint) 157 | 158 | def apply_to_nonlinear( 159 | self, pretrained_nonlinear_checkpoint, param_names, scaling_coef=1.0 160 | ): 161 | """Apply a task vector to a nonlinear pretrained model.""" 162 | return linear_to_nonlinear(self, param_names).apply_to( 163 | pretrained_nonlinear_checkpoint, scaling_coef 164 | ) 165 | 166 | def apply_to_linear(self, pretrained_linear_checkpoint, scaling_coef=1.0): 167 | """Apply a task vector to a linear pretrained model.""" 168 | return self.apply_to(pretrained_linear_checkpoint, scaling_coef) 169 | 170 | def get_named_parameters(self, param_names): 171 | """Get the named parameters of the task vector.""" 172 | params = {k: v for k, v in self.vector.items() if "model.params0" not in k} 173 | return {k: v for k, v in zip(param_names, params.values())} 174 | 175 | def _cast_to_same_type(self, other): 176 | return nonlinear_to_linear(other) 177 | 178 | 179 | def nonlinear_to_linear(nonlinear_task_vector): 180 | """Convert a nonlinear task vector to a linear task vector.""" 181 | if isinstance(nonlinear_task_vector, LinearizedTaskVector): 182 | return nonlinear_task_vector 183 | else: 184 | linear_params = { 185 | f"model.params.{i}": v 186 | for i, v in enumerate(nonlinear_task_vector.vector.values()) 187 | } 188 | # The diff of the init params of the linearized moodels are all zero. 189 | linear_params |= { 190 | f"model.params0.{i}": torch.zeros_like(v) 191 | for i, v in enumerate(nonlinear_task_vector.vector.values()) 192 | } 193 | return LinearizedTaskVector(vector=linear_params) 194 | 195 | 196 | def linear_to_nonlinear(linear_task_vector, param_names): 197 | """Convert a linear task vector to a nonlinear task vector.""" 198 | if isinstance(linear_task_vector, NonLinearTaskVector): 199 | return linear_task_vector 200 | else: 201 | return NonLinearTaskVector( 202 | vector=linear_task_vector.get_named_parameters(param_names) 203 | ) 204 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | 4 | import numpy as np 5 | import torch 6 | 7 | 8 | def assign_learning_rate(param_group, new_lr): 9 | param_group["lr"] = new_lr 10 | 11 | 12 | def _warmup_lr(base_lr, warmup_length, step): 13 | return base_lr * (step + 1) / warmup_length 14 | 15 | 16 | def cosine_lr(optimizer, base_lrs, warmup_length, steps): 17 | if not isinstance(base_lrs, list): 18 | base_lrs = [base_lrs for _ in optimizer.param_groups] 19 | assert len(base_lrs) == len(optimizer.param_groups) 20 | 21 | def _lr_adjuster(step): 22 | for param_group, base_lr in zip(optimizer.param_groups, base_lrs): 23 | if step < warmup_length: 24 | lr = _warmup_lr(base_lr, warmup_length, step) 25 | else: 26 | e = step - warmup_length 27 | es = steps - warmup_length 28 | lr = 0.5 * (1 + np.cos(np.pi * e / es)) * base_lr 29 | assign_learning_rate(param_group, lr) 30 | 31 | return _lr_adjuster 32 | 33 | 34 | def accuracy(output, target, topk=(1,)): 35 | pred = output.topk(max(topk), 1, True, True)[1].t() 36 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 37 | return [ 38 | float(correct[:k].reshape(-1).float().sum(0, keepdim=True).cpu().numpy()) 39 | for k in topk 40 | ] 41 | 42 | 43 | def torch_load_old(save_path, device=None): 44 | with open(save_path, "rb") as f: 45 | classifier = pickle.load(f) 46 | if device is not None: 47 | classifier = classifier.to(device) 48 | return classifier 49 | 50 | 51 | def torch_save(model, save_path): 52 | if os.path.dirname(save_path) != "": 53 | os.makedirs(os.path.dirname(save_path), exist_ok=True) 54 | torch.save(model, save_path) 55 | 56 | 57 | def torch_load(save_path, device=None): 58 | model = torch.load(save_path, map_location="cpu") 59 | if device is not None: 60 | model = model.to(device) 61 | return model 62 | 63 | 64 | def get_logits(inputs, classifier): 65 | assert callable(classifier) 66 | if hasattr(classifier, "to"): 67 | classifier = classifier.to(inputs.device) 68 | return classifier(inputs) 69 | 70 | 71 | def get_probs(inputs, classifier): 72 | if hasattr(classifier, "predict_proba"): 73 | probs = classifier.predict_proba(inputs.detach().cpu().numpy()) 74 | return torch.from_numpy(probs) 75 | logits = get_logits(inputs, classifier) 76 | return logits.softmax(dim=1) 77 | 78 | 79 | class LabelSmoothing(torch.nn.Module): 80 | def __init__(self, smoothing=0.0): 81 | super(LabelSmoothing, self).__init__() 82 | self.confidence = 1.0 - smoothing 83 | self.smoothing = smoothing 84 | 85 | def forward(self, x, target): 86 | logprobs = torch.nn.functional.log_softmax(x, dim=-1) 87 | 88 | nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1)) 89 | nll_loss = nll_loss.squeeze(1) 90 | smooth_loss = -logprobs.mean(dim=-1) 91 | loss = self.confidence * nll_loss + self.smoothing * smooth_loss 92 | return loss.mean() 93 | 94 | 95 | class DotDict(dict): 96 | """dot.notation access to dictionary attributes""" 97 | 98 | __getattr__ = dict.get 99 | __setattr__ = dict.__setitem__ 100 | __delattr__ = dict.__delitem__ 101 | 102 | 103 | def find_optimal_coef( 104 | results, 105 | metric="avg_normalized_top1", 106 | minimize=False, 107 | control_metric=None, 108 | control_metric_threshold=0.0, 109 | ): 110 | best_coef = None 111 | if minimize: 112 | best_metric = 1 113 | else: 114 | best_metric = 0 115 | for scaling_coef in results.keys(): 116 | if control_metric is not None: 117 | if results[scaling_coef][control_metric] < control_metric_threshold: 118 | print(f"Control metric fell below {control_metric_threshold} threshold") 119 | continue 120 | if minimize: 121 | if results[scaling_coef][metric] < best_metric: 122 | best_metric = results[scaling_coef][metric] 123 | best_coef = scaling_coef 124 | else: 125 | if results[scaling_coef][metric] > best_metric: 126 | best_metric = results[scaling_coef][metric] 127 | best_coef = scaling_coef 128 | return best_coef 129 | 130 | 131 | def nonlinear_advantage(nonlinear_acc, linear_acc, num_classes): 132 | """Computes the normalized non-linear advantage of a finetuned model. 133 | 134 | The nonlinear_advantage is defined as: 135 | error_rate(linear_model) - error_rate(nonlinear_model) / (1 - 1 / num_classes) 136 | and takes values between [-1, 1]. A value of 0 indicates that the nonlinear 137 | model is no better than the linear one. Meanwhile, a value of 1 indicates 138 | that the nonlinear model is perfect and the linear trivial, and a value of 139 | -1 indicates the opposite. 140 | """ 141 | return (nonlinear_acc - linear_acc) / (1.0 - 1.0 / num_classes) 142 | --------------------------------------------------------------------------------