├── LICENSE ├── README.md ├── task_arithemetic ├── .gitignore ├── README.md ├── environment.yml ├── scripts │ ├── cross-task-linearity.py │ ├── eval.py │ ├── model_addition_stitching.py │ ├── model_ensemble_stitching.py │ └── model_negation_stitching.py └── src │ ├── args.py │ ├── datasets │ ├── cars.py │ ├── cifar10.py │ ├── cifar100.py │ ├── common.py │ ├── dtd.py │ ├── eurosat.py │ ├── gtsrb.py │ ├── imagenet.py │ ├── mnist.py │ ├── registry.py │ ├── resisc45.py │ ├── stl10.py │ ├── sun397.py │ ├── svhn.py │ └── templates.py │ ├── eval.py │ ├── finetune.py │ ├── heads.py │ ├── modeling.py │ ├── task_vectors.py │ └── utils │ ├── .DS_Store │ ├── __init__.py │ ├── avgmeter.py │ ├── dissimilarity.py │ ├── featuremap.py │ ├── load.py │ ├── plot.py │ ├── tools.py │ └── utils.py └── task_arithemetic_t5 ├── README.md ├── avgmeter.py ├── cross-task-linearity.py ├── data_utils.py ├── dissimilarity.py ├── featuremap.py ├── task_vectors.py └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Zhanpeng Zhou 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Cross-Task-Linearity 2 | Core release for "[On the Emergence of Cross-Task Linearity in Pretraining-Finetuning Paradigm](https://arxiv.org/abs/2402.03660)" (accepted by ICML 2024). 3 | 4 | ## Abstract 5 | The pretraining-finetuning paradigm has become the prevailing trend in modern deep learning. In this work, we discover an intriguing linear phenomenon in models that are initialized from a common pretrained checkpoint and finetuned on different tasks, termed as *Cross-Task Linearity (CTL)*. Specifically, we show that if we linearly interpolate the weights of two finetuned models, the features in the weight-interpolated model are often approximately equal to the linear interpolation of features in two finetuned models at each layer. We provide comprehensive empirical evidence supporting that CTL consistently occurs for finetuned models that start from the same pretrained checkpoint. We conjecture that in the pretraining-finetuning paradigm, neural networks approximately function as linear maps, mapping from the parameter space to the feature space. Based on this viewpoint, our study unveils novel insights into explaining model merging/editing, particularly by translating operations from the parameter space to the feature space. Furthermore, we delve deeper into the underlying factors for the emergence of CTL, highlighting the role of pretraining. 6 | 7 | ## Code 8 | `task_arithemetic` contains the code for *section 4.3 Insights into Task Arithmetic*, including cross-task-linearity on addition and negation, model stitching on addition and negation for ViT models. `task_arithemetic_t5` contains the code for T5 model in NLP domain. 9 | -------------------------------------------------------------------------------- /task_arithemetic/.gitignore: -------------------------------------------------------------------------------- 1 | checkpoints/ 2 | out/ 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | pip-wheel-metadata/ 27 | share/python-wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | MANIFEST 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .nox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | *.py,cover 54 | .hypothesis/ 55 | .pytest_cache/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | db.sqlite3-journal 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | target/ 79 | 80 | # Jupyter Notebook 81 | .ipynb_checkpoints 82 | 83 | # IPython 84 | profile_default/ 85 | ipython_config.py 86 | 87 | # pyenv 88 | .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 98 | __pypackages__/ 99 | 100 | # Celery stuff 101 | celerybeat-schedule 102 | celerybeat.pid 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .venv 110 | env/ 111 | venv/ 112 | ENV/ 113 | env.bak/ 114 | venv.bak/ 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | outs/ 134 | 135 | outs/ 136 | *.sh 137 | 138 | scripts/img 139 | scripts/plot_hyperbola_imagenet.py 140 | scripts/plot_hyperbola_target.py 141 | visualize/ 142 | *.tar.gz -------------------------------------------------------------------------------- /task_arithemetic/README.md: -------------------------------------------------------------------------------- 1 | # Editing Models with Task Arithmetic 2 | 3 | **We adopt the origin source and add our Cross-Task Linearity implementation in `scripts` directory.** 4 | `scripts/cross-task-linearity.py` is the main script for Cross-Task Linearity evaluation. 5 | Here are the scripts for Cross-Task Linearity evaluation: 6 | ```bash 7 | # addition 8 | python cross-task-linearity.py \ 9 | --save_root ../outs/vit-b-32/ctl_addition/MNIST \ 10 | --model_root $model_path \ 11 | --data_root $data_root \ 12 | --dataset MNIST \ 13 | --batch_size 64 \ 14 | --sample_num 1000 \ 15 | --modelA "0.8+MNIST" \ 16 | --modelB "0.8+Cars" 17 | # negation 18 | python cross-task-linearity.py \ 19 | --save_root ../outs/vit-b-32/ctl_negation/MNIST \ 20 | --model_root $model_path \ 21 | --data_root $data_root \ 22 | --dataset MNIST \ 23 | --batch_size 64 \ 24 | --sample_num 5000 \ 25 | --modelA "0.4+MNIST" \ 26 | --modelB "0.4_Cars" 27 | 28 | # addition_stitching 29 | python model_addition_stitching.py \ 30 | --save_root ../outs/vit-b-32/model_addition_stitching/MNIST \ 31 | --model_root $model_path \ 32 | --data_root $data_root \ 33 | --dataset MNIST \ 34 | --sample_num 1000 \ 35 | --modelA "0.8+MNIST" 36 | 37 | # negation_stitching 38 | python model_negation_stitching.py \ 39 | --save_root ../outs/vit-b-32/model_negation_stitching/MNIST \ 40 | --model_root $model_path \ 41 | --data_root $data_root \ 42 | --task MNIST \ 43 | --sample_num 1000 \ 44 | --dataset MNIST 45 | 46 | # ensemble_stitching for addition 47 | python model_ensemble_stitching.py \ 48 | --save_root ../outs/vit-b-32/ensemble_stitching/addition/MNIST+Cars \ 49 | --model_root $model_path \ 50 | --data_root $data_root \ 51 | --dataset MNIST \ 52 | --batch_size 64 \ 53 | --sample_num 1000 \ 54 | --modelA "0.8+MNIST" \ 55 | --modelB "0.8+Cars" 56 | # ensemble_stitching for negation 57 | python model_ensemble_stitching.py \ 58 | --save_root ../outs/vit-b-32/model_ensemble_stitching/negation/MNIST \ 59 | --model_root $model_root \ 60 | --data_root $data_root \ 61 | --dataset MNIST \ 62 | --sample_num 5000 \ 63 | --modelA "0.8_MNIST" \ 64 | --modelB "0.8+MNIST" 65 | ``` 66 | 67 | This repository contains origin code for the ICLR 2023 paper [Editing Models with Task Arithmetic](https://arxiv.org/abs/2212.04089). 68 | 69 | 70 | ### Abstract 71 | *Changing how pre-trained models behave---e.g., improving their performance on a downstream task or mitigating biases learned during pre-training---is a common practice when developing machine learning systems. In this work, we propose a new paradigm for steering the behavior of neural networks, centered around task vectors. A task vector specifies a direction in the weight space of a pre-trained model, such that movement in that direction improves performance on the task. We build task vectors by subtracting the weights of a pre-trained model from the weights of the same model after fine-tuning on a task. We show that these task vectors can be modified and combined together through arithmetic operations such as negation and addition, and the behavior of the resulting model is steered accordingly. Negating a task vector decreases performance on the target task, with little change in model behavior on control tasks. Moreover, adding task vectors together can improve performance on multiple tasks at once. Finally, when tasks are linked by an analogy relationship of the form ``A is to B as C is to D", combining task vectors from three of the tasks can improve performance on the fourth, even when no data from the fourth task is used for training. Overall, our experiments with several models, modalities and tasks show that task arithmetic is a simple, efficient and effective way of editing models.* 72 | 73 | 74 | 75 | ## Code 76 | 77 | ### Install dependencies 78 | 79 | ```bash 80 | conda env create 81 | conda activate task-vectors 82 | ``` 83 | 84 | 85 | ### Add directory to PYTHONPATH: 86 | 87 | ```bash 88 | cd task_vectors 89 | export PYTHONPATH="$PYTHONPATH:$PWD" 90 | ``` 91 | 92 | ### Using task vectors 93 | 94 | The task vector logic can be found at [src/task_vectors.py](src/task_vectors.py). 95 | 96 | To create a task vector, you will need a pre-trained checkpoint and a fine-tuned checkpoint: 97 | 98 | ```python 99 | from task_vectors import TaskVector 100 | task_vector = TaskVector(pretrained_checkpoint, finetuned_checkpoint) 101 | ``` 102 | 103 | Once created, task vectors can be modified and combined through arithmetic operations! For instance, to negate a task vector, simply use the ```-``` operator: 104 | 105 | ```python 106 | # Negating a task vector 107 | new_task_vector = -task_vector 108 | ``` 109 | 110 | To add task vectors, you can use the ```+``` operator, or ```sum```: 111 | 112 | ```python 113 | # Adding two task vectors 114 | new_task_vector = task_vector_A + task_vector_B 115 | # Adding multiple task vectors 116 | new_task_vector = sum(list_of_task_vectors) 117 | ``` 118 | 119 | Analogies can be done as simply as: 120 | 121 | ```python 122 | # Task analogies 123 | new_task_vector = task_vector_C + task_vector_B - task_vector_A 124 | ``` 125 | 126 | ### Checkpoints 127 | 128 | Checkpoints for CLIP ViT-B/32, ViT-B/16 and ViT-L/14 are available on he link below, including fine-tuned checkpoints on eight downstream tasks: Stanford Cars, DTD, EuroSAT, GTSRB, MNIST, RESISC45, SUN397 and SVHN. 129 | 130 | [Download here](https://drive.google.com/drive/folders/1u_Tva6x0p6oxu5Eo0ZZsf-520Cc_3MKw?usp=share_link) 131 | 132 | ### Examples 133 | 134 | Below is an example of negating a task vector from MNIST, then evaluating on MNIST and on ImageNet: 135 | 136 | ```python 137 | import torch 138 | from task_vectors import TaskVector 139 | from eval import eval_single_dataset 140 | from args import parse_arguments 141 | 142 | # Config 143 | dataset = 'MNIST' 144 | model = 'ViT-L-14' 145 | args = parse_arguments() 146 | args.data_location = '/path/to/data' 147 | args.model = model 148 | args.save = f'checkpoints/{model}' 149 | pretrained_checkpoint = f'checkpoints/{model}/zeroshot.pt' 150 | finetuned_checkpoint = f'checkpoints/{model}/{dataset}/finetuned.pt' 151 | 152 | 153 | # Create the task vector 154 | task_vector = TaskVector(pretrained_checkpoint, finetuned_checkpoint) 155 | # Negate the task vector 156 | neg_task_vector = -task_vector 157 | # Apply the task vector 158 | image_encoder = neg_task_vector.apply_to(pretrained_checkpoint, scaling_coef=0.5) 159 | # Evaluate 160 | eval_single_dataset(image_encoder, dataset, args) 161 | eval_single_dataset(image_encoder, 'ImageNet', args) 162 | ``` 163 | 164 | You can also find an example of adding task vectors together below, using the MNIST and RESISC45 datasets: 165 | 166 | 167 | ```python 168 | import torch 169 | from task_vectors import TaskVector 170 | from eval import eval_single_dataset 171 | from args import parse_arguments 172 | 173 | # Config 174 | datasets = ['MNIST', 'RESISC45'] 175 | model = 'ViT-L-14' 176 | args = parse_arguments() 177 | args.data_location = '/path/to/data' 178 | args.model = model 179 | args.save = f'checkpoints/{model}' 180 | pretrained_checkpoint = f'checkpoints/{model}/zeroshot.pt' 181 | 182 | # Create the task vectors 183 | task_vectors = [ 184 | TaskVector(pretrained_checkpoint, f'checkpoints/{model}/{dataset}/finetuned.pt') 185 | for dataset in datasets 186 | ] 187 | # Sum the task vectors 188 | task_vector_sum = sum(task_vectors) 189 | # Apply the resulting task vector 190 | image_encoder = task_vector_sum.apply_to(pretrained_checkpoint, scaling_coef=0.8) 191 | # Evaluate 192 | for dataset in datasets: 193 | eval_single_dataset(image_encoder, dataset, args) 194 | ``` 195 | -------------------------------------------------------------------------------- /task_arithemetic/environment.yml: -------------------------------------------------------------------------------- 1 | name: task-vectors 2 | channels: 3 | - defaults 4 | dependencies: 5 | - _libgcc_mutex=0.1=main 6 | - _openmp_mutex=5.1=1_gnu 7 | - ca-certificates=2023.08.22=h06a4308_0 8 | - ld_impl_linux-64=2.38=h1181459_1 9 | - libffi=3.4.4=h6a678d5_0 10 | - libgcc-ng=11.2.0=h1234567_1 11 | - libgomp=11.2.0=h1234567_1 12 | - libstdcxx-ng=11.2.0=h1234567_1 13 | - ncurses=6.4=h6a678d5_0 14 | - openssl=3.0.12=h7f8727e_0 15 | - pip=23.3.1=py38h06a4308_0 16 | - python=3.8.18=h955ad1f_0 17 | - readline=8.2=h5eee18b_0 18 | - setuptools=68.0.0=py38h06a4308_0 19 | - sqlite=3.41.2=h5eee18b_0 20 | - tk=8.6.12=h1ccaba5_0 21 | - wheel=0.41.2=py38h06a4308_0 22 | - xz=5.4.5=h5eee18b_0 23 | - zlib=1.2.13=h5eee18b_0 24 | - pip: 25 | - certifi==2023.11.17 26 | - charset-normalizer==3.3.2 27 | - clip==1.0 28 | - contourpy==1.1.1 29 | - cycler==0.12.1 30 | - filelock==3.13.1 31 | - fonttools==4.46.0 32 | - ftfy==6.1.3 33 | - huggingface-hub==0.10.0 34 | - idna==3.6 35 | - importlib-resources==6.1.1 36 | - kiwisolver==1.4.5 37 | - matplotlib==3.7.4 38 | - numpy==1.24.4 39 | - open-clip-torch==2.0.2 40 | - packaging==23.2 41 | - pandas==2.0.3 42 | - pillow==10.1.0 43 | - pyparsing==3.1.1 44 | - python-dateutil==2.8.2 45 | - pytz==2023.3.post1 46 | - pyyaml==6.0.1 47 | - regex==2023.10.3 48 | - requests==2.31.0 49 | - torch==1.12.1+cu116 50 | - torchaudio==0.12.1+cu116 51 | - torchvision==0.13.1+cu116 52 | - tqdm==4.66.1 53 | - typing-extensions==4.9.0 54 | - tzdata==2023.3 55 | - wcwidth==0.2.12 56 | - zipp==3.17.0 57 | prefix: /cpfs01/user/zhangbo/anaconda3/envs/task-vectors 58 | -------------------------------------------------------------------------------- /task_arithemetic/scripts/eval.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) 3 | # cwd change to current file's dir 4 | os.chdir(os.path.dirname(os.path.abspath(__file__))) 5 | import argparse 6 | import numpy as np 7 | import torch 8 | 9 | # import internal libs 10 | from src.modeling import ImageClassifier 11 | from src.datasets.registry import get_dataset 12 | from src.utils.load import load_image_encoder, decode_model_name 13 | from src.utils.tools import evaluate 14 | from src.utils import get_datetime, set_logger, get_logger, set_seed, set_device, \ 15 | log_settings, save_current_src 16 | 17 | def add_args() -> argparse.Namespace: 18 | parser = argparse.ArgumentParser( 19 | description="simple verification") 20 | ## the basic setting of exp 21 | parser.add_argument('--device', default=0, type=int, 22 | help="set the device.") 23 | parser.add_argument("--seed", default=0, type=int, 24 | help="set the seed.") 25 | parser.add_argument("--save_root", default="../outs/tmp/", type=str, 26 | help='the path of saving results.') 27 | parser.add_argument("--model_root", default=None, type=str, 28 | help='the path of loading models.') 29 | parser.add_argument("--data_root", default=None, type=str, 30 | help="The root directory for the datasets.",) 31 | parser.add_argument("--model", default=None, type=str, 32 | help="specify the model name",) 33 | parser.add_argument("--dataset", default="ImageNet", type=str, 34 | help='the dataset name.') 35 | parser.add_argument("--batch_size", default=256, type=int, 36 | help="set the batch size."), 37 | parser.add_argument("-v", "--verbose", action="store_true", dest="verbose", 38 | help="enable debug info output.") 39 | args = parser.parse_args() 40 | 41 | if not os.path.exists(args.save_root): 42 | os.makedirs(args.save_root) 43 | 44 | # set the save_path 45 | exp_name = "-".join([get_datetime(), 46 | f"seed{args.seed}", 47 | f"{args.model}", 48 | f"{args.dataset}", 49 | f"bs{args.batch_size}",]) 50 | args.save_path = os.path.join(args.save_root, exp_name) 51 | if not os.path.exists(args.save_path): 52 | os.makedirs(args.save_path) 53 | return args 54 | 55 | 56 | def main(): 57 | # get the args. 58 | args = add_args() 59 | # set the logger 60 | set_logger(args.save_path) 61 | # get the logger 62 | logger = get_logger(__name__, args.verbose) 63 | # set the seed 64 | set_seed(args.seed) 65 | # set the device 66 | args.device = set_device(args.device) 67 | # save the current src 68 | # save_current_src(save_path = args.save_path) 69 | 70 | # show the args. 71 | logger.info("#########parameters settings....") 72 | log_settings(args) 73 | 74 | # prepare the model 75 | logger.info("#########prepare the model....") 76 | # load the image encoder 77 | image_encoder = torch.load(os.path.join(args.model_root, "zeroshot.pt")) 78 | # load the classification head 79 | classification_head = torch.load(os.path.join(args.model_root, f"head_{args.dataset}.pt")) 80 | # construct the model 81 | model = ImageClassifier(image_encoder, classification_head) 82 | logger.info(f"model: {model}") 83 | 84 | # prepare the dataset 85 | logger.info("#########prepare the dataset....") 86 | dataset = get_dataset( 87 | dataset_name = args.dataset, 88 | preprocess=model.val_preprocess, 89 | location=args.data_root, 90 | batch_size=args.batch_size, 91 | ) 92 | dataloader = dataset.test_loader 93 | 94 | # eval the model over the dataset 95 | logger.info("#########eval the model over the dataset....") 96 | scaling_coef, task_vectors_info = decode_model_name(args.model) 97 | 98 | # get the image encoder to be evaluated 99 | model.image_encoder = load_image_encoder( 100 | model_root = args.model_root, 101 | task_vectors_info = task_vectors_info, 102 | scaling_coef = scaling_coef, 103 | ) 104 | 105 | avg_acc, avg_loss, predictions = evaluate(args.device, model, dataloader) 106 | logger.info(f"avg_acc: {avg_acc}") 107 | logger.info(f"avg_loss: {avg_loss}") 108 | 109 | # save the predictions 110 | np.save(os.path.join(args.save_path, "predictions.npy"), predictions) 111 | 112 | 113 | if __name__ == "__main__": 114 | main() -------------------------------------------------------------------------------- /task_arithemetic/scripts/model_addition_stitching.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) 3 | # cwd change to current file's dir 4 | os.chdir(os.path.dirname(os.path.abspath(__file__))) 5 | import argparse 6 | import torch 7 | import numpy as np 8 | from tqdm import tqdm 9 | import torch.nn as nn 10 | from torch.utils.data import DataLoader, Subset 11 | from collections import OrderedDict 12 | from open_clip import ResidualAttentionBlock 13 | 14 | # import internal libs 15 | from src.datasets.common import maybe_dictionarize 16 | from src.modeling import ImageClassifier 17 | from src.datasets.registry import get_dataset 18 | from src.utils import get_datetime, set_logger, get_logger, set_seed, set_device, \ 19 | log_settings, save_current_src, get_logits 20 | from src.utils.avgmeter import MetricTracker 21 | from src.utils.tools import get_module 22 | from src.utils.featuremap import FeatureMap 23 | from src.utils.load import load_image_encoder, decode_model_name 24 | 25 | 26 | 27 | def get_featuremaps(device: torch.device, 28 | weights: OrderedDict, 29 | model: nn.Module, 30 | dataloader: DataLoader,): 31 | """get the featuremaps of model A and model B 32 | 33 | Args: 34 | device (torch.device): the device to run the model. 35 | modelA_path (str): the path of model A. 36 | modelB_path (str): the path of model B. 37 | model (nn.Module): the model to extract featuremaps. 38 | dataloader (torch.utils.data.DataLoader): the dataloader. 39 | alpha (float): the interpolation coefficient. 40 | beta (float): the interpolation coefficient. 41 | 42 | Return: 43 | None 44 | """ 45 | logger = get_logger(f"{__name__}.get_featuremaps") 46 | 47 | 48 | # set the layers 49 | layers = [] 50 | for name, module in model.named_modules(): 51 | if isinstance(module, (ResidualAttentionBlock)): 52 | layers.append(name) 53 | 54 | # get the featuremaps of model 55 | logger.info(f"get the featuremaps of model") 56 | model.image_encoder.load_state_dict(weights) 57 | 58 | # get the featuremaps 59 | fm = FeatureMap(device, model) 60 | featuremap = fm.get_featuremaps(dataloader, layer_names=layers) 61 | del fm 62 | return featuremap 63 | 64 | 65 | def evaluate(device: torch.device, 66 | model: nn.Module, 67 | dataloader: DataLoader,) -> tuple: 68 | """evaluate the model over the dataset 69 | 70 | Args: 71 | device (torch.device): the device to run the model. 72 | model (nn.Module): the model to be evaluated. 73 | dataloader (Dataloader): usually the test loader. 74 | 75 | Return: 76 | (avg_acc, predictions) 77 | """ 78 | # set the model to eval mode 79 | model.eval() 80 | loss_fn = nn.CrossEntropyLoss(reduction="none").to(device) 81 | # evaluate 82 | with torch.no_grad(): 83 | predictions, corrects, n = [], 0., 0. 84 | test_losses = [] 85 | for _, data in enumerate(tqdm(dataloader)): 86 | # put data to the device 87 | data = maybe_dictionarize(data) 88 | x = data['images'].to(device) 89 | y = data['labels'].to(device) 90 | 91 | # forward 92 | logits = get_logits(x, model) 93 | 94 | # get losses 95 | losses = loss_fn(logits, y) 96 | 97 | # import ipdb; ipdb.set_trace() 98 | test_losses.extend(losses.cpu().detach().numpy()) 99 | # get the preds 100 | preds = logits.argmax(dim=1).to(device) 101 | 102 | # update 103 | predictions.extend(preds.cpu().numpy()) 104 | corrects += preds.eq(y.view_as(preds)).sum().item() 105 | n += y.size(0) 106 | 107 | # get the average acc, loss 108 | avg_acc = corrects / n 109 | avg_loss = np.mean(test_losses) 110 | 111 | return avg_acc, avg_loss, np.array(predictions) 112 | 113 | def model_stitching(device: torch.device, 114 | save_path: str, 115 | model: nn.Module, 116 | weights: OrderedDict, 117 | featuremaps: OrderedDict, 118 | dataloader: torch.utils.data.DataLoader,) -> None: 119 | """model_stitching 120 | 121 | Args: 122 | device (torch.device): _description_ 123 | save_path (str): _description_ 124 | model (nn.Module): _description_ 125 | weights (OrderedDict): _description_ 126 | featuremap (OrderedDict): _description_ 127 | dataloader (torch.utils.data.DataLoader): _description_ 128 | """ 129 | logger = get_logger(f"{__name__}.model_stitching") 130 | if not os.path.exists(save_path): 131 | os.makedirs(save_path) 132 | 133 | # initialize the tracker 134 | tracker = MetricTracker() 135 | # load the weights 136 | model.image_encoder.load_state_dict(weights) 137 | model.to(device) 138 | model.eval() 139 | featuremaps.pop("features") 140 | featuremaps.pop("logits") 141 | # for each layer, change the intermediate featuremaps 142 | for layer_name, X in featuremaps.items(): 143 | logger.info(f"predict with Weight_B and Feature_A on {layer_name}") 144 | # get the module 145 | module = get_module(model, layer_name) 146 | # init the curr_idx 147 | curr_idx = 0 148 | def hook(module, input, output): 149 | from open_clip import ResidualAttentionBlock 150 | if isinstance(module, nn.MultiheadAttention): 151 | output = output[0] 152 | if isinstance(module, (nn.MultiheadAttention, nn.Sequential, ResidualAttentionBlock)): 153 | output = output.permute(1, 0, 2) 154 | nonlocal curr_idx 155 | output.data.copy_(X[curr_idx:curr_idx + len(output)]) 156 | curr_idx += output.shape[0] 157 | if isinstance(module, (nn.MultiheadAttention, nn.Sequential, ResidualAttentionBlock)): 158 | output = output.permute(1, 0, 2) 159 | 160 | handle = module.register_forward_hook(hook) 161 | 162 | avg_acc, avg_loss, predictions = evaluate(device, model, dataloader) 163 | if not os.path.exists(os.path.join(save_path, f"layer_{layer_name}")): 164 | os.makedirs(os.path.join(save_path, f"layer_{layer_name}")) 165 | np.save(os.path.join(os.path.join(save_path, f"layer_{layer_name}"), "predictions.npy"), predictions) 166 | handle.remove() 167 | # track 168 | tracker.track({ 169 | "layer": layer_name, 170 | "avg_acc": avg_acc, 171 | "avg_loss": avg_loss, 172 | }) 173 | logger.info(f"layer: {layer_name}, avg_acc: {avg_acc}, avg_loss: {avg_loss}") 174 | # save the metric 175 | tracker.save_to_csv(os.path.join(save_path, "model_stitching.csv")) 176 | 177 | 178 | 179 | def add_args() -> argparse.Namespace: 180 | parser = argparse.ArgumentParser( 181 | description="simple verification") 182 | ## the basic setting of exp 183 | parser.add_argument('--device', default=0, type=int, 184 | help="set the device.") 185 | parser.add_argument("--seed", default=0, type=int, 186 | help="set the seed.") 187 | parser.add_argument("--save_root", default="../outs/model_stitching/", type=str, 188 | help='the path of saving results.') 189 | parser.add_argument("--model_root", default=None, type=str, 190 | help='the path of loading models.') 191 | parser.add_argument("--data_root", default=None, type=str, 192 | help="The root directory for the datasets.",) 193 | parser.add_argument("--modelA", default=None, type=str, 194 | help='set the model A.') 195 | parser.add_argument("--modelB", default=None, type=str, 196 | help='set the model B.') 197 | parser.add_argument("--dataset", default="ImageNet", type=str, 198 | help='the dataset name.') 199 | parser.add_argument("--sample_num", default=10000, type=int, 200 | help="set the sample number.") 201 | parser.add_argument("--batch_size", default=128, type=int, 202 | help="set the batch size."), 203 | parser.add_argument("-v", "--verbose", action="store_true", dest="verbose", 204 | help="enable debug info output.") 205 | args = parser.parse_args() 206 | 207 | if not os.path.exists(args.save_root): 208 | os.makedirs(args.save_root) 209 | 210 | # set the save_path 211 | exp_name = "-".join([get_datetime(), 212 | f"seed{args.seed}", 213 | f"{args.dataset}", 214 | f"modelA_{args.modelA}", 215 | f"modelB_{args.modelB}", 216 | f"sample_num{args.sample_num}", 217 | f"bs{args.batch_size}",]) 218 | args.save_path = os.path.join(args.save_root, exp_name) 219 | if not os.path.exists(args.save_path): 220 | os.makedirs(args.save_path) 221 | return args 222 | 223 | 224 | def main(): 225 | # get the args. 226 | args = add_args() 227 | # set the logger 228 | set_logger(args.save_path) 229 | # get the logger 230 | logger = get_logger(__name__, args.verbose) 231 | # set the seed 232 | set_seed(args.seed) 233 | # set the device 234 | args.device = set_device(args.device) 235 | # save the current src 236 | # save_current_src(save_path = args.save_path) 237 | 238 | # show the args. 239 | logger.info("#########parameters settings....") 240 | log_settings(args) 241 | 242 | # prepare the model 243 | logger.info("#########prepare the model....") 244 | # find the checkpoint with specification 245 | image_encoder = torch.load(os.path.join(args.model_root, "zeroshot.pt")) 246 | # load the classification head 247 | classification_head = torch.load(os.path.join(args.model_root, f"head_{args.dataset}.pt")) 248 | # construct the model 249 | model = ImageClassifier(image_encoder, classification_head) 250 | logger.info(f"model: {model}") 251 | 252 | # prepare the dataset 253 | logger.info("#########prepare the dataset....") 254 | dataset_wrap = get_dataset( 255 | dataset_name = args.dataset, 256 | preprocess=model.val_preprocess, 257 | location=args.data_root, 258 | batch_size=args.batch_size, 259 | ) 260 | dataset = dataset_wrap.test_loader.dataset 261 | indices = torch.randperm(len(dataset))[:args.sample_num] 262 | subset = Subset(dataset, indices) 263 | dataloader = DataLoader(subset, batch_size=args.batch_size, shuffle=False, num_workers=16) 264 | 265 | 266 | 267 | # get the featuremaps 268 | logger.info("#########get the featuremaps....") 269 | weights = dict() 270 | for key, model_name in [("A", args.modelA), ("B", args.modelB)]: # the model_name should be like "0.5+MNIST_CIFAR10" 271 | # extract the scaling coefficient and task vectors info 272 | if model_name is not None and model_name != "": 273 | scaling_coef, task_vectors_info = decode_model_name(model_name) 274 | logger.info(f"{key}: scaling_coef: {scaling_coef}, task_vectors_info: {task_vectors_info}") 275 | 276 | image_encoder_tmp = load_image_encoder(model_root=args.model_root, 277 | task_vectors_info=task_vectors_info, 278 | scaling_coef=scaling_coef,) 279 | weights[key] = image_encoder_tmp.state_dict() 280 | else: 281 | pretrain_image_encoder = torch.load(os.path.join(args.model_root, "zeroshot.pt")) 282 | weights[key] = pretrain_image_encoder.state_dict() 283 | 284 | featuremaps = get_featuremaps(device=args.device, 285 | weights=weights["A"], 286 | model=model, 287 | dataloader=dataloader,) 288 | 289 | # evalulate the model stitching 290 | logger.info("#########evalulate the model stitching....") 291 | model_stitching(device = args.device, 292 | save_path = os.path.join(args.save_path, "model_stitching"), 293 | model = model, 294 | weights = weights["B"], 295 | featuremaps = featuremaps, 296 | dataloader=dataloader,) 297 | 298 | 299 | if __name__ == "__main__": 300 | main() -------------------------------------------------------------------------------- /task_arithemetic/scripts/model_ensemble_stitching.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) 3 | # cwd change to current file's dir 4 | os.chdir(os.path.dirname(os.path.abspath(__file__))) 5 | import argparse 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | from torch.utils.data import DataLoader, Subset 10 | from collections import OrderedDict 11 | 12 | # import internal libs 13 | from src.modeling import ImageClassifier 14 | from src.datasets.registry import get_dataset 15 | from src.utils import get_datetime, set_logger, get_logger, set_seed, set_device, \ 16 | log_settings, save_current_src 17 | from src.utils.avgmeter import MetricTracker 18 | from src.utils.tools import interpolate_weights, get_module, evaluate 19 | from src.utils.featuremap import FeatureMap 20 | from src.utils.load import load_image_encoder, decode_model_name 21 | 22 | def get_featuremaps(device: torch.device, 23 | weights: OrderedDict, 24 | model: nn.Module, 25 | dataloader: DataLoader,): 26 | """get the featuremaps of model A and model B 27 | 28 | Args: 29 | device (torch.device): the device to run the model. 30 | weights (OrderedDict): the weights of the image_encoder. 31 | model (nn.Module): the model to extract featuremaps. 32 | dataloader (torch.utils.data.DataLoader): the dataloader. 33 | 34 | Return: 35 | None 36 | """ 37 | logger = get_logger(f"{__name__}.get_featuremaps") 38 | 39 | # set the layers 40 | layers = [] 41 | for name, module in model.named_modules(): 42 | from open_clip import ResidualAttentionBlock 43 | if isinstance(module, (ResidualAttentionBlock)): 44 | layers.append(name) 45 | 46 | # get the featuremaps of model 47 | logger.info(f"get the featuremaps of model") 48 | model.image_encoder.load_state_dict(weights) 49 | 50 | # get the featuremaps 51 | fm = FeatureMap(device, model) 52 | featuremap = fm.get_featuremaps(dataloader, layer_names=layers) 53 | del fm 54 | return featuremap 55 | 56 | 57 | def ensemble_stitch(device: torch.device, 58 | save_path: str, 59 | model: nn.Module, 60 | weights: dict, 61 | featuremaps: dict, 62 | dataloader: DataLoader, 63 | alpha: float = 0.5, 64 | beta: float = 0.5, 65 | feat: str = "int") -> None: 66 | """ensemble the featuremaps and stitch them 67 | 68 | Args: 69 | device (torch.device): the device to run the model. 70 | save_path (str): the path to save the results. 71 | model (nn.Module): the model to extract featuremaps. 72 | weights (dict): the weights of the image_encoder. 73 | featuremaps (dict): the featuremaps of the model. 74 | dataloader (torch.utils.data.DataLoader): the dataloader. 75 | alpha (float, optional): the alpha to interpolate the weights. Defaults to 0.5. 76 | beta (float, optional): the beta to interpolate the weights. Defaults to 0.5. 77 | feat (str, optional): the featuremap to interpolate. Defaults to "int". 78 | 79 | Return: 80 | None 81 | """ 82 | logger = get_logger(f"{__name__}.ensemble_stitch") 83 | if not os.path.exists(save_path): 84 | os.makedirs(save_path) 85 | 86 | # init the metric tracker 87 | tracker = MetricTracker() 88 | 89 | # first interpolate the weights 90 | assert list(weights.keys()) == list(featuremaps.keys()) == ["A", "B"], \ 91 | f"""the keys of weights and featuremaps should be the same, 92 | but got {weights.keys()} and {featuremaps.keys()} 93 | """ 94 | 95 | # use the pretrain model to compute 96 | # model.image_encoder.load_state_dict( 97 | # interpolate_weights(weights["A"], weights["B"], alpha=alpha, beta=beta) 98 | # ) 99 | 100 | # set the model to eval mode 101 | model.to(device) 102 | model.eval() 103 | 104 | # for each layer, change the intermediate featuremap 105 | for layer_name in featuremaps["A"].keys(): 106 | if layer_name == "features" or layer_name == "logits": 107 | continue 108 | logger.info(f"current layer {layer_name}") 109 | 110 | # get the H_A and H_B 111 | H_A = featuremaps["A"][layer_name] 112 | H_B = featuremaps["B"][layer_name] 113 | # get the interpolated H 114 | if feat == "int": 115 | H_int = alpha * H_A + beta * H_B 116 | elif feat == "A": 117 | H_int = H_A 118 | elif feat == "B": 119 | H_int = H_B 120 | else: 121 | raise NotImplementedError(f"feat: {feat} is not implemented.") 122 | 123 | # get the curr module 124 | module = get_module(model, layer_name) 125 | 126 | # hook part 127 | curr_idx = 0 128 | def hook(module, input, output): 129 | # preprocess 130 | from open_clip import ResidualAttentionBlock 131 | if isinstance(module, (ResidualAttentionBlock)): 132 | output = output.permute(1, 0, 2) 133 | 134 | # processing 135 | nonlocal curr_idx 136 | output.data.copy_(H_int[curr_idx:curr_idx + len(output)]) 137 | curr_idx += output.shape[0] 138 | 139 | # transform back 140 | if isinstance(module, (ResidualAttentionBlock)): 141 | output = output.permute(1, 0, 2) 142 | 143 | handle = module.register_forward_hook(hook) 144 | 145 | # evaluate 146 | avg_acc, avg_loss, _ = evaluate(device, model, dataloader) 147 | logger.info(f"avg_acc: {avg_acc}, avg_loss: {avg_loss}") 148 | 149 | # remove handle 150 | handle.remove() 151 | 152 | # track 153 | tracker.track({ 154 | "layer": layer_name, 155 | "avg_acc": avg_acc, 156 | "avg_loss": avg_loss, 157 | }) 158 | 159 | # save the results 160 | tracker.save_to_csv(os.path.join(save_path, "ensemble_stitch.csv")) 161 | 162 | 163 | def add_args() -> argparse.Namespace: 164 | parser = argparse.ArgumentParser( 165 | description="simple verification") 166 | ## the basic setting of exp 167 | parser.add_argument('--device', default=0, type=int, 168 | help="set the device.") 169 | parser.add_argument("--seed", default=0, type=int, 170 | help="set the seed.") 171 | parser.add_argument("--save_root", default="../outs/tmp/", type=str, 172 | help='the path of saving results.') 173 | parser.add_argument("--model_root", default=None, type=str, 174 | help='the path of loading models.') 175 | parser.add_argument("--data_root", default=None, type=str, 176 | help="The root directory for the datasets.",) 177 | parser.add_argument("--modelA", default=None, type=str, 178 | help='set the model A.') 179 | parser.add_argument("--modelB", default=None, type=str, 180 | help='set the model B.') 181 | parser.add_argument("--dataset", default="ImageNet", type=str, 182 | help='the dataset name.') 183 | parser.add_argument("--batch_size", default=256, type=int, 184 | help="set the batch size."), 185 | parser.add_argument("--alpha", default=0.5, type=float, 186 | help="set the alpha to interpolate the weights.") 187 | parser.add_argument("--beta", default=0.5, type=float, 188 | help="set the beta to interpolate the weights.") 189 | parser.add_argument("--sample_num", default=5000, type=int, 190 | help="the sample num to calculate the dissimilarity.") 191 | parser.add_argument("--metric", default="cosine", type=str, 192 | help="the metric to calculate the dissimilarity.") 193 | parser.add_argument("--feat", default="int", type=str, 194 | help="the featuremap to interpolate.") 195 | parser.add_argument("-v", "--verbose", action="store_true", dest="verbose", 196 | help="enable debug info output.") 197 | args = parser.parse_args() 198 | 199 | if not os.path.exists(args.save_root): 200 | os.makedirs(args.save_root) 201 | 202 | # set the save_path 203 | exp_name = "-".join([get_datetime(), 204 | f"seed{args.seed}", 205 | f"{args.dataset}", 206 | f"modelA_{args.modelA}", 207 | f"modelB_{args.modelB}", 208 | f"alpha{args.alpha}", 209 | f"beta{args.beta}", 210 | f"bs{args.batch_size}", 211 | f"sample_num{args.sample_num}", 212 | f"{args.metric}",]) 213 | args.save_path = os.path.join(args.save_root, exp_name) 214 | if not os.path.exists(args.save_path): 215 | os.makedirs(args.save_path) 216 | return args 217 | 218 | 219 | def main(): 220 | # get the args. 221 | args = add_args() 222 | # set the logger 223 | set_logger(args.save_path) 224 | # get the logger 225 | logger = get_logger(__name__, args.verbose) 226 | # set the seed 227 | set_seed(args.seed) 228 | # set the device 229 | args.device = set_device(args.device) 230 | # save the current src 231 | save_current_src(save_path = args.save_path) 232 | 233 | # show the args. 234 | logger.info("#########parameters settings....") 235 | log_settings(args) 236 | 237 | # prepare the model 238 | logger.info("#########prepare the model....") 239 | # find the checkpoint with specification 240 | image_encoder = torch.load(os.path.join(args.model_root, "zeroshot.pt")) 241 | # load the classification head 242 | classification_head = torch.load(os.path.join(args.model_root, f"head_{args.dataset}.pt")) 243 | # construct the model 244 | model = ImageClassifier(image_encoder, classification_head) 245 | logger.info(f"model: {model}") 246 | 247 | # prepare the dataset 248 | logger.info("#########prepare the dataset....") 249 | dataset_wrap = get_dataset( 250 | dataset_name = args.dataset, 251 | preprocess=model.val_preprocess, 252 | location=args.data_root, 253 | batch_size=args.batch_size, 254 | ) 255 | dataset = dataset_wrap.test_loader.dataset 256 | indices = torch.randperm(len(dataset))[:args.sample_num] 257 | subset = Subset(dataset, indices) 258 | dataloader = DataLoader(subset, batch_size=args.batch_size, shuffle=False, num_workers=16) 259 | 260 | 261 | 262 | 263 | 264 | # get the featuremaps 265 | logger.info("#########get the featuremaps....") 266 | weights, featuremaps = {}, {} 267 | for key, model_name in [("A", args.modelA), ("B", args.modelB)]: # the model_name should be like "0.5+MNIST_CIFAR10" 268 | # extract the scaling coefficient and task vectors info 269 | scaling_coef, task_vectors_info = decode_model_name(model_name) 270 | logger.info(f"{key}: scaling_coef: {scaling_coef}, task_vectors_info: {task_vectors_info}") 271 | 272 | # load the image_encoder 273 | image_encoder_tmp = load_image_encoder(model_root=args.model_root, 274 | task_vectors_info=task_vectors_info, 275 | scaling_coef=scaling_coef,) 276 | weights[key] = image_encoder_tmp.state_dict() 277 | 278 | # get the featuremaps 279 | featuremaps[key] = get_featuremaps(device=args.device, 280 | weights=weights[key], 281 | model=model, 282 | dataloader=dataloader,) 283 | 284 | # ensemble and stitch 285 | logger.info("#########ensemble and stitch...") 286 | ensemble_stitch(device=args.device, 287 | save_path=os.path.join(args.save_path, "ensemble_stitch_exp"), 288 | model=model, 289 | weights=weights, 290 | featuremaps=featuremaps, 291 | dataloader=dataloader, 292 | alpha=args.alpha, 293 | beta=args.beta, 294 | feat=args.feat,) 295 | 296 | 297 | if __name__ == "__main__": 298 | main() -------------------------------------------------------------------------------- /task_arithemetic/scripts/model_negation_stitching.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) 3 | # cwd change to current file's dir 4 | os.chdir(os.path.dirname(os.path.abspath(__file__))) 5 | import argparse 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | from torch.utils.data import DataLoader, Subset 10 | from collections import OrderedDict, defaultdict 11 | 12 | # import internal libs 13 | from src.modeling import ImageClassifier 14 | from src.datasets.registry import get_dataset 15 | from src.utils import get_datetime, set_logger, get_logger, set_seed, set_device, \ 16 | log_settings, save_current_src 17 | from src.utils.avgmeter import MetricTracker 18 | from src.utils.tools import get_module, evaluate 19 | from src.utils.featuremap import FeatureMap 20 | from src.utils.load import load_image_encoder 21 | 22 | def get_featuremaps(device: torch.device, 23 | weights: OrderedDict, 24 | model: nn.Module, 25 | dataloader: DataLoader,): 26 | """get the featuremaps of model A and model B 27 | Args: 28 | device (torch.device): the device to run the model. 29 | weights (OrderedDict): the weights of the image_encoder. 30 | model (nn.Module): the model to extract featuremaps. 31 | dataloader (torch.utils.data.DataLoader): the dataloader. 32 | Return: 33 | None 34 | """ 35 | logger = get_logger(f"{__name__}.get_featuremaps") 36 | 37 | # set the layers 38 | layers = [] 39 | for name, module in model.named_modules(): 40 | from open_clip import ResidualAttentionBlock 41 | if isinstance(module, (ResidualAttentionBlock)): 42 | layers.append(name) 43 | 44 | # get the featuremaps of model 45 | logger.info(f"get the featuremaps of model") 46 | model.image_encoder.load_state_dict(weights) 47 | 48 | # get the featuremaps 49 | fm = FeatureMap(device, model) 50 | featuremap = fm.get_featuremaps(dataloader, layer_names=layers) 51 | del fm 52 | return featuremap 53 | 54 | 55 | def model_stitch(device: torch.device, 56 | model: nn.Module, 57 | weights: dict, 58 | featuremaps: dict, 59 | dataloader: DataLoader,) -> defaultdict: 60 | """ensemble the featuremaps and stitch them 61 | Args: 62 | device (torch.device): the device to run the model. 63 | model (nn.Module): the model to extract featuremaps. 64 | weights (dict): the weights of the image_encoder. 65 | featuremaps (dict): the featuremaps of the model. 66 | dataloader (torch.utils.data.DataLoader): the dataloader. 67 | 68 | Return: 69 | None 70 | """ 71 | logger = get_logger(f"{__name__}.model_stitch") 72 | # init the metric tracker 73 | tracker = MetricTracker() 74 | 75 | # set the model to eval mode 76 | model.image_encoder.load_state_dict(weights) 77 | model.to(device) 78 | model.eval() 79 | 80 | # for each layer, change the intermediate featuremap 81 | for layer_name in featuremaps.keys(): 82 | if layer_name == "features" or layer_name == "logits": 83 | continue 84 | logger.info(f"current layer {layer_name}") 85 | 86 | # get the feature map in current layer 87 | H = featuremaps[layer_name] 88 | 89 | # get the curr module 90 | module = get_module(model, layer_name) 91 | 92 | # hook part 93 | curr_idx = 0 94 | def hook(module, input, output): 95 | # preprocess 96 | from open_clip import ResidualAttentionBlock 97 | if isinstance(module, (ResidualAttentionBlock)): 98 | output = output.permute(1, 0, 2) 99 | 100 | # processing 101 | nonlocal curr_idx 102 | output.data.copy_(H[curr_idx:curr_idx + len(output)]) 103 | curr_idx += output.shape[0] 104 | 105 | # transform back 106 | if isinstance(module, (ResidualAttentionBlock)): 107 | output = output.permute(1, 0, 2) 108 | 109 | handle = module.register_forward_hook(hook) 110 | 111 | # evaluate 112 | avg_acc, avg_loss, _ = evaluate(device, model, dataloader) 113 | logger.info(f"avg_acc: {avg_acc}, avg_loss: {avg_loss}") 114 | 115 | # remove handle 116 | handle.remove() 117 | 118 | # track 119 | tracker.track({ 120 | "layer": layer_name, 121 | "avg_acc": avg_acc, 122 | "avg_loss": avg_loss, 123 | }) 124 | 125 | # save the results 126 | return tracker.get_metrics() 127 | 128 | 129 | def add_args() -> argparse.Namespace: 130 | parser = argparse.ArgumentParser( 131 | description="simple verification") 132 | ## the basic setting of exp 133 | parser.add_argument('--device', default=0, type=int, 134 | help="set the device.") 135 | parser.add_argument("--seed", default=0, type=int, 136 | help="set the seed.") 137 | parser.add_argument("--save_root", default="../outs/negation_stitching/", type=str, 138 | help='the path of saving results.') 139 | parser.add_argument("--model_root", default=None, type=str, 140 | help='the path of loading models.') 141 | parser.add_argument("--data_root", default=None, type=str, 142 | help="The root directory for the datasets.",) 143 | parser.add_argument("--task", default=None, type=str, 144 | help='set the task name.') 145 | parser.add_argument("--dataset", default="ImageNet", type=str, 146 | help='the dataset name.') 147 | parser.add_argument("--sample_num", default=5000, type=int, 148 | help="the sample num to test") 149 | parser.add_argument("--batch_size", default=128, type=int, 150 | help="set the batch size."), 151 | parser.add_argument("-v", "--verbose", action="store_true", dest="verbose", 152 | help="enable debug info output.") 153 | args = parser.parse_args() 154 | 155 | if not os.path.exists(args.save_root): 156 | os.makedirs(args.save_root) 157 | 158 | # set the save_path 159 | exp_name = "-".join([get_datetime(), 160 | f"seed{args.seed}", 161 | f"{args.dataset}", 162 | f"{args.task}", 163 | f"sample_num{args.sample_num}", 164 | f"bs{args.batch_size}",]) 165 | args.save_path = os.path.join(args.save_root, exp_name) 166 | if not os.path.exists(args.save_path): 167 | os.makedirs(args.save_path) 168 | return args 169 | 170 | 171 | def main(): 172 | # get the args. 173 | args = add_args() 174 | # set the logger 175 | set_logger(args.save_path) 176 | # get the logger 177 | logger = get_logger(__name__, args.verbose) 178 | # set the seed 179 | set_seed(args.seed) 180 | # set the device 181 | args.device = set_device(args.device) 182 | # save the current src 183 | save_current_src(save_path = args.save_path) 184 | 185 | # show the args. 186 | logger.info("#########parameters settings....") 187 | log_settings(args) 188 | 189 | # prepare the model 190 | logger.info("#########prepare the model....") 191 | # find the checkpoint with specification 192 | zeroshot_image_encoder = torch.load(os.path.join(args.model_root, "zeroshot.pt")) 193 | # load the classification head 194 | classification_head = torch.load(os.path.join(args.model_root, f"head_{args.dataset}.pt")) 195 | # construct the model 196 | model = ImageClassifier(zeroshot_image_encoder, classification_head) 197 | logger.info(f"model: {model}") 198 | 199 | # prepare the dataset 200 | logger.info("#########prepare the dataset....") 201 | dataset_wrap = get_dataset( 202 | dataset_name = args.dataset, 203 | preprocess=model.val_preprocess, 204 | location=args.data_root, 205 | batch_size=args.batch_size, 206 | ) 207 | dataset = dataset_wrap.test_loader.dataset 208 | indices = torch.randperm(len(dataset))[:args.sample_num] 209 | subset = Subset(dataset, indices) 210 | dataloader = DataLoader(subset, batch_size=args.batch_size, shuffle=False, num_workers=16) 211 | 212 | # get the zeroshot featuremaps 213 | logger.info("#########prepare the zeroshot featuremaps") 214 | zeroshot_weights = zeroshot_image_encoder.state_dict() 215 | zeroshot_featuremaps = get_featuremaps(device=args.device, 216 | weights=zeroshot_weights, 217 | model=model, 218 | dataloader=dataloader,) 219 | 220 | # iterate each lambda and layers 221 | logger.info("#########negation and stitch...") 222 | # init the tracker 223 | tracker = MetricTracker() 224 | 225 | lambdas = [0.0, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, \ 226 | 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 1.0] 227 | for lambda_ in lambdas: 228 | logger.info(f"Evaluating the lambda {lambda_}") 229 | 230 | # first original minus featuremaps 231 | minus_image_encoder = load_image_encoder(model_root=args.model_root, 232 | task_vectors_info={args.task: "minus"}, 233 | scaling_coef=lambda_,) 234 | minus_origin_featuremaps = get_featuremaps(device=args.device, 235 | weights=minus_image_encoder.state_dict(), 236 | model=model, 237 | dataloader=dataloader,) 238 | minus_origin_metrics = model_stitch(device=args.device, 239 | model=model, 240 | weights=zeroshot_weights, 241 | featuremaps=minus_origin_featuremaps, 242 | dataloader=dataloader,) 243 | 244 | # then use the approximate minus featuremaps 245 | add_image_encoder = load_image_encoder(model_root=args.model_root, 246 | task_vectors_info={args.task: "add"}, 247 | scaling_coef=lambda_,) 248 | add_featuremaps = get_featuremaps(device=args.device, 249 | weights=add_image_encoder.state_dict(), 250 | model=model, 251 | dataloader=dataloader,) 252 | # creae the approximate minus featuremaps 253 | minus_approx_featuremaps = {} 254 | for layer_name in minus_origin_featuremaps.keys(): 255 | minus_approx_featuremaps[layer_name] = \ 256 | 2*zeroshot_featuremaps[layer_name] - add_featuremaps[layer_name] 257 | minus_approx_metrics = model_stitch(device=args.device, 258 | model=model, 259 | weights=zeroshot_weights, 260 | featuremaps=minus_approx_featuremaps, 261 | dataloader=dataloader,) 262 | 263 | for idx, layer in enumerate(minus_origin_metrics["layer"]): 264 | tracker.track({ 265 | "layer": layer, 266 | "lambda": lambda_, 267 | "origin_acc": minus_origin_metrics["avg_acc"][idx], 268 | "origin_loss": minus_origin_metrics["avg_loss"][idx], 269 | "approx_acc": minus_approx_metrics["avg_acc"][idx], 270 | "approx_loss": minus_approx_metrics["avg_loss"][idx], 271 | }) 272 | 273 | tracker.save_to_csv(os.path.join(args.save_path, "negation_stitch.csv")) 274 | 275 | 276 | if __name__ == "__main__": 277 | main() -------------------------------------------------------------------------------- /task_arithemetic/src/args.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | 4 | import torch 5 | 6 | def parse_arguments(): 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument( 9 | "--data-location", 10 | type=str, 11 | default=os.path.expanduser('~/data'), 12 | help="The root directory for the datasets.", 13 | ) 14 | parser.add_argument( 15 | "--eval-datasets", 16 | default=None, 17 | type=lambda x: x.split(","), 18 | help="Which datasets to use for evaluation. Split by comma, e.g. MNIST,EuroSAT. " 19 | ) 20 | parser.add_argument( 21 | "--train-dataset", 22 | default=None, 23 | type=lambda x: x.split(","), 24 | help="Which dataset(s) to patch on.", 25 | ) 26 | parser.add_argument( 27 | "--exp_name", 28 | type=str, 29 | default=None, 30 | help="Name of the experiment, for organization purposes only." 31 | ) 32 | parser.add_argument( 33 | "--results-db", 34 | type=str, 35 | default=None, 36 | help="Where to store the results, else does not store", 37 | ) 38 | parser.add_argument( 39 | "--model", 40 | type=str, 41 | default=None, 42 | help="The type of model (e.g. RN50, ViT-B-32).", 43 | ) 44 | parser.add_argument( 45 | "--batch-size", 46 | type=int, 47 | default=128, 48 | ) 49 | parser.add_argument( 50 | "--lr", 51 | type=float, 52 | default=0.001, 53 | help="Learning rate." 54 | ) 55 | parser.add_argument( 56 | "--wd", 57 | type=float, 58 | default=0.1, 59 | help="Weight decay" 60 | ) 61 | parser.add_argument( 62 | "--ls", 63 | type=float, 64 | default=0.0, 65 | help="Label smoothing." 66 | ) 67 | parser.add_argument( 68 | "--warmup_length", 69 | type=int, 70 | default=500, 71 | ) 72 | parser.add_argument( 73 | "--epochs", 74 | type=int, 75 | default=10, 76 | ) 77 | parser.add_argument( 78 | "--load", 79 | type=lambda x: x.split(","), 80 | default=None, 81 | help="Optionally load _classifiers_, e.g. a zero shot classifier or probe or ensemble both.", 82 | ) 83 | parser.add_argument( 84 | "--save", 85 | type=str, 86 | default=None, 87 | help="Optionally save a _classifier_, e.g. a zero shot classifier or probe.", 88 | ) 89 | parser.add_argument( 90 | "--cache-dir", 91 | type=str, 92 | default=None, 93 | help="Directory for caching features and encoder", 94 | ) 95 | parser.add_argument( 96 | "--openclip-cachedir", 97 | type=str, 98 | default='/gscratch/efml/gamaga/.cache/open_clip', 99 | help='Directory for caching models from OpenCLIP' 100 | ) 101 | parsed_args = parser.parse_args() 102 | parsed_args.device = "cuda" if torch.cuda.is_available() else "cpu" 103 | 104 | if parsed_args.load is not None and len(parsed_args.load) == 1: 105 | parsed_args.load = parsed_args.load[0] 106 | return parsed_args 107 | -------------------------------------------------------------------------------- /task_arithemetic/src/datasets/cars.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torchvision.datasets as datasets 4 | 5 | 6 | import pathlib 7 | from typing import Callable, Optional, Any, Tuple 8 | 9 | from PIL import Image 10 | 11 | from torchvision.datasets.utils import download_and_extract_archive, download_url, verify_str_arg 12 | from torchvision.datasets.vision import VisionDataset 13 | 14 | 15 | class PytorchStanfordCars(VisionDataset): 16 | """`Stanford Cars `_ Dataset 17 | 18 | The Cars dataset contains 16,185 images of 196 classes of cars. The data is 19 | split into 8,144 training images and 8,041 testing images, where each class 20 | has been split roughly in a 50-50 split 21 | 22 | .. note:: 23 | 24 | This class needs `scipy `_ to load target files from `.mat` format. 25 | 26 | Args: 27 | root (string): Root directory of dataset 28 | split (string, optional): The dataset split, supports ``"train"`` (default) or ``"test"``. 29 | transform (callable, optional): A function/transform that takes in an PIL image 30 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 31 | target_transform (callable, optional): A function/transform that takes in the 32 | target and transforms it. 33 | download (bool, optional): If True, downloads the dataset from the internet and 34 | puts it in root directory. If dataset is already downloaded, it is not 35 | downloaded again.""" 36 | 37 | def __init__( 38 | self, 39 | root: str, 40 | split: str = "train", 41 | transform: Optional[Callable] = None, 42 | target_transform: Optional[Callable] = None, 43 | download: bool = False, 44 | ) -> None: 45 | 46 | try: 47 | import scipy.io as sio 48 | except ImportError: 49 | raise RuntimeError("Scipy is not found. This dataset needs to have scipy installed: pip install scipy") 50 | 51 | super().__init__(root, transform=transform, target_transform=target_transform) 52 | 53 | self._split = verify_str_arg(split, "split", ("train", "test")) 54 | self._base_folder = pathlib.Path(root) / "stanford_cars" 55 | devkit = self._base_folder / "devkit" 56 | 57 | if self._split == "train": 58 | self._annotations_mat_path = devkit / "cars_train_annos.mat" 59 | self._images_base_path = self._base_folder / "cars_train" 60 | else: 61 | self._annotations_mat_path = self._base_folder / "cars_test_annos_withlabels.mat" 62 | self._images_base_path = self._base_folder / "cars_test" 63 | 64 | if download: 65 | self.download() 66 | 67 | if not self._check_exists(): 68 | raise RuntimeError("Dataset not found. You can use download=True to download it") 69 | 70 | self._samples = [ 71 | ( 72 | str(self._images_base_path / annotation["fname"]), 73 | annotation["class"] - 1, # Original target mapping starts from 1, hence -1 74 | ) 75 | for annotation in sio.loadmat(self._annotations_mat_path, squeeze_me=True)["annotations"] 76 | ] 77 | 78 | self.classes = sio.loadmat(str(devkit / "cars_meta.mat"), squeeze_me=True)["class_names"].tolist() 79 | self.class_to_idx = {cls: i for i, cls in enumerate(self.classes)} 80 | 81 | def __len__(self) -> int: 82 | return len(self._samples) 83 | 84 | def __getitem__(self, idx: int) -> Tuple[Any, Any]: 85 | """Returns pil_image and class_id for given index""" 86 | image_path, target = self._samples[idx] 87 | pil_image = Image.open(image_path).convert("RGB") 88 | 89 | if self.transform is not None: 90 | pil_image = self.transform(pil_image) 91 | if self.target_transform is not None: 92 | target = self.target_transform(target) 93 | return pil_image, target 94 | 95 | 96 | def download(self) -> None: 97 | if self._check_exists(): 98 | return 99 | 100 | download_and_extract_archive( 101 | url="https://ai.stanford.edu/~jkrause/cars/car_devkit.tgz", 102 | download_root=str(self._base_folder), 103 | md5="c3b158d763b6e2245038c8ad08e45376", 104 | ) 105 | if self._split == "train": 106 | download_and_extract_archive( 107 | url="https://ai.stanford.edu/~jkrause/car196/cars_train.tgz", 108 | download_root=str(self._base_folder), 109 | md5="065e5b463ae28d29e77c1b4b166cfe61", 110 | ) 111 | else: 112 | download_and_extract_archive( 113 | url="https://ai.stanford.edu/~jkrause/car196/cars_test.tgz", 114 | download_root=str(self._base_folder), 115 | md5="4ce7ebf6a94d07f1952d94dd34c4d501", 116 | ) 117 | download_url( 118 | url="https://ai.stanford.edu/~jkrause/car196/cars_test_annos_withlabels.mat", 119 | root=str(self._base_folder), 120 | md5="b0a2b23655a3edd16d84508592a98d10", 121 | ) 122 | 123 | def _check_exists(self) -> bool: 124 | if not (self._base_folder / "devkit").is_dir(): 125 | return False 126 | 127 | return self._annotations_mat_path.exists() and self._images_base_path.is_dir() 128 | 129 | 130 | class Cars: 131 | def __init__(self, 132 | preprocess, 133 | location=os.path.expanduser('~/data'), 134 | batch_size=32, 135 | num_workers=16): 136 | # Data loading code 137 | 138 | self.train_dataset = PytorchStanfordCars(location, 'train', preprocess, download=True) 139 | self.train_loader = torch.utils.data.DataLoader( 140 | self.train_dataset, 141 | shuffle=True, 142 | batch_size=batch_size, 143 | num_workers=num_workers, 144 | ) 145 | 146 | self.test_dataset = PytorchStanfordCars(location, 'test', preprocess, download=True) 147 | self.test_loader = torch.utils.data.DataLoader( 148 | self.test_dataset, 149 | batch_size=batch_size, 150 | num_workers=num_workers 151 | ) 152 | idx_to_class = dict((v, k) 153 | for k, v in self.train_dataset.class_to_idx.items()) 154 | self.classnames = [idx_to_class[i].replace( 155 | '_', ' ') for i in range(len(idx_to_class))] 156 | -------------------------------------------------------------------------------- /task_arithemetic/src/datasets/cifar10.py: -------------------------------------------------------------------------------- 1 | import os 2 | import PIL 3 | import torch 4 | import numpy as np 5 | import torchvision 6 | from torchvision import transforms 7 | from torchvision.datasets import CIFAR10 as PyTorchCIFAR10 8 | from torchvision.datasets import VisionDataset 9 | 10 | cifar_classnames = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'] 11 | 12 | class CIFAR10: 13 | def __init__(self, preprocess, 14 | location=os.path.expanduser('~/data'), 15 | batch_size=128, 16 | num_workers=16): 17 | 18 | 19 | self.train_dataset = PyTorchCIFAR10( 20 | root=location, download=True, train=True, transform=preprocess 21 | ) 22 | 23 | self.train_loader = torch.utils.data.DataLoader( 24 | self.train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers 25 | ) 26 | 27 | self.test_dataset = PyTorchCIFAR10( 28 | root=location, download=True, train=False, transform=preprocess 29 | ) 30 | 31 | self.test_loader = torch.utils.data.DataLoader( 32 | self.test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers 33 | ) 34 | 35 | self.classnames = self.test_dataset.classes 36 | 37 | def convert(x): 38 | if isinstance(x, np.ndarray): 39 | return torchvision.transforms.functional.to_pil_image(x) 40 | return x 41 | 42 | class BasicVisionDataset(VisionDataset): 43 | def __init__(self, images, targets, transform=None, target_transform=None): 44 | if transform is not None: 45 | transform.transforms.insert(0, convert) 46 | super(BasicVisionDataset, self).__init__(root=None, transform=transform, target_transform=target_transform) 47 | assert len(images) == len(targets) 48 | 49 | self.images = images 50 | self.targets = targets 51 | 52 | def __getitem__(self, index): 53 | return self.transform(self.images[index]), self.targets[index] 54 | 55 | def __len__(self): 56 | return len(self.targets) 57 | -------------------------------------------------------------------------------- /task_arithemetic/src/datasets/cifar100.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torchvision.datasets import CIFAR100 as PyTorchCIFAR100 4 | 5 | class CIFAR100: 6 | def __init__(self, 7 | preprocess, 8 | location=os.path.expanduser('~/data'), 9 | batch_size=128, 10 | num_workers=16): 11 | 12 | self.train_dataset = PyTorchCIFAR100( 13 | root=location, download=True, train=True, transform=preprocess 14 | ) 15 | 16 | self.train_loader = torch.utils.data.DataLoader( 17 | self.train_dataset, batch_size=batch_size, num_workers=num_workers 18 | ) 19 | 20 | self.test_dataset = PyTorchCIFAR100( 21 | root=location, download=True, train=False, transform=preprocess 22 | ) 23 | 24 | self.test_loader = torch.utils.data.DataLoader( 25 | self.test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers 26 | ) 27 | 28 | self.classnames = self.test_dataset.classes 29 | 30 | 31 | -------------------------------------------------------------------------------- /task_arithemetic/src/datasets/common.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import json 4 | import glob 5 | import collections 6 | import random 7 | 8 | import numpy as np 9 | 10 | from tqdm import tqdm 11 | 12 | import torchvision.datasets as datasets 13 | from torch.utils.data import Dataset, DataLoader, Sampler 14 | 15 | 16 | class SubsetSampler(Sampler): 17 | def __init__(self, indices): 18 | self.indices = indices 19 | 20 | def __iter__(self): 21 | return (i for i in self.indices) 22 | 23 | def __len__(self): 24 | return len(self.indices) 25 | 26 | class ImageFolderWithPaths(datasets.ImageFolder): 27 | def __init__(self, path, transform, flip_label_prob=0.0): 28 | super().__init__(path, transform) 29 | self.flip_label_prob = flip_label_prob 30 | if self.flip_label_prob > 0: 31 | print(f'Flipping labels with probability {self.flip_label_prob}') 32 | num_classes = len(self.classes) 33 | for i in range(len(self.samples)): 34 | if random.random() < self.flip_label_prob: 35 | new_label = random.randint(0, num_classes-1) 36 | self.samples[i] = ( 37 | self.samples[i][0], 38 | new_label 39 | ) 40 | 41 | def __getitem__(self, index): 42 | image, label = super(ImageFolderWithPaths, self).__getitem__(index) 43 | return { 44 | 'images': image, 45 | 'labels': label, 46 | 'image_paths': self.samples[index][0] 47 | } 48 | 49 | 50 | def maybe_dictionarize(batch): 51 | if isinstance(batch, dict): 52 | return batch 53 | 54 | if len(batch) == 2: 55 | batch = {'images': batch[0], 'labels': batch[1]} 56 | elif len(batch) == 3: 57 | batch = {'images': batch[0], 'labels': batch[1], 'metadata': batch[2]} 58 | else: 59 | raise ValueError(f'Unexpected number of elements: {len(batch)}') 60 | 61 | return batch 62 | 63 | 64 | def get_features_helper(image_encoder, dataloader, device): 65 | all_data = collections.defaultdict(list) 66 | 67 | image_encoder = image_encoder.to(device) 68 | image_encoder = torch.nn.DataParallel(image_encoder, device_ids=[x for x in range(torch.cuda.device_count())]) 69 | image_encoder.eval() 70 | 71 | with torch.no_grad(): 72 | for batch in tqdm(dataloader): 73 | batch = maybe_dictionarize(batch) 74 | features = image_encoder(batch['images'].cuda()) 75 | 76 | all_data['features'].append(features.cpu()) 77 | 78 | for key, val in batch.items(): 79 | if key == 'images': 80 | continue 81 | if hasattr(val, 'cpu'): 82 | val = val.cpu() 83 | all_data[key].append(val) 84 | else: 85 | all_data[key].extend(val) 86 | 87 | for key, val in all_data.items(): 88 | if torch.is_tensor(val[0]): 89 | all_data[key] = torch.cat(val).numpy() 90 | 91 | return all_data 92 | 93 | 94 | def get_features(is_train, image_encoder, dataset, device): 95 | split = 'train' if is_train else 'val' 96 | dname = type(dataset).__name__ 97 | if image_encoder.cache_dir is not None: 98 | cache_dir = f'{image_encoder.cache_dir}/{dname}/{split}' 99 | cached_files = glob.glob(f'{cache_dir}/*') 100 | if image_encoder.cache_dir is not None and len(cached_files) > 0: 101 | print(f'Getting features from {cache_dir}') 102 | data = {} 103 | for cached_file in cached_files: 104 | name = os.path.splitext(os.path.basename(cached_file))[0] 105 | data[name] = torch.load(cached_file) 106 | else: 107 | print(f'Did not find cached features at {cache_dir}. Building from scratch.') 108 | loader = dataset.train_loader if is_train else dataset.test_loader 109 | data = get_features_helper(image_encoder, loader, device) 110 | if image_encoder.cache_dir is None: 111 | print('Not caching because no cache directory was passed.') 112 | else: 113 | os.makedirs(cache_dir, exist_ok=True) 114 | print(f'Caching data at {cache_dir}') 115 | for name, val in data.items(): 116 | torch.save(val, f'{cache_dir}/{name}.pt') 117 | return data 118 | 119 | 120 | class FeatureDataset(Dataset): 121 | def __init__(self, is_train, image_encoder, dataset, device): 122 | self.data = get_features(is_train, image_encoder, dataset, device) 123 | 124 | def __len__(self): 125 | return len(self.data['features']) 126 | 127 | def __getitem__(self, idx): 128 | data = {k: v[idx] for k, v in self.data.items()} 129 | data['features'] = torch.from_numpy(data['features']).float() 130 | return data 131 | 132 | 133 | def get_dataloader(dataset, is_train, args, image_encoder=None): 134 | if image_encoder is not None: 135 | feature_dataset = FeatureDataset(is_train, image_encoder, dataset, args.device) 136 | dataloader = DataLoader(feature_dataset, batch_size=args.batch_size, shuffle=is_train) 137 | else: 138 | dataloader = dataset.train_loader if is_train else dataset.test_loader 139 | return dataloader -------------------------------------------------------------------------------- /task_arithemetic/src/datasets/dtd.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torchvision.datasets as datasets 4 | 5 | 6 | class DTD: 7 | def __init__(self, 8 | preprocess, 9 | location=os.path.expanduser('~/data'), 10 | batch_size=32, 11 | num_workers=16): 12 | # Data loading code 13 | traindir = os.path.join(location, 'dtd', 'train') 14 | valdir = os.path.join(location, 'dtd', 'val') 15 | 16 | # self.train_dataset = datasets.ImageFolder( 17 | # traindir, transform=preprocess) 18 | self.train_dataset = datasets.DTD(root=location, split="train", transform=preprocess) 19 | self.train_loader = torch.utils.data.DataLoader( 20 | self.train_dataset, 21 | shuffle=True, 22 | batch_size=batch_size, 23 | num_workers=num_workers, 24 | ) 25 | 26 | # self.test_dataset = datasets.ImageFolder(valdir, transform=preprocess) 27 | self.test_dataset = datasets.DTD(root=location, split="val",transform=preprocess) 28 | self.test_loader = torch.utils.data.DataLoader( 29 | self.test_dataset, 30 | batch_size=batch_size, 31 | num_workers=num_workers 32 | ) 33 | idx_to_class = dict((v, k) 34 | for k, v in self.train_dataset.class_to_idx.items()) 35 | self.classnames = [idx_to_class[i].replace( 36 | '_', ' ') for i in range(len(idx_to_class))] -------------------------------------------------------------------------------- /task_arithemetic/src/datasets/eurosat.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torchvision.datasets as datasets 4 | import re 5 | def pretify_classname(classname): 6 | l = re.findall(r'[A-Z](?:[a-z]+|[A-Z]*(?=[A-Z]|$))', classname) 7 | l = [i.lower() for i in l] 8 | out = ' '.join(l) 9 | if out.endswith('al'): 10 | return out + ' area' 11 | return out 12 | 13 | class EuroSATBase: 14 | def __init__(self, 15 | preprocess, 16 | test_split, 17 | location='~/datasets', 18 | batch_size=32, 19 | num_workers=16): 20 | # Data loading code 21 | # location = os.path.join(location, "eurosat") 22 | traindir = os.path.join(location, 'EuroSAT_splits', 'train') 23 | testdir = os.path.join(location, 'EuroSAT_splits', test_split) 24 | 25 | 26 | self.train_dataset = datasets.ImageFolder(traindir, transform=preprocess) 27 | self.train_loader = torch.utils.data.DataLoader( 28 | self.train_dataset, 29 | shuffle=True, 30 | batch_size=batch_size, 31 | num_workers=num_workers, 32 | ) 33 | 34 | self.test_dataset = datasets.ImageFolder(testdir, transform=preprocess) 35 | self.test_loader = torch.utils.data.DataLoader( 36 | self.test_dataset, 37 | batch_size=batch_size, 38 | num_workers=num_workers 39 | ) 40 | idx_to_class = dict((v, k) 41 | for k, v in self.train_dataset.class_to_idx.items()) 42 | self.classnames = [idx_to_class[i].replace('_', ' ') for i in range(len(idx_to_class))] 43 | self.classnames = [pretify_classname(c) for c in self.classnames] 44 | ours_to_open_ai = { 45 | 'annual crop': 'annual crop land', 46 | 'forest': 'forest', 47 | 'herbaceous vegetation': 'brushland or shrubland', 48 | 'highway': 'highway or road', 49 | 'industrial area': 'industrial buildings or commercial buildings', 50 | 'pasture': 'pasture land', 51 | 'permanent crop': 'permanent crop land', 52 | 'residential area': 'residential buildings or homes or apartments', 53 | 'river': 'river', 54 | 'sea lake': 'lake or sea', 55 | } 56 | for i in range(len(self.classnames)): 57 | self.classnames[i] = ours_to_open_ai[self.classnames[i]] 58 | 59 | 60 | class EuroSAT(EuroSATBase): 61 | def __init__(self, 62 | preprocess, 63 | location='~/datasets', 64 | batch_size=32, 65 | num_workers=16): 66 | super().__init__(preprocess, 'test', location, batch_size, num_workers) 67 | 68 | 69 | class EuroSATVal(EuroSATBase): 70 | def __init__(self, 71 | preprocess, 72 | location='~/datasets', 73 | batch_size=32, 74 | num_workers=16): 75 | super().__init__(preprocess, 'val', location, batch_size, num_workers) 76 | -------------------------------------------------------------------------------- /task_arithemetic/src/datasets/gtsrb.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import os 3 | import pathlib 4 | from typing import Any, Callable, Dict, List, Optional, Tuple 5 | 6 | import numpy as np 7 | import PIL 8 | import torch 9 | from torchvision.datasets.folder import make_dataset 10 | from torchvision.datasets.utils import (download_and_extract_archive, 11 | verify_str_arg) 12 | from torchvision.datasets.vision import VisionDataset 13 | 14 | def find_classes(directory: str) -> Tuple[List[str], Dict[str, int]]: 15 | """Finds the class folders in a dataset. 16 | 17 | See :class:`DatasetFolder` for details. 18 | """ 19 | classes = sorted(entry.name for entry in os.scandir(directory) if entry.is_dir()) 20 | if not classes: 21 | raise FileNotFoundError(f"Couldn't find any class folder in {directory}.") 22 | 23 | class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)} 24 | return classes, class_to_idx 25 | 26 | class PyTorchGTSRB(VisionDataset): 27 | """`German Traffic Sign Recognition Benchmark (GTSRB) `_ Dataset. 28 | 29 | Modified from https://pytorch.org/vision/main/_modules/torchvision/datasets/gtsrb.html#GTSRB. 30 | 31 | Args: 32 | root (string): Root directory of the dataset. 33 | split (string, optional): The dataset split, supports ``"train"`` (default), or ``"test"``. 34 | transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed 35 | version. E.g, ``transforms.RandomCrop``. 36 | target_transform (callable, optional): A function/transform that takes in the target and transforms it. 37 | download (bool, optional): If True, downloads the dataset from the internet and 38 | puts it in root directory. If dataset is already downloaded, it is not 39 | downloaded again. 40 | """ 41 | 42 | def __init__( 43 | self, 44 | root: str, 45 | split: str = "train", 46 | transform: Optional[Callable] = None, 47 | target_transform: Optional[Callable] = None, 48 | download: bool = False, 49 | ) -> None: 50 | 51 | super().__init__(root, transform=transform, target_transform=target_transform) 52 | 53 | self._split = verify_str_arg(split, "split", ("train", "test")) 54 | self._base_folder = pathlib.Path(root) / "gtsrb" 55 | self._target_folder = ( 56 | self._base_folder / "GTSRB" / ("Training" if self._split == "train" else "Final_Test/Images") 57 | ) 58 | 59 | if download: 60 | self.download() 61 | 62 | if not self._check_exists(): 63 | raise RuntimeError("Dataset not found. You can use download=True to download it") 64 | 65 | if self._split == "train": 66 | _, class_to_idx = find_classes(str(self._target_folder)) 67 | samples = make_dataset(str(self._target_folder), extensions=(".ppm",), class_to_idx=class_to_idx) 68 | else: 69 | with open(self._base_folder / "GT-final_test.csv") as csv_file: 70 | samples = [ 71 | (str(self._target_folder / row["Filename"]), int(row["ClassId"])) 72 | for row in csv.DictReader(csv_file, delimiter=";", skipinitialspace=True) 73 | ] 74 | 75 | self._samples = samples 76 | self.transform = transform 77 | self.target_transform = target_transform 78 | 79 | def __len__(self) -> int: 80 | return len(self._samples) 81 | 82 | def __getitem__(self, index: int) -> Tuple[Any, Any]: 83 | 84 | path, target = self._samples[index] 85 | sample = PIL.Image.open(path).convert("RGB") 86 | 87 | if self.transform is not None: 88 | sample = self.transform(sample) 89 | 90 | if self.target_transform is not None: 91 | target = self.target_transform(target) 92 | 93 | return sample, target 94 | 95 | 96 | def _check_exists(self) -> bool: 97 | return self._target_folder.is_dir() 98 | 99 | def download(self) -> None: 100 | if self._check_exists(): 101 | return 102 | 103 | base_url = "https://sid.erda.dk/public/archives/daaeac0d7ce1152aea9b61d9f1e19370/" 104 | 105 | if self._split == "train": 106 | download_and_extract_archive( 107 | f"{base_url}GTSRB-Training_fixed.zip", 108 | download_root=str(self._base_folder), 109 | md5="513f3c79a4c5141765e10e952eaa2478", 110 | ) 111 | else: 112 | download_and_extract_archive( 113 | f"{base_url}GTSRB_Final_Test_Images.zip", 114 | download_root=str(self._base_folder), 115 | md5="c7e4e6327067d32654124b0fe9e82185", 116 | ) 117 | download_and_extract_archive( 118 | f"{base_url}GTSRB_Final_Test_GT.zip", 119 | download_root=str(self._base_folder), 120 | md5="fe31e9c9270bbcd7b84b7f21a9d9d9e5", 121 | ) 122 | 123 | 124 | class GTSRB: 125 | def __init__(self, 126 | preprocess, 127 | location=os.path.expanduser('~/data'), 128 | batch_size=128, 129 | num_workers=16): 130 | 131 | # to fit with repo conventions for location 132 | self.train_dataset = PyTorchGTSRB( 133 | root=location, 134 | download=True, 135 | split='train', 136 | transform=preprocess 137 | ) 138 | 139 | self.train_loader = torch.utils.data.DataLoader( 140 | self.train_dataset, 141 | batch_size=batch_size, 142 | shuffle=True, 143 | num_workers=num_workers 144 | ) 145 | 146 | self.test_dataset = PyTorchGTSRB( 147 | root=location, 148 | download=True, 149 | split='test', 150 | transform=preprocess 151 | ) 152 | 153 | self.test_loader = torch.utils.data.DataLoader( 154 | self.test_dataset, 155 | batch_size=batch_size, 156 | shuffle=False, 157 | num_workers=num_workers 158 | ) 159 | 160 | # from https://github.com/openai/CLIP/blob/e184f608c5d5e58165682f7c332c3a8b4c1545f2/data/prompts.md 161 | self.classnames = [ 162 | 'red and white circle 20 kph speed limit', 163 | 'red and white circle 30 kph speed limit', 164 | 'red and white circle 50 kph speed limit', 165 | 'red and white circle 60 kph speed limit', 166 | 'red and white circle 70 kph speed limit', 167 | 'red and white circle 80 kph speed limit', 168 | 'end / de-restriction of 80 kph speed limit', 169 | 'red and white circle 100 kph speed limit', 170 | 'red and white circle 120 kph speed limit', 171 | 'red and white circle red car and black car no passing', 172 | 'red and white circle red truck and black car no passing', 173 | 'red and white triangle road intersection warning', 174 | 'white and yellow diamond priority road', 175 | 'red and white upside down triangle yield right-of-way', 176 | 'stop', 177 | 'empty red and white circle', 178 | 'red and white circle no truck entry', 179 | 'red circle with white horizonal stripe no entry', 180 | 'red and white triangle with exclamation mark warning', 181 | 'red and white triangle with black left curve approaching warning', 182 | 'red and white triangle with black right curve approaching warning', 183 | 'red and white triangle with black double curve approaching warning', 184 | 'red and white triangle rough / bumpy road warning', 185 | 'red and white triangle car skidding / slipping warning', 186 | 'red and white triangle with merging / narrow lanes warning', 187 | 'red and white triangle with person digging / construction / road work warning', 188 | 'red and white triangle with traffic light approaching warning', 189 | 'red and white triangle with person walking warning', 190 | 'red and white triangle with child and person walking warning', 191 | 'red and white triangle with bicyle warning', 192 | 'red and white triangle with snowflake / ice warning', 193 | 'red and white triangle with deer warning', 194 | 'white circle with gray strike bar no speed limit', 195 | 'blue circle with white right turn arrow mandatory', 196 | 'blue circle with white left turn arrow mandatory', 197 | 'blue circle with white forward arrow mandatory', 198 | 'blue circle with white forward or right turn arrow mandatory', 199 | 'blue circle with white forward or left turn arrow mandatory', 200 | 'blue circle with white keep right arrow mandatory', 201 | 'blue circle with white keep left arrow mandatory', 202 | 'blue circle with white arrows indicating a traffic circle', 203 | 'white circle with gray strike bar indicating no passing for cars has ended', 204 | 'white circle with gray strike bar indicating no passing for trucks has ended', 205 | ] 206 | -------------------------------------------------------------------------------- /task_arithemetic/src/datasets/mnist.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torchvision.datasets as datasets 4 | 5 | class MNIST: 6 | def __init__(self, 7 | preprocess, 8 | location=os.path.expanduser('~/data'), 9 | batch_size=128, 10 | num_workers=16): 11 | 12 | 13 | self.train_dataset = datasets.MNIST( 14 | root=location, 15 | download=True, 16 | train=True, 17 | transform=preprocess 18 | ) 19 | 20 | self.train_loader = torch.utils.data.DataLoader( 21 | self.train_dataset, 22 | batch_size=batch_size, 23 | shuffle=True, 24 | num_workers=num_workers 25 | ) 26 | 27 | self.test_dataset = datasets.MNIST( 28 | root=location, 29 | download=True, 30 | train=False, 31 | transform=preprocess 32 | ) 33 | 34 | self.test_loader = torch.utils.data.DataLoader( 35 | self.test_dataset, 36 | batch_size=batch_size, 37 | shuffle=False, 38 | num_workers=num_workers 39 | ) 40 | 41 | self.classnames = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'] -------------------------------------------------------------------------------- /task_arithemetic/src/datasets/registry.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import inspect 3 | import random 4 | import torch 5 | import copy 6 | 7 | from torch.utils.data.dataset import random_split 8 | 9 | from src.datasets.cars import Cars 10 | from src.datasets.cifar10 import CIFAR10 # no need 11 | from src.datasets.cifar100 import CIFAR100 # no need 12 | from src.datasets.dtd import DTD 13 | from src.datasets.eurosat import EuroSAT, EuroSATVal 14 | from src.datasets.gtsrb import GTSRB 15 | from src.datasets.imagenet import ImageNet 16 | from src.datasets.mnist import MNIST 17 | from src.datasets.resisc45 import RESISC45 18 | from src.datasets.stl10 import STL10 # no need 19 | from src.datasets.svhn import SVHN 20 | from src.datasets.sun397 import SUN397 21 | 22 | registry = { 23 | name: obj for name, obj in inspect.getmembers(sys.modules[__name__], inspect.isclass) 24 | } 25 | 26 | 27 | class GenericDataset(object): 28 | def __init__(self): 29 | self.train_dataset = None 30 | self.train_loader = None 31 | self.test_dataset = None 32 | self.test_loader = None 33 | self.classnames = None 34 | 35 | 36 | def split_train_into_train_val(dataset, new_dataset_class_name, batch_size, num_workers, val_fraction, max_val_samples=None, seed=0): 37 | assert val_fraction > 0. and val_fraction < 1. 38 | total_size = len(dataset.train_dataset) 39 | val_size = int(total_size * val_fraction) 40 | if max_val_samples is not None: 41 | val_size = min(val_size, max_val_samples) 42 | train_size = total_size - val_size 43 | 44 | assert val_size > 0 45 | assert train_size > 0 46 | 47 | lengths = [train_size, val_size] 48 | 49 | trainset, valset = random_split( 50 | dataset.train_dataset, 51 | lengths, 52 | generator=torch.Generator().manual_seed(seed) 53 | ) 54 | if new_dataset_class_name == 'MNISTVal': 55 | assert trainset.indices[0] == 36044 56 | 57 | 58 | new_dataset = None 59 | 60 | new_dataset_class = type(new_dataset_class_name, (GenericDataset, ), {}) 61 | new_dataset = new_dataset_class() 62 | 63 | new_dataset.train_dataset = trainset 64 | new_dataset.train_loader = torch.utils.data.DataLoader( 65 | new_dataset.train_dataset, 66 | shuffle=True, 67 | batch_size=batch_size, 68 | num_workers=num_workers, 69 | ) 70 | 71 | new_dataset.test_dataset = valset 72 | new_dataset.test_loader = torch.utils.data.DataLoader( 73 | new_dataset.test_dataset, 74 | batch_size=batch_size, 75 | num_workers=num_workers 76 | ) 77 | 78 | new_dataset.classnames = copy.copy(dataset.classnames) 79 | 80 | return new_dataset 81 | 82 | 83 | def get_dataset(dataset_name, preprocess, location, batch_size=128, num_workers=16, val_fraction=0.1, max_val_samples=5000): 84 | if dataset_name.endswith('Val'): 85 | # Handle val splits 86 | if dataset_name in registry: 87 | dataset_class = registry[dataset_name] 88 | else: 89 | base_dataset_name = dataset_name.split('Val')[0] 90 | base_dataset = get_dataset(base_dataset_name, preprocess, location, batch_size, num_workers) 91 | dataset = split_train_into_train_val( 92 | base_dataset, dataset_name, batch_size, num_workers, val_fraction, max_val_samples) 93 | return dataset 94 | else: 95 | assert dataset_name in registry, f'Unsupported dataset: {dataset_name}. Supported datasets: {list(registry.keys())}' 96 | dataset_class = registry[dataset_name] 97 | dataset = dataset_class( 98 | preprocess, location=location, batch_size=batch_size, num_workers=num_workers 99 | ) 100 | return dataset 101 | -------------------------------------------------------------------------------- /task_arithemetic/src/datasets/resisc45.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | import abc 5 | import os 6 | from typing import Any, Callable, Dict, Optional, Tuple 7 | 8 | import numpy as np 9 | import torch 10 | from torch import Tensor 11 | from torch.utils.data import Dataset 12 | from torchvision.datasets import ImageFolder 13 | from torchvision.datasets.folder import default_loader as pil_loader 14 | 15 | 16 | # modified from: https://github.com/microsoft/torchgeo 17 | class VisionDataset(Dataset[Dict[str, Any]], abc.ABC): 18 | """Abstract base class for datasets lacking geospatial information. 19 | This base class is designed for datasets with pre-defined image chips. 20 | """ 21 | 22 | @abc.abstractmethod 23 | def __getitem__(self, index: int) -> Dict[str, Any]: 24 | """Return an index within the dataset. 25 | Args: 26 | index: index to return 27 | Returns: 28 | data and labels at that index 29 | Raises: 30 | IndexError: if index is out of range of the dataset 31 | """ 32 | 33 | @abc.abstractmethod 34 | def __len__(self) -> int: 35 | """Return the length of the dataset. 36 | Returns: 37 | length of the dataset 38 | """ 39 | 40 | def __str__(self) -> str: 41 | """Return the informal string representation of the object. 42 | Returns: 43 | informal string representation 44 | """ 45 | return f"""\ 46 | {self.__class__.__name__} Dataset 47 | type: VisionDataset 48 | size: {len(self)}""" 49 | 50 | 51 | class VisionClassificationDataset(VisionDataset, ImageFolder): 52 | """Abstract base class for classification datasets lacking geospatial information. 53 | This base class is designed for datasets with pre-defined image chips which 54 | are separated into separate folders per class. 55 | """ 56 | 57 | def __init__( 58 | self, 59 | root: str, 60 | transforms: Optional[Callable[[Dict[str, Tensor]], Dict[str, Tensor]]] = None, 61 | loader: Optional[Callable[[str], Any]] = pil_loader, 62 | is_valid_file: Optional[Callable[[str], bool]] = None, 63 | ) -> None: 64 | """Initialize a new VisionClassificationDataset instance. 65 | Args: 66 | root: root directory where dataset can be found 67 | transforms: a function/transform that takes input sample and its target as 68 | entry and returns a transformed version 69 | loader: a callable function which takes as input a path to an image and 70 | returns a PIL Image or numpy array 71 | is_valid_file: A function that takes the path of an Image file and checks if 72 | the file is a valid file 73 | """ 74 | # When transform & target_transform are None, ImageFolder.__getitem__(index) 75 | # returns a PIL.Image and int for image and label, respectively 76 | super().__init__( 77 | root=root, 78 | transform=None, 79 | target_transform=None, 80 | loader=loader, 81 | is_valid_file=is_valid_file, 82 | ) 83 | 84 | # Must be set after calling super().__init__() 85 | self.transforms = transforms 86 | 87 | def __getitem__(self, index: int) -> Dict[str, Tensor]: 88 | """Return an index within the dataset. 89 | Args: 90 | index: index to return 91 | Returns: 92 | data and label at that index 93 | """ 94 | image, label = self._load_image(index) 95 | 96 | if self.transforms is not None: 97 | return self.transforms(image), label 98 | 99 | return image, label 100 | 101 | def __len__(self) -> int: 102 | """Return the number of data points in the dataset. 103 | Returns: 104 | length of the dataset 105 | """ 106 | return len(self.imgs) 107 | 108 | def _load_image(self, index: int) -> Tuple[Tensor, Tensor]: 109 | """Load a single image and it's class label. 110 | Args: 111 | index: index to return 112 | Returns: 113 | the image 114 | the image class label 115 | """ 116 | img, label = ImageFolder.__getitem__(self, index) 117 | label = torch.tensor(label) 118 | return img, label 119 | 120 | 121 | class RESISC45Dataset(VisionClassificationDataset): 122 | """RESISC45 dataset. 123 | The `RESISC45 `_ 124 | dataset is a dataset for remote sensing image scene classification. 125 | Dataset features: 126 | * 31,500 images with 0.2-30 m per pixel resolution (256x256 px) 127 | * three spectral bands - RGB 128 | * 45 scene classes, 700 images per class 129 | * images extracted from Google Earth from over 100 countries 130 | * images conditions with high variability (resolution, weather, illumination) 131 | Dataset format: 132 | * images are three-channel jpgs 133 | Dataset classes: 134 | 0. airplane 135 | 1. airport 136 | 2. baseball_diamond 137 | 3. basketball_court 138 | 4. beach 139 | 5. bridge 140 | 6. chaparral 141 | 7. church 142 | 8. circular_farmland 143 | 9. cloud 144 | 10. commercial_area 145 | 11. dense_residential 146 | 12. desert 147 | 13. forest 148 | 14. freeway 149 | 15. golf_course 150 | 16. ground_track_field 151 | 17. harbor 152 | 18. industrial_area 153 | 19. intersection 154 | 20. island 155 | 21. lake 156 | 22. meadow 157 | 23. medium_residential 158 | 24. mobile_home_park 159 | 25. mountain 160 | 26. overpass 161 | 27. palace 162 | 28. parking_lot 163 | 29. railway 164 | 30. railway_station 165 | 31. rectangular_farmland 166 | 32. river 167 | 33. roundabout 168 | 34. runway 169 | 35. sea_ice 170 | 36. ship 171 | 37. snowberg 172 | 38. sparse_residential 173 | 39. stadium 174 | 40. storage_tank 175 | 41. tennis_court 176 | 42. terrace 177 | 43. thermal_power_station 178 | 44. wetland 179 | This dataset uses the train/val/test splits defined in the "In-domain representation 180 | learning for remote sensing" paper: 181 | * https://arxiv.org/abs/1911.06721 182 | If you use this dataset in your research, please cite the following paper: 183 | * https://doi.org/10.1109/jproc.2017.2675998 184 | """ 185 | 186 | # url = "https://drive.google.com/file/d/1DnPSU5nVSN7xv95bpZ3XQ0JhKXZOKgIv" 187 | # md5 = "d824acb73957502b00efd559fc6cfbbb" 188 | # filename = "NWPU-RESISC45.rar" 189 | directory = "resisc45/NWPU-RESISC45" 190 | 191 | splits = ["train", "val", "test"] 192 | split_urls = { 193 | "train": "https://storage.googleapis.com/remote_sensing_representations/resisc45-train.txt", # noqa: E501 194 | "val": "https://storage.googleapis.com/remote_sensing_representations/resisc45-val.txt", # noqa: E501 195 | "test": "https://storage.googleapis.com/remote_sensing_representations/resisc45-test.txt", # noqa: E501 196 | } 197 | split_md5s = { 198 | "train": "b5a4c05a37de15e4ca886696a85c403e", 199 | "val": "a0770cee4c5ca20b8c32bbd61e114805", 200 | "test": "3dda9e4988b47eb1de9f07993653eb08", 201 | } 202 | classes = [ 203 | "airplane", 204 | "airport", 205 | "baseball_diamond", 206 | "basketball_court", 207 | "beach", 208 | "bridge", 209 | "chaparral", 210 | "church", 211 | "circular_farmland", 212 | "cloud", 213 | "commercial_area", 214 | "dense_residential", 215 | "desert", 216 | "forest", 217 | "freeway", 218 | "golf_course", 219 | "ground_track_field", 220 | "harbor", 221 | "industrial_area", 222 | "intersection", 223 | "island", 224 | "lake", 225 | "meadow", 226 | "medium_residential", 227 | "mobile_home_park", 228 | "mountain", 229 | "overpass", 230 | "palace", 231 | "parking_lot", 232 | "railway", 233 | "railway_station", 234 | "rectangular_farmland", 235 | "river", 236 | "roundabout", 237 | "runway", 238 | "sea_ice", 239 | "ship", 240 | "snowberg", 241 | "sparse_residential", 242 | "stadium", 243 | "storage_tank", 244 | "tennis_court", 245 | "terrace", 246 | "thermal_power_station", 247 | "wetland", 248 | ] 249 | 250 | def __init__( 251 | self, 252 | root: str = "data", 253 | split: str = "train", 254 | transforms: Optional[Callable[[Dict[str, Tensor]], Dict[str, Tensor]]] = None, 255 | ) -> None: 256 | """Initialize a new RESISC45 dataset instance. 257 | Args: 258 | root: root directory where dataset can be found 259 | split: one of "train", "val", or "test" 260 | transforms: a function/transform that takes input sample and its target as 261 | entry and returns a transformed version 262 | """ 263 | assert split in self.splits 264 | self.root = root 265 | 266 | valid_fns = set() 267 | with open(os.path.join(self.root, "resisc45", f"resisc45-{split}.txt")) as f: 268 | for fn in f: 269 | valid_fns.add(fn.strip()) 270 | is_in_split: Callable[[str], bool] = lambda x: os.path.basename( 271 | x) in valid_fns 272 | 273 | super().__init__( 274 | root=os.path.join(root, self.directory), 275 | transforms=transforms, 276 | is_valid_file=is_in_split, 277 | ) 278 | 279 | 280 | 281 | class RESISC45: 282 | def __init__(self, 283 | preprocess, 284 | location=os.path.expanduser('~/data'), 285 | batch_size=32, 286 | num_workers=16): 287 | 288 | self.train_dataset = RESISC45Dataset(root=location, split='train', transforms=preprocess) 289 | self.train_loader = torch.utils.data.DataLoader( 290 | self.train_dataset, 291 | shuffle=True, 292 | batch_size=batch_size, 293 | num_workers=num_workers, 294 | ) 295 | 296 | self.test_dataset = RESISC45Dataset(root=location, split='test', transforms=preprocess) 297 | self.test_loader = torch.utils.data.DataLoader( 298 | self.test_dataset, 299 | batch_size=batch_size, 300 | num_workers=num_workers 301 | ) 302 | 303 | # class names have _ so split on this for better zero-shot head 304 | self.classnames = [' '.join(c.split('_')) for c in RESISC45Dataset.classes] 305 | -------------------------------------------------------------------------------- /task_arithemetic/src/datasets/stl10.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torchvision.datasets as datasets 4 | 5 | class STL10: 6 | def __init__(self, 7 | preprocess, 8 | location=os.path.expanduser('~/data'), 9 | batch_size=128, 10 | num_workers=16): 11 | 12 | location = os.path.join(location, 'stl10') 13 | self.train_dataset = datasets.STL10( 14 | root=location, 15 | download=True, 16 | split='train', 17 | transform=preprocess 18 | ) 19 | 20 | self.train_loader = torch.utils.data.DataLoader( 21 | self.train_dataset, 22 | batch_size=batch_size, 23 | shuffle=True, 24 | num_workers=num_workers 25 | ) 26 | 27 | self.test_dataset = datasets.STL10( 28 | root=location, 29 | download=True, 30 | split='test', 31 | transform=preprocess 32 | ) 33 | 34 | self.test_loader = torch.utils.data.DataLoader( 35 | self.test_dataset, 36 | batch_size=batch_size, 37 | shuffle=False, 38 | num_workers=num_workers 39 | ) 40 | 41 | self.classnames = self.train_dataset.classes -------------------------------------------------------------------------------- /task_arithemetic/src/datasets/sun397.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torchvision.datasets as datasets 4 | 5 | class SUN397: 6 | def __init__(self, 7 | preprocess, 8 | location=os.path.expanduser('~/data'), 9 | batch_size=32, 10 | num_workers=16): 11 | # Data loading code 12 | traindir = os.path.join(location, 'SUN397_splits', 'train') 13 | valdir = os.path.join(location, 'SUN397_splits', 'val') 14 | 15 | self.train_dataset = datasets.SUN397(root=traindir, transform=preprocess) 16 | # self.train_dataset = datasets.ImageFolder(traindir, transform=preprocess) 17 | self.train_loader = torch.utils.data.DataLoader( 18 | self.train_dataset, 19 | shuffle=True, 20 | batch_size=batch_size, 21 | num_workers=num_workers, 22 | ) 23 | self.test_dataset = datasets.SUN397(root=valdir, transform=preprocess) 24 | # self.test_dataset = datasets.ImageFolder(valdir, transform=preprocess) 25 | self.test_loader = torch.utils.data.DataLoader( 26 | self.test_dataset, 27 | batch_size=batch_size, 28 | num_workers=num_workers 29 | ) 30 | idx_to_class = dict((v, k) 31 | for k, v in self.train_dataset.class_to_idx.items()) 32 | self.classnames = [idx_to_class[i][2:].replace('_', ' ') for i in range(len(idx_to_class))] 33 | -------------------------------------------------------------------------------- /task_arithemetic/src/datasets/svhn.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torchvision.datasets import SVHN as PyTorchSVHN 4 | import numpy as np 5 | 6 | 7 | class SVHN: 8 | def __init__(self, 9 | preprocess, 10 | location=os.path.expanduser('~/data'), 11 | batch_size=128, 12 | num_workers=16): 13 | 14 | # to fit with repo conventions for location 15 | modified_location = os.path.join(location, 'svhn') 16 | 17 | self.train_dataset = PyTorchSVHN( 18 | root=modified_location, 19 | download=True, 20 | split='train', 21 | transform=preprocess 22 | ) 23 | 24 | self.train_loader = torch.utils.data.DataLoader( 25 | self.train_dataset, 26 | batch_size=batch_size, 27 | shuffle=True, 28 | num_workers=num_workers 29 | ) 30 | 31 | self.test_dataset = PyTorchSVHN( 32 | root=modified_location, 33 | download=True, 34 | split='test', 35 | transform=preprocess 36 | ) 37 | 38 | self.test_loader = torch.utils.data.DataLoader( 39 | self.test_dataset, 40 | batch_size=batch_size, 41 | shuffle=False, 42 | num_workers=num_workers 43 | ) 44 | 45 | self.classnames = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'] 46 | -------------------------------------------------------------------------------- /task_arithemetic/src/datasets/templates.py: -------------------------------------------------------------------------------- 1 | cars_template = [ 2 | lambda c: f'a photo of a {c}.', 3 | lambda c: f'a photo of the {c}.', 4 | lambda c: f'a photo of my {c}.', 5 | lambda c: f'i love my {c}!', 6 | lambda c: f'a photo of my dirty {c}.', 7 | lambda c: f'a photo of my clean {c}.', 8 | lambda c: f'a photo of my new {c}.', 9 | lambda c: f'a photo of my old {c}.', 10 | ] 11 | 12 | cifar10_template = [ 13 | lambda c: f'a photo of a {c}.', 14 | lambda c: f'a blurry photo of a {c}.', 15 | lambda c: f'a black and white photo of a {c}.', 16 | lambda c: f'a low contrast photo of a {c}.', 17 | lambda c: f'a high contrast photo of a {c}.', 18 | lambda c: f'a bad photo of a {c}.', 19 | lambda c: f'a good photo of a {c}.', 20 | lambda c: f'a photo of a small {c}.', 21 | lambda c: f'a photo of a big {c}.', 22 | lambda c: f'a photo of the {c}.', 23 | lambda c: f'a blurry photo of the {c}.', 24 | lambda c: f'a black and white photo of the {c}.', 25 | lambda c: f'a low contrast photo of the {c}.', 26 | lambda c: f'a high contrast photo of the {c}.', 27 | lambda c: f'a bad photo of the {c}.', 28 | lambda c: f'a good photo of the {c}.', 29 | lambda c: f'a photo of the small {c}.', 30 | lambda c: f'a photo of the big {c}.', 31 | ] 32 | 33 | cifar100_template = [ 34 | lambda c: f'a photo of a {c}.', 35 | lambda c: f'a blurry photo of a {c}.', 36 | lambda c: f'a black and white photo of a {c}.', 37 | lambda c: f'a low contrast photo of a {c}.', 38 | lambda c: f'a high contrast photo of a {c}.', 39 | lambda c: f'a bad photo of a {c}.', 40 | lambda c: f'a good photo of a {c}.', 41 | lambda c: f'a photo of a small {c}.', 42 | lambda c: f'a photo of a big {c}.', 43 | lambda c: f'a photo of the {c}.', 44 | lambda c: f'a blurry photo of the {c}.', 45 | lambda c: f'a black and white photo of the {c}.', 46 | lambda c: f'a low contrast photo of the {c}.', 47 | lambda c: f'a high contrast photo of the {c}.', 48 | lambda c: f'a bad photo of the {c}.', 49 | lambda c: f'a good photo of the {c}.', 50 | lambda c: f'a photo of the small {c}.', 51 | lambda c: f'a photo of the big {c}.', 52 | ] 53 | 54 | dtd_template = [ 55 | lambda c: f'a photo of a {c} texture.', 56 | lambda c: f'a photo of a {c} pattern.', 57 | lambda c: f'a photo of a {c} thing.', 58 | lambda c: f'a photo of a {c} object.', 59 | lambda c: f'a photo of the {c} texture.', 60 | lambda c: f'a photo of the {c} pattern.', 61 | lambda c: f'a photo of the {c} thing.', 62 | lambda c: f'a photo of the {c} object.', 63 | ] 64 | 65 | eurosat_template = [ 66 | lambda c: f'a centered satellite photo of {c}.', 67 | lambda c: f'a centered satellite photo of a {c}.', 68 | lambda c: f'a centered satellite photo of the {c}.', 69 | ] 70 | 71 | food101_template = [ 72 | lambda c: f'a photo of {c}, a type of food.', 73 | ] 74 | 75 | gtsrb_template = [ 76 | lambda c: f'a zoomed in photo of a "{c}" traffic sign.', 77 | lambda c: f'a centered photo of a "{c}" traffic sign.', 78 | lambda c: f'a close up photo of a "{c}" traffic sign.', 79 | ] 80 | 81 | mnist_template = [ 82 | lambda c: f'a photo of the number: "{c}".', 83 | ] 84 | 85 | imagenet_template = [ 86 | lambda c: f'a bad photo of a {c}.', 87 | lambda c: f'a photo of many {c}.', 88 | lambda c: f'a sculpture of a {c}.', 89 | lambda c: f'a photo of the hard to see {c}.', 90 | lambda c: f'a low resolution photo of the {c}.', 91 | lambda c: f'a rendering of a {c}.', 92 | lambda c: f'graffiti of a {c}.', 93 | lambda c: f'a bad photo of the {c}.', 94 | lambda c: f'a cropped photo of the {c}.', 95 | lambda c: f'a tattoo of a {c}.', 96 | lambda c: f'the embroidered {c}.', 97 | lambda c: f'a photo of a hard to see {c}.', 98 | lambda c: f'a bright photo of a {c}.', 99 | lambda c: f'a photo of a clean {c}.', 100 | lambda c: f'a photo of a dirty {c}.', 101 | lambda c: f'a dark photo of the {c}.', 102 | lambda c: f'a drawing of a {c}.', 103 | lambda c: f'a photo of my {c}.', 104 | lambda c: f'the plastic {c}.', 105 | lambda c: f'a photo of the cool {c}.', 106 | lambda c: f'a close-up photo of a {c}.', 107 | lambda c: f'a black and white photo of the {c}.', 108 | lambda c: f'a painting of the {c}.', 109 | lambda c: f'a painting of a {c}.', 110 | lambda c: f'a pixelated photo of the {c}.', 111 | lambda c: f'a sculpture of the {c}.', 112 | lambda c: f'a bright photo of the {c}.', 113 | lambda c: f'a cropped photo of a {c}.', 114 | lambda c: f'a plastic {c}.', 115 | lambda c: f'a photo of the dirty {c}.', 116 | lambda c: f'a jpeg corrupted photo of a {c}.', 117 | lambda c: f'a blurry photo of the {c}.', 118 | lambda c: f'a photo of the {c}.', 119 | lambda c: f'a good photo of the {c}.', 120 | lambda c: f'a rendering of the {c}.', 121 | lambda c: f'a {c} in a video game.', 122 | lambda c: f'a photo of one {c}.', 123 | lambda c: f'a doodle of a {c}.', 124 | lambda c: f'a close-up photo of the {c}.', 125 | lambda c: f'a photo of a {c}.', 126 | lambda c: f'the origami {c}.', 127 | lambda c: f'the {c} in a video game.', 128 | lambda c: f'a sketch of a {c}.', 129 | lambda c: f'a doodle of the {c}.', 130 | lambda c: f'a origami {c}.', 131 | lambda c: f'a low resolution photo of a {c}.', 132 | lambda c: f'the toy {c}.', 133 | lambda c: f'a rendition of the {c}.', 134 | lambda c: f'a photo of the clean {c}.', 135 | lambda c: f'a photo of a large {c}.', 136 | lambda c: f'a rendition of a {c}.', 137 | lambda c: f'a photo of a nice {c}.', 138 | lambda c: f'a photo of a weird {c}.', 139 | lambda c: f'a blurry photo of a {c}.', 140 | lambda c: f'a cartoon {c}.', 141 | lambda c: f'art of a {c}.', 142 | lambda c: f'a sketch of the {c}.', 143 | lambda c: f'a embroidered {c}.', 144 | lambda c: f'a pixelated photo of a {c}.', 145 | lambda c: f'itap of the {c}.', 146 | lambda c: f'a jpeg corrupted photo of the {c}.', 147 | lambda c: f'a good photo of a {c}.', 148 | lambda c: f'a plushie {c}.', 149 | lambda c: f'a photo of the nice {c}.', 150 | lambda c: f'a photo of the small {c}.', 151 | lambda c: f'a photo of the weird {c}.', 152 | lambda c: f'the cartoon {c}.', 153 | lambda c: f'art of the {c}.', 154 | lambda c: f'a drawing of the {c}.', 155 | lambda c: f'a photo of the large {c}.', 156 | lambda c: f'a black and white photo of a {c}.', 157 | lambda c: f'the plushie {c}.', 158 | lambda c: f'a dark photo of a {c}.', 159 | lambda c: f'itap of a {c}.', 160 | lambda c: f'graffiti of the {c}.', 161 | lambda c: f'a toy {c}.', 162 | lambda c: f'itap of my {c}.', 163 | lambda c: f'a photo of a cool {c}.', 164 | lambda c: f'a photo of a small {c}.', 165 | lambda c: f'a tattoo of the {c}.', 166 | ] 167 | 168 | resisc45_template = [ 169 | lambda c: f'satellite imagery of {c}.', 170 | lambda c: f'aerial imagery of {c}.', 171 | lambda c: f'satellite photo of {c}.', 172 | lambda c: f'aerial photo of {c}.', 173 | lambda c: f'satellite view of {c}.', 174 | lambda c: f'aerial view of {c}.', 175 | lambda c: f'satellite imagery of a {c}.', 176 | lambda c: f'aerial imagery of a {c}.', 177 | lambda c: f'satellite photo of a {c}.', 178 | lambda c: f'aerial photo of a {c}.', 179 | lambda c: f'satellite view of a {c}.', 180 | lambda c: f'aerial view of a {c}.', 181 | lambda c: f'satellite imagery of the {c}.', 182 | lambda c: f'aerial imagery of the {c}.', 183 | lambda c: f'satellite photo of the {c}.', 184 | lambda c: f'aerial photo of the {c}.', 185 | lambda c: f'satellite view of the {c}.', 186 | lambda c: f'aerial view of the {c}.', 187 | ] 188 | 189 | stl10_template = [ 190 | lambda c: f'a photo of a {c}.', 191 | lambda c: f'a photo of the {c}.', 192 | ] 193 | 194 | sun397_template = [ 195 | lambda c: f'a photo of a {c}.', 196 | lambda c: f'a photo of the {c}.', 197 | ] 198 | 199 | svhn_template = [ 200 | lambda c: f'a photo of the number: "{c}".', 201 | ] 202 | 203 | 204 | dataset_to_template = { 205 | 'Cars': cars_template, 206 | 'CIFAR10': cifar10_template, 207 | 'CIFAR100': cifar100_template, 208 | 'DTD': dtd_template, 209 | 'EuroSAT': eurosat_template, 210 | 'Food101': food101_template, 211 | 'GTSRB': gtsrb_template, 212 | 'MNIST': mnist_template, 213 | 'ImageNet': imagenet_template, 214 | 'RESISC45': resisc45_template, 215 | 'STL10': stl10_template, 216 | 'SUN397': sun397_template, 217 | 'SVHN': svhn_template, 218 | } 219 | 220 | 221 | def get_templates(dataset_name): 222 | if dataset_name.endswith('Val'): 223 | return get_templates(dataset_name.replace('Val', '')) 224 | assert dataset_name in dataset_to_template, f'Unsupported dataset: {dataset_name}' 225 | return dataset_to_template[dataset_name] -------------------------------------------------------------------------------- /task_arithemetic/src/eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import tqdm 4 | 5 | import torch 6 | import numpy as np 7 | 8 | from src import utils 9 | from src.datasets.common import get_dataloader, maybe_dictionarize 10 | from src.heads import get_classification_head 11 | from src.modeling import ImageClassifier 12 | 13 | from src.datasets.registry import get_dataset 14 | 15 | 16 | def eval_single_dataset(image_encoder, dataset_name, args): 17 | classification_head = get_classification_head(args, dataset_name) 18 | model = ImageClassifier(image_encoder, classification_head) 19 | 20 | model.eval() 21 | 22 | dataset = get_dataset( 23 | dataset_name, 24 | model.val_preprocess, 25 | location=args.data_location, 26 | batch_size=args.batch_size 27 | ) 28 | dataloader = get_dataloader( 29 | dataset, is_train=False, args=args, image_encoder=None) 30 | device = args.device 31 | 32 | with torch.no_grad(): 33 | top1, correct, n = 0., 0., 0. 34 | for i, data in enumerate(tqdm.tqdm(dataloader)): 35 | data = maybe_dictionarize(data) 36 | x = data['images'].to(device) 37 | y = data['labels'].to(device) 38 | 39 | logits = utils.get_logits(x, model) 40 | 41 | pred = logits.argmax(dim=1, keepdim=True).to(device) 42 | 43 | correct += pred.eq(y.view_as(pred)).sum().item() 44 | 45 | n += y.size(0) 46 | 47 | top1 = correct / n 48 | 49 | metrics = {'top1': top1} 50 | print(f'Done evaluating on {dataset_name}. Accuracy: {100*top1:.2f}%') 51 | 52 | return metrics 53 | 54 | def evaluate(image_encoder, args): 55 | if args.eval_datasets is None: 56 | return 57 | info = vars(args) 58 | for i, dataset_name in enumerate(args.eval_datasets): 59 | print('Evaluating on', dataset_name) 60 | 61 | results = eval_single_dataset(image_encoder, dataset_name, args) 62 | 63 | if 'top1' in results: 64 | print(f"{dataset_name} Top-1 accuracy: {results['top1']:.4f}") 65 | for key, val in results.items(): 66 | if 'worst' in key or 'f1' in key.lower() or 'pm0' in key: 67 | print(f"{dataset_name} {key}: {val:.4f}") 68 | info[dataset_name + ':' + key] = val 69 | 70 | if args.results_db is not None: 71 | dirname = os.path.dirname(args.results_db) 72 | if dirname: 73 | os.makedirs(dirname, exist_ok=True) 74 | with open(args.results_db, 'a+') as f: 75 | f.write(json.dumps(info) + '\n') 76 | print(f'Results saved to {args.results_db}.') 77 | else: 78 | print('Results not saved (to do so, use --results_db to specify a path).') 79 | 80 | return info -------------------------------------------------------------------------------- /task_arithemetic/src/finetune.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | 4 | import torch 5 | 6 | from src.args import parse_arguments 7 | from src.datasets.common import get_dataloader, maybe_dictionarize 8 | from src.datasets.registry import get_dataset 9 | from src.eval import evaluate 10 | from src.modeling import ImageEncoder, ImageClassifier, MultiHeadImageClassifier 11 | from src.utils import cosine_lr, LabelSmoothing 12 | from src.heads import get_classification_head 13 | 14 | 15 | import src.datasets as datasets 16 | 17 | 18 | def finetune(args): 19 | train_dataset = args.train_dataset 20 | ckpdir = os.path.join(args.save, train_dataset) 21 | 22 | # Check if checkpoints already exist 23 | zs_path = os.path.join(args.save, train_dataset, 'checkpoint_0.pt') 24 | ft_path = os.path.join(args.save, train_dataset, f'checkpoint_{args.epochs}.pt') 25 | if os.path.exists(zs_path) and os.path.exists(ft_path): 26 | print(f'Skipping fine-tuning because {ft_path} exists.') 27 | return zs_path, ft_path 28 | 29 | assert train_dataset is not None, "Please provide a training dataset." 30 | if args.load is not None and args.load.endswith('pt'): 31 | image_encoder = ImageEncoder.load(args.load) 32 | else: 33 | print('Building image encoder.') 34 | image_encoder = ImageEncoder(args, keep_lang=False) 35 | 36 | classification_head = get_classification_head(args, train_dataset) 37 | 38 | model = ImageClassifier(image_encoder, classification_head) 39 | 40 | model.freeze_head() 41 | 42 | preprocess_fn = model.train_preprocess 43 | print_every = 100 44 | 45 | dataset = get_dataset( 46 | train_dataset, 47 | preprocess_fn, 48 | location=args.data_location, 49 | batch_size=args.batch_size 50 | ) 51 | num_batches = len(dataset.train_loader) 52 | 53 | devices = list(range(torch.cuda.device_count())) 54 | print('Using devices', devices) 55 | model = torch.nn.DataParallel(model, device_ids=devices) 56 | 57 | if args.ls > 0: 58 | loss_fn = LabelSmoothing(args.ls) 59 | else: 60 | loss_fn = torch.nn.CrossEntropyLoss() 61 | 62 | params = [p for p in model.parameters() if p.requires_grad] 63 | optimizer = torch.optim.AdamW(params, lr=args.lr, weight_decay=args.wd) 64 | 65 | scheduler = cosine_lr(optimizer, args.lr, args.warmup_length, args.epochs * num_batches) 66 | 67 | # Saving zero-shot model 68 | if args.save is not None: 69 | os.makedirs(ckpdir, exist_ok=True) 70 | model_path = os.path.join(ckpdir, f'zeroshot.pt') 71 | model.module.image_encoder.save(model_path) 72 | 73 | for epoch in range(args.epochs): 74 | model = model.cuda() 75 | model.train() 76 | data_loader = get_dataloader( 77 | dataset, is_train=True, args=args, image_encoder=None) 78 | 79 | for i, batch in enumerate(data_loader): 80 | start_time = time.time() 81 | 82 | step = i + epoch * num_batches 83 | scheduler(step) 84 | optimizer.zero_grad() 85 | 86 | batch = maybe_dictionarize(batch) 87 | inputs = batch['images'].to('cuda:0') 88 | labels = batch['labels'].to('cuda:0') 89 | data_time = time.time() - start_time 90 | 91 | logits = model(inputs) 92 | 93 | loss = loss_fn(logits, labels) 94 | 95 | loss.backward() 96 | 97 | torch.nn.utils.clip_grad_norm_(params, 1.0) 98 | 99 | optimizer.step() 100 | batch_time = time.time() - start_time 101 | 102 | if step % print_every == 0: 103 | percent_complete = 100 * i / len(data_loader) 104 | print( 105 | f"Train Epoch: {epoch} [{percent_complete:.0f}% {i}/{len(dataset.train_loader)}]\t" 106 | f"Loss: {loss.item():.6f}\tData (t) {data_time:.3f}\tBatch (t) {batch_time:.3f}", flush=True 107 | ) 108 | 109 | # Evaluate 110 | image_encoder = model.module.image_encoder 111 | evaluate(image_encoder, args) 112 | 113 | if args.save is not None: 114 | zs_path = os.path.join(ckpdir, 'zeroshot.pt') 115 | ft_path = os.path.join(ckpdir, 'finetuned.pt') 116 | image_encoder.save(ft_path) 117 | return zs_path, ft_path 118 | 119 | 120 | if __name__ == '__main__': 121 | data_location = '' 122 | models = ['ViT-B-32', 'ViT-B-16', 'ViT-L-14'] 123 | datasets = ['Cars', 'DTD', 'EuroSAT', 'GTSRB', 'MNIST', 'RESISC45', 'SUN397', 'SVHN'] 124 | epochs = { 125 | 'Cars': 35, 126 | 'DTD': 76, 127 | 'EuroSAT': 12, 128 | 'GTSRB': 11, 129 | 'MNIST': 5, 130 | 'RESISC45': 15, 131 | 'SUN397': 14, 132 | 'SVHN': 4, 133 | 'ImageNet': 4 134 | } 135 | 136 | for model in models: 137 | for dataset in datasets: 138 | print('='*100) 139 | print(f'Finetuning {model} on {dataset}') 140 | print('='*100) 141 | args = parse_arguments() 142 | args.lr = 1e-5 143 | args.epochs = epochs[dataset] 144 | args.data_location = data_location 145 | args.train_dataset = dataset + 'Val' 146 | args.batch_size = 128 147 | args.model = model 148 | args.save = f'checkpoints/{model}' 149 | finetune(args) 150 | -------------------------------------------------------------------------------- /task_arithemetic/src/heads.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from tqdm import tqdm 4 | 5 | import open_clip 6 | 7 | from src.datasets.templates import get_templates 8 | from src.datasets.registry import get_dataset 9 | 10 | from src.modeling import ClassificationHead, ImageEncoder 11 | 12 | 13 | def build_classification_head(model, dataset_name, template, data_location, device): 14 | template = get_templates(dataset_name) 15 | 16 | logit_scale = model.logit_scale 17 | dataset = get_dataset( 18 | dataset_name, 19 | None, 20 | location=data_location 21 | ) 22 | model.eval() 23 | model.to(device) 24 | 25 | print('Building classification head.') 26 | with torch.no_grad(): 27 | zeroshot_weights = [] 28 | for classname in tqdm(dataset.classnames): 29 | texts = [] 30 | for t in template: 31 | texts.append(t(classname)) 32 | texts = open_clip.tokenize(texts).to(device) # tokenize 33 | embeddings = model.encode_text(texts) # embed with text encoder 34 | embeddings /= embeddings.norm(dim=-1, keepdim=True) 35 | 36 | embeddings = embeddings.mean(dim=0, keepdim=True) 37 | embeddings /= embeddings.norm() 38 | 39 | zeroshot_weights.append(embeddings) 40 | 41 | zeroshot_weights = torch.stack(zeroshot_weights, dim=0).to(device) 42 | zeroshot_weights = torch.transpose(zeroshot_weights, 0, 2) 43 | 44 | zeroshot_weights *= logit_scale.exp() 45 | 46 | zeroshot_weights = zeroshot_weights.squeeze().float() 47 | zeroshot_weights = torch.transpose(zeroshot_weights, 0, 1) 48 | 49 | classification_head = ClassificationHead(normalize=True, weights=zeroshot_weights) 50 | 51 | return classification_head 52 | 53 | 54 | def get_classification_head(args, dataset): 55 | filename = os.path.join(args.save, f'head_{dataset}.pt') 56 | if os.path.exists(filename): 57 | print(f'Classification head for {args.model} on {dataset} exists at {filename}') 58 | return ClassificationHead.load(filename) 59 | print(f'Did not find classification head for {args.model} on {dataset} at {filename}, building one from scratch.') 60 | model = ImageEncoder(args, keep_lang=True).model 61 | template = get_templates(dataset) 62 | classification_head = build_classification_head(model, dataset, template, args.data_location, args.device) 63 | os.makedirs(args.save, exist_ok=True) 64 | classification_head.save(filename) 65 | return classification_head 66 | 67 | -------------------------------------------------------------------------------- /task_arithemetic/src/modeling.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import open_clip 4 | 5 | from src import utils 6 | 7 | 8 | class ImageEncoder(torch.nn.Module): 9 | def __init__(self, args, keep_lang=False): 10 | super().__init__() 11 | 12 | print(f'Loading {args.model} pre-trained weights.') 13 | if '__pretrained__' in args.model: 14 | name, pretrained = args.model.split('__pretrained__') 15 | else: 16 | name = args.model 17 | pretrained = 'openai' 18 | self.model, self.train_preprocess, self.val_preprocess = open_clip.create_model_and_transforms( 19 | name, pretrained=pretrained, cache_dir=args.openclip_cachedir) 20 | 21 | self.cache_dir = args.cache_dir 22 | 23 | if not keep_lang and hasattr(self.model, 'transformer'): 24 | delattr(self.model, 'transformer') 25 | 26 | def forward(self, images): 27 | assert self.model is not None 28 | return self.model.encode_image(images) 29 | 30 | def __call__(self, inputs): 31 | return self.forward(inputs) 32 | 33 | def save(self, filename): 34 | print(f'Saving image encoder to {filename}') 35 | utils.torch_save(self, filename) 36 | 37 | @classmethod 38 | def load(cls, model_name, filename): 39 | print(f'Loading image encoder from {filename}') 40 | state_dict = torch.load(filename) 41 | return cls.load(model_name, state_dict) 42 | 43 | @classmethod 44 | def load_from_state_dict(cls, model_name, state_dict): 45 | self.model, self.train_preprocess, self.val_preprocess = open_clip.create_model_and_transforms( 46 | name, pretrained=pretrained, cache_dir=args.openclip_cachedir) 47 | self.model.load_from_state_dict(state_dict) 48 | 49 | 50 | 51 | 52 | class ClassificationHead(torch.nn.Linear): 53 | def __init__(self, normalize, weights, biases=None): 54 | output_size, input_size = weights.shape 55 | super().__init__(input_size, output_size) 56 | self.normalize = normalize 57 | if weights is not None: 58 | self.weight = torch.nn.Parameter(weights.clone()) 59 | if biases is not None: 60 | self.bias = torch.nn.Parameter(biases.clone()) 61 | else: 62 | self.bias = torch.nn.Parameter(torch.zeros_like(self.bias)) 63 | 64 | def forward(self, inputs): 65 | if self.normalize: 66 | inputs = inputs / inputs.norm(dim=-1, keepdim=True) 67 | return super().forward(inputs) 68 | 69 | def __call__(self, inputs): 70 | return self.forward(inputs) 71 | 72 | def save(self, filename): 73 | print(f'Saving classification head to {filename}') 74 | utils.torch_save(self, filename) 75 | 76 | @classmethod 77 | def load(cls, filename): 78 | print(f'Loading classification head from {filename}') 79 | return utils.torch_load(filename) 80 | 81 | 82 | class ImageClassifier(torch.nn.Module): 83 | def __init__(self, image_encoder, classification_head): 84 | super().__init__() 85 | self.image_encoder = image_encoder 86 | self.classification_head = classification_head 87 | if self.image_encoder is not None: 88 | self.train_preprocess = self.image_encoder.train_preprocess 89 | self.val_preprocess = self.image_encoder.val_preprocess 90 | 91 | def freeze_head(self): 92 | self.classification_head.weight.requires_grad_(False) 93 | self.classification_head.bias.requires_grad_(False) 94 | 95 | def forward(self, inputs): 96 | features = self.image_encoder(inputs) 97 | outputs = self.classification_head(features) 98 | return outputs 99 | 100 | def __call__(self, inputs): 101 | return self.forward(inputs) 102 | 103 | def save(self, filename): 104 | print(f'Saving image classifier to {filename}') 105 | utils.torch_save(self, filename) 106 | 107 | @classmethod 108 | def load(cls, filename): 109 | print(f'Loading image classifier from {filename}') 110 | return utils.torch_load(filename) 111 | 112 | 113 | class MultiHeadImageClassifier(torch.nn.Module): 114 | def __init__(self, image_encoder, classification_heads): 115 | super().__init__() 116 | self.image_encoder = image_encoder 117 | self.classification_heads = torch.nn.ModuleList(classification_heads) 118 | if self.image_encoder is not None: 119 | self.train_preprocess = self.image_encoder.train_preprocess 120 | self.val_preprocess = self.image_encoder.val_preprocess 121 | 122 | def freeze_head(self): 123 | for idx in range(len(self.classification_heads)): 124 | self.classification_heads[idx].weight.requires_grad_(False) 125 | self.classification_heads[idx].bias.requires_grad_(False) 126 | 127 | def forward(self, inputs, head_idx): 128 | features = self.image_encoder(inputs) 129 | outputs = self.classification_heads[head_idx](features) 130 | return outputs 131 | 132 | def __call__(self, inputs, head_idx): 133 | return self.forward(inputs, head_idx) 134 | 135 | def save(self, filename): 136 | print(f'Saving image classifier to {filename}') 137 | utils.torch_save(self, filename) 138 | 139 | @classmethod 140 | def load(cls, filename): 141 | print(f'Loading image classifier from {filename}') 142 | return utils.torch_load(filename) 143 | -------------------------------------------------------------------------------- /task_arithemetic/src/task_vectors.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class TaskVector(): 5 | def __init__(self, pretrained_checkpoint=None, finetuned_checkpoint=None, vector=None): 6 | """Initializes the task vector from a pretrained and a finetuned checkpoints. 7 | 8 | This can either be done by passing two state dicts (one corresponding to the 9 | pretrained model, and another to the finetuned model), or by directly passying in 10 | the task vector state dict. 11 | """ 12 | if vector is not None: 13 | self.vector = vector 14 | else: 15 | assert pretrained_checkpoint is not None and finetuned_checkpoint is not None 16 | with torch.no_grad(): 17 | pretrained_state_dict = torch.load(pretrained_checkpoint).state_dict() 18 | finetuned_state_dict = torch.load(finetuned_checkpoint).state_dict() 19 | self.vector = {} 20 | for key in pretrained_state_dict: 21 | if pretrained_state_dict[key].dtype in [torch.int64, torch.uint8]: 22 | continue 23 | self.vector[key] = finetuned_state_dict[key] - pretrained_state_dict[key] 24 | 25 | def __add__(self, other): 26 | """Add two task vectors together.""" 27 | with torch.no_grad(): 28 | new_vector = {} 29 | for key in self.vector: 30 | if key not in other.vector: 31 | print(f'Warning, key {key} is not present in both task vectors.') 32 | continue 33 | new_vector[key] = self.vector[key] + other.vector[key] 34 | return TaskVector(vector=new_vector) 35 | 36 | def __radd__(self, other): 37 | if other is None or isinstance(other, int): 38 | return self 39 | return self.__add__(other) 40 | 41 | def __neg__(self): 42 | """Negate a task vector.""" 43 | with torch.no_grad(): 44 | new_vector = {} 45 | for key in self.vector: 46 | new_vector[key] = - self.vector[key] 47 | return TaskVector(vector=new_vector) 48 | 49 | def apply_to(self, pretrained_checkpoint, scaling_coef=1.0): 50 | """Apply a task vector to a pretrained model.""" 51 | with torch.no_grad(): 52 | pretrained_model = torch.load(pretrained_checkpoint) 53 | new_state_dict = {} 54 | pretrained_state_dict = pretrained_model.state_dict() 55 | for key in pretrained_state_dict: 56 | if key not in self.vector: 57 | print(f'Warning: key {key} is present in the pretrained state dict but not in the task vector') 58 | continue 59 | new_state_dict[key] = pretrained_state_dict[key] + scaling_coef * self.vector[key] 60 | pretrained_model.load_state_dict(new_state_dict, strict=False) 61 | return pretrained_model 62 | 63 | -------------------------------------------------------------------------------- /task_arithemetic/src/utils/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zzp1012/Cross-Task-Linearity/0092799a5b74718673d18f3a4d14ac52d19b03eb/task_arithemetic/src/utils/.DS_Store -------------------------------------------------------------------------------- /task_arithemetic/src/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from src.utils.utils import * 2 | 3 | import os 4 | import datetime 5 | import argparse 6 | import logging 7 | import random 8 | import numpy as np 9 | import torch 10 | 11 | def get_datetime() -> str: 12 | """get the date. 13 | Returns: 14 | date (str): the date. 15 | """ 16 | datetime_ = datetime.datetime.now().strftime("%m%d-%H%M%S") 17 | return datetime_ 18 | 19 | 20 | def set_logger(save_path: str) -> None: 21 | """set the logger. 22 | Args: 23 | save_path(str): the path for saving logfile.txt 24 | name(str): the name of the logger 25 | verbose(bool): if true, will print to console. 26 | 27 | Returns: 28 | None 29 | """ 30 | # set the logger 31 | logfile = os.path.join(save_path, "logfile.txt") 32 | logging.basicConfig(filename=logfile, 33 | filemode="w+", 34 | format='%(name)-12s: %(levelname)-8s %(message)s', 35 | datefmt="%H:%M:%S", 36 | level=logging.INFO) 37 | # define a Handler which writes DEBUG messages or higher to the sys.stderr 38 | console = logging.StreamHandler() 39 | console.setLevel(logging.DEBUG) 40 | # tell the handler to use this format 41 | console.setFormatter(logging.Formatter( 42 | '%(name)-12s: %(levelname)-8s %(message)s')) 43 | # add the handler to the root logger 44 | logging.getLogger().addHandler(console) 45 | 46 | 47 | def get_logger(name:str, 48 | verbose:bool = True) -> logging.Logger: 49 | """get the logger. 50 | Args: 51 | name (str): the name of the logger 52 | verbose (bool): if true, will print to console. 53 | Returns: 54 | logger (logging.Logger) 55 | """ 56 | logger = logging.getLogger(name) 57 | 58 | logger.setLevel(logging.DEBUG) 59 | if not verbose: 60 | logger.setLevel(logging.INFO) 61 | return logger 62 | 63 | 64 | def set_seed(seed: int = 0) -> None: 65 | """set the random seed for multiple packages. 66 | Args: 67 | seed (int): the seed. 68 | 69 | Returns: 70 | None 71 | """ 72 | random.seed(seed) 73 | os.environ['PYTHONHASHSEED'] = str(seed) 74 | np.random.seed(seed) 75 | torch.manual_seed(seed) 76 | torch.cuda.manual_seed(seed) 77 | torch.backends.cudnn.deterministic = True 78 | 79 | 80 | def set_device(device: int) -> torch.device: 81 | """set GPU device. 82 | Args: 83 | device (int) the number of GPU device 84 | 85 | Returns: 86 | device (torch.device) 87 | """ 88 | logger = get_logger(__name__) 89 | if torch.cuda.is_available(): 90 | if device >= torch.cuda.device_count(): 91 | logger.error("CUDA error, invalid device ordinal") 92 | exit(1) 93 | else: 94 | logger.error("Plz choose other machine with GPU to run the program") 95 | exit(1) 96 | os.environ['CUDA_VISIBLE_DEVICES'] = str(device) 97 | device = torch.device("cuda:" + str(device)) 98 | logger.info(device) 99 | return device 100 | 101 | 102 | def log_settings(args: argparse.Namespace, config: dict = {}) -> None: 103 | """log the settings of the program. 104 | Args: 105 | args (argparse.Namespace): the arguments. 106 | config (dict): the config. 107 | """ 108 | logger = get_logger(__name__) 109 | hyperparameters = { 110 | **args.__dict__, 111 | **{key: value for key, value in config.items() \ 112 | if key.isupper() and type(value) in [int, float, str, bool, dict]} 113 | } 114 | logger.info(hyperparameters) 115 | 116 | 117 | def save_current_src(save_path: str) -> None: 118 | """save the current src. 119 | Args: 120 | save_path (str): the path to save the current src. 121 | src_path (str): the path to the current src. 122 | Returns: 123 | None 124 | """ 125 | logger = get_logger(__name__) 126 | logger.info("save the current src") 127 | src_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 128 | os.system("cp -r {} {}".format(src_path, save_path)) 129 | script_path = os.path.join(os.path.dirname(src_path), "scripts") 130 | os.system("cp -r {} {}".format(script_path, save_path)) -------------------------------------------------------------------------------- /task_arithemetic/src/utils/avgmeter.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import torch 3 | from collections import defaultdict 4 | from typing import Optional 5 | 6 | class AverageMeter(object): 7 | """Computes and stores the average and current value. 8 | 9 | Examples:: 10 | >>> # Initialize a meter to record loss 11 | >>> losses = AverageMeter() 12 | >>> # Update meter after every minibatch update 13 | >>> losses.update(loss_value, batch_size) 14 | """ 15 | def __init__(self): 16 | self.reset() 17 | 18 | def reset(self): 19 | self.val = 0 20 | self.avg = 0 21 | self.sum = 0 22 | self.count = 0 23 | 24 | def update(self, val, n=1): 25 | self.val = val 26 | self.sum += val * n 27 | self.count += n 28 | self.avg = self.sum / self.count 29 | 30 | 31 | class MetricMeter(object): 32 | """A collection of metrics. 33 | 34 | Source: https://github.com/KaiyangZhou/Dassl.pytorch 35 | 36 | Examples:: 37 | >>> # 1. Create an instance of MetricMeter 38 | >>> metric = MetricMeter() 39 | >>> # 2. Update using a dictionary as input 40 | >>> input_dict = {'loss_1': value_1, 'loss_2': value_2} 41 | >>> metric.update(input_dict) 42 | >>> # 3. Convert to string and print 43 | >>> print(str(metric)) 44 | """ 45 | def __init__(self, delimiter='\n\t'): 46 | self.delimiter = delimiter 47 | self.reset() 48 | 49 | def reset(self): 50 | self.meters = defaultdict(AverageMeter) 51 | 52 | def update(self, input_dict: dict, n: int=1): 53 | """Update the meter with a dictionary. 54 | 55 | Args: 56 | input_dict (dict): A dictionary of metrics. 57 | n (int): The number of samples in the input. 58 | """ 59 | if input_dict is None: 60 | return 61 | 62 | if not isinstance(input_dict, dict): 63 | raise TypeError( 64 | 'Input to MetricMeter.update() must be a dictionary' 65 | ) 66 | 67 | for k, v in input_dict.items(): 68 | if isinstance(v, torch.Tensor): 69 | v = v.item() 70 | self.meters[k].update(v, n) 71 | 72 | def __str__(self): 73 | output_str = [] 74 | for name, meter in self.meters.items(): 75 | output_str.append( 76 | '{} {:.4f} ({:.4f})'.format(name, meter.val, meter.avg) 77 | ) 78 | return self.delimiter + self.delimiter.join(output_str) 79 | 80 | 81 | class MetricTracker(object): 82 | """track metrics over time and compute average 83 | """ 84 | def __init__(self): 85 | self.reset() 86 | 87 | def reset(self): 88 | """reset metrics 89 | """ 90 | self.metrics = defaultdict(list) 91 | self.meter = MetricMeter() 92 | 93 | def update(self, input_dict: dict, n: int=1): 94 | """update metrics 95 | 96 | Args: 97 | input_dict (dict): A dictionary of metrics. 98 | n (int): The number of samples in the input. 99 | """ 100 | self.meter.update(input_dict, n) 101 | 102 | def track(self, input_dict: Optional[dict]=None): 103 | """track metrics 104 | """ 105 | if input_dict is not None: 106 | for k, v in input_dict.items(): 107 | assert k not in self.meter.meters, \ 108 | f'key {k} already exists in meter.keys() {self.meter.meters.keys()}' 109 | self.metrics[k].append(v) 110 | 111 | for k, v in self.meter.meters.items(): 112 | self.metrics[k].append(v.avg) 113 | self.meter.reset() 114 | 115 | def __str__(self): 116 | return str(self.meter) 117 | 118 | def save_to_csv(self, filename: str): 119 | """save metrics to csv file 120 | """ 121 | df = pd.DataFrame.from_dict(self.metrics) 122 | df.to_csv(filename, index=False) 123 | 124 | def get_metrics(self): 125 | """get metrics 126 | """ 127 | return self.metrics -------------------------------------------------------------------------------- /task_arithemetic/src/utils/dissimilarity.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class DissimilarityMetric: 4 | """compute the distance of two featuremaps 5 | """ 6 | def __init__(self, metric): 7 | self.__metric = metric 8 | 9 | def __call__(self, A, B, **kwargs): 10 | if self.__metric == "vanilla": 11 | return self.__vanilla(A, B) 12 | elif self.__metric == "cosine": 13 | return self.__cosine_similarity(A, B, **kwargs) 14 | elif self.__metric == "abs_div": 15 | return self.__abs_div(A, B, **kwargs) 16 | 17 | A = self.__preprocess(A) 18 | B = self.__preprocess(B) 19 | if self.__metric == "lin_cka": 20 | return self.__lin_cka_dist(A, B) 21 | elif self.__metric == "lin_cka_prime": 22 | return self.__lin_cka_prime_dist(A, B) 23 | elif self.__metric == "procrustes": 24 | return self.__procrustes(A, B) 25 | else: 26 | raise ValueError("Unknown metric") 27 | 28 | def __preprocess(self, X: torch.Tensor) -> torch.Tensor: 29 | """preprocess the featuremap 30 | 1. flatten the featuremap 31 | 2. transpose the featuremap 32 | 3. center the featuremap 33 | 4. normalize the featuremap 34 | 35 | Args: 36 | X (torch.Tensor): the featuremap, shape: (N, ...) 37 | 38 | Return: 39 | X (torch.Tensor): the preprocessed featuremap, shape: (N, D) 40 | """ 41 | # flatten the featuremap 42 | X = X.view(X.shape[0], -1) # shape: (N, D) 43 | # transpose the featuremap 44 | X = X.t() # shape: (D, N) 45 | # centering 46 | X = X - X.mean(dim=-1, keepdim=True) # shape: (D, N) 47 | # normalize with the Frobenius norm 48 | X = X / torch.norm(X, p="fro") # shape: (D, N) 49 | return X 50 | 51 | def __lin_cka_dist(self, A: torch.Tensor, B: torch.Tensor) -> torch.Tensor: 52 | """compute the linear CKA distance between two featuremaps 53 | Args: 54 | A (torch.Tensor): the featuremap A, shape: (D, N) 55 | B (torch.Tensor): the featuremap B, shape: (D', N) 56 | 57 | Return: 58 | dist (torch.Tensor): the distance between A and B 59 | """ 60 | similarity = torch.norm(B @ A.t(), p="fro") ** 2 61 | normalization = torch.norm(A @ A.t(), p="fro") * torch.norm(B @ B.t(), p="fro") 62 | return 1 - similarity / normalization 63 | 64 | def __lin_cka_prime_dist(self, A: torch.Tensor, B: torch.Tensor) -> torch.Tensor: 65 | """Computes Linear CKA prime distance bewteen representations A and B 66 | The version here is suited to D, D' >> N 67 | 68 | Args: 69 | A (torch.Tensor): the featuremap A, shape: (D, N) 70 | B (torch.Tensor): the featuremap B, shape: (D', N) 71 | 72 | Return: 73 | dist (torch.Tensor): the distance between A and B 74 | """ 75 | if A.shape[0] > A.shape[1]: # D > N 76 | At_A = A.t() @ A # shape: (N, N) O(n * n * a) 77 | Bt_B = B.t() @ B # shape: (N, N) O(n * n * a) 78 | numerator = torch.sum((At_A - Bt_B) ** 2) 79 | denominator = torch.sum(A ** 2) ** 2 + torch.sum(B ** 2) ** 2 80 | return numerator / denominator 81 | else: 82 | similarity = torch.norm(B @ A.t(), p="fro") ** 2 83 | denominator = torch.sum(A ** 2) ** 2 + torch.sum(B ** 2) ** 2 84 | return 1 - 2 * similarity / denominator 85 | 86 | def __procrustes(self, A: torch.Tensor, B: torch.Tensor) -> torch.Tensor: 87 | """Compute the Procrustes distance between two featuremaps 88 | Args: 89 | A (torch.Tensor): the featuremap A, shape: (D, N) 90 | B (torch.Tensor): the featuremap B, shape: (D', N) 91 | 92 | Return: 93 | dist (torch.Tensor): the distance between A and B 94 | """ 95 | A_sq_frob = torch.sum(A ** 2) 96 | B_sq_frob = torch.sum(B ** 2) 97 | nuc = torch.norm(A @ B.t(), p="nuc") # O(p * p * n) 98 | return A_sq_frob + B_sq_frob - 2 * nuc 99 | 100 | def __vanilla(self, A: torch.Tensor, B: torch.Tensor) -> torch.Tensor: 101 | """compute the vanilla distance between two featuremaps, Frobenius norm 102 | 103 | Args: 104 | A (torch.Tensor): the featuremap A, shape: (N, D) 105 | B (torch.Tensor): the featuremap B, shape: (N, D) 106 | """ 107 | assert A.shape == B.shape, \ 108 | f"the shape of A - {A.shape} should be the same as the shape of B - {B.shape}" 109 | A = A.view(A.shape[0], -1) # shape: (N, D) 110 | B = B.view(B.shape[0], -1) # shape: (N, D) 111 | 112 | def norm_square(A: torch.Tensor) -> torch.Tensor: 113 | return torch.sum(A ** 2) # shape: (1, ) 114 | 115 | return norm_square(A - B) / torch.norm(A, p="fro") / torch.norm(B, p="fro") 116 | 117 | def __abs_div(self, A: torch.Tensor, B: torch.Tensor, **kwargs) -> torch.Tensor: 118 | get_coef = kwargs.get("get_coef", False) 119 | 120 | assert A.shape == B.shape, \ 121 | f"the shape of A - {A.shape} should be the same as the shape of B - {B.shape}" 122 | A = A.view(A.shape[0], -1).double() # shape: (N, D) 123 | B = B.view(B.shape[0], -1).double() # shape: (N, D) 124 | 125 | dist = torch.sum(torch.abs(A)) / torch.sum(torch.abs(B)) 126 | 127 | coef = 0. 128 | 129 | if get_coef: 130 | return dist, coef 131 | else: 132 | return dist 133 | 134 | def __cosine_similarity(self, A: torch.Tensor, B: torch.Tensor, **kwargs) -> torch.Tensor: 135 | """compute the cosine similarity between two matrices 136 | 137 | dist = 1 - / (||A|| * ||B||) 138 | 139 | Args: 140 | A (torch.Tensor): the featuremap A, shape: (N, D) 141 | B (torch.Tensor): the featuremap B, shape: (N, D) 142 | kwargs: the keyword arguments 143 | 144 | Return: 145 | dist (torch.Tensor): the distance between A and B 146 | """ 147 | get_coef = kwargs.get("get_coef", False) 148 | 149 | assert A.shape == B.shape, \ 150 | f"the shape of A - {A.shape} should be the same as the shape of B - {B.shape}" 151 | 152 | A = A.view(A.shape[0], -1).double() # shape: (N, D) 153 | B = B.view(B.shape[0], -1).double() # shape: (N, D) 154 | 155 | # compute the frobenius inner product of A and B 156 | inner_product = torch.sum(A * B) # shape: (1, 1) 157 | # compute the frobenius norm of A and B 158 | A_norm = torch.norm(A, p="fro") # shape: (1, 1) 159 | B_norm = torch.norm(B, p="fro") # shape: (1, 1) 160 | 161 | # cal the distance 162 | dist = 1 - torch.abs(inner_product) / (A_norm * B_norm) 163 | 164 | # compute the coefficient 165 | coef = inner_product / (B_norm ** 2) 166 | 167 | assert torch.abs(inner_product) <= A_norm * B_norm * (1 + 1e-10), \ 168 | f"the inner product - {inner_product} should be less than the product of the norm - {A_norm * B_norm}" 169 | 170 | if get_coef: 171 | return dist, coef 172 | else: 173 | return dist 174 | 175 | 176 | class DissimilarityMetricOverSamples: 177 | """compute the distance of two featuremaps, for each sample 178 | """ 179 | def __init__(self, metric): 180 | self.__metric = metric 181 | 182 | def __call__(self, A, B, **kwargs): 183 | if self.__metric == "vanilla": 184 | return self.__vanilla(A, B) 185 | elif self.__metric == "cosine": 186 | return self.__cosine_similarity(A, B, **kwargs) 187 | elif self.__metric == "abs_div": 188 | return self.__abs_div(A, B, **kwargs) 189 | 190 | def __vanilla(self, A: torch.Tensor, B: torch.Tensor) -> torch.Tensor: 191 | """compute the vanilla distance between two featuremaps, Frobenius norm 192 | 193 | Args: 194 | A (torch.Tensor): the featuremap A, shape: (N, D) 195 | B (torch.Tensor): the featuremap B, shape: (N, D) 196 | """ 197 | assert A.shape == B.shape, \ 198 | f"the shape of A - {A.shape} should be the same as the shape of B - {B.shape}" 199 | A = A.view(A.shape[0], -1) # shape: (N, D) 200 | B = B.view(B.shape[0], -1) # shape: (N, D) 201 | 202 | def norm_square(A: torch.Tensor) -> torch.Tensor: 203 | return torch.sum(A ** 2, dim=-1) # shape: (N, ) 204 | return norm_square(A - B) / torch.norm(A, p="fro", dim=-1) / torch.norm(B, p="fro", dim=-1) # shape: (N, ) 205 | 206 | def __abs_div(self, A: torch.Tensor, B: torch.Tensor, **kwargs) -> torch.Tensor: 207 | get_coef = kwargs.get("get_coef", False) 208 | 209 | assert A.shape == B.shape, \ 210 | f"the shape of A - {A.shape} should be the same as the shape of B - {B.shape}" 211 | A = A.view(A.shape[0], -1).double() # shape: (N, D) 212 | B = B.view(B.shape[0], -1).double() # shape: (N, D) 213 | 214 | dist = torch.sum(torch.abs(A),dim=-1) / torch.sum(torch.abs(B),dim=-1) 215 | 216 | coef = 0. 217 | 218 | if get_coef: 219 | return dist, coef 220 | else: 221 | return dist 222 | 223 | def __cosine_similarity(self, A: torch.Tensor, B: torch.Tensor, **kwargs) -> torch.Tensor: 224 | """compute the cosine similarity between two matrices 225 | 226 | dist = 1 - / (||A|| * ||B||) 227 | 228 | Args: 229 | A (torch.Tensor): the featuremap A, shape: (N, D) 230 | B (torch.Tensor): the featuremap B, shape: (N, D) 231 | kwargs: the keyword arguments 232 | 233 | Return: 234 | dist (torch.Tensor): the distance between A and B 235 | """ 236 | get_coef = kwargs.get("get_coef", False) 237 | 238 | assert A.shape == B.shape, \ 239 | f"the shape of A - {A.shape} should be the same as the shape of B - {B.shape}" 240 | 241 | A = A.view(A.shape[0], -1).double() # shape: (N, D) 242 | B = B.view(B.shape[0], -1).double() # shape: (N, D) 243 | 244 | # compute the frobenius inner product of A and B 245 | inner_product = torch.sum(A * B, dim=-1) # shape: (N, ) 246 | # compute the frobenius norm of A and B 247 | A_norm = torch.norm(A, p="fro", dim=-1) # shape: (N, ) 248 | B_norm = torch.norm(B, p="fro", dim=-1) # shape: (N, ) 249 | 250 | # cal the distance 251 | dist = 1 - torch.abs(inner_product) / (A_norm * B_norm) # shape: (N, ) 252 | 253 | # compute the coefficient 254 | coef = inner_product / (B_norm ** 2) # shape: (N, ) 255 | 256 | if get_coef: 257 | return dist, coef 258 | else: 259 | return dist 260 | 261 | 262 | if __name__ == "__main__": 263 | # set the seed 264 | torch.manual_seed(0) 265 | A = torch.randn(2, 3, 4, 5) 266 | B = torch.randn(2, 3, 4, 5) 267 | 268 | metric = DissimilarityMetric("cosine") 269 | print(metric(A, B)) -------------------------------------------------------------------------------- /task_arithemetic/src/utils/featuremap.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.utils.data import DataLoader 4 | from functools import reduce 5 | from collections import defaultdict 6 | from tqdm import tqdm 7 | 8 | # import internal libs 9 | from src.utils import get_logger, get_logits 10 | from src.datasets.common import maybe_dictionarize 11 | 12 | class FeatureMap: 13 | """class used to extract feature map from a given model 14 | """ 15 | def __init__(self, 16 | device: torch.device, 17 | model: nn.Module,) -> None: 18 | """initialize the feature map extractor 19 | 20 | Args: 21 | device (torch.device): device to run the model 22 | model (nn.Module): model to extract feature map 23 | """ 24 | self.__device = device 25 | self.__model = model.to(device) 26 | self.__model_name = model.__class__.__name__ 27 | 28 | # make sure the model is in eval mode 29 | self.__model.eval() 30 | # initialize the hooks 31 | self.__hooks = {} 32 | self.__featuremaps = defaultdict(list) 33 | 34 | def __get_module(self, module_name: str) -> nn.Module: 35 | """get the module from the model 36 | 37 | Args: 38 | module_name (str): name of the module 39 | 40 | Returns: 41 | nn.Module: the module 42 | """ 43 | return reduce(getattr, module_name.split('.'), self.__model) 44 | 45 | def __get_conv_fc_layer_names(self) -> list: 46 | """get the names of the conv layers and the fc layers 47 | 48 | Returns: 49 | list: names of the conv layers 50 | """ 51 | layer_names = [] 52 | for name, module in self.__model.named_modules(): 53 | if isinstance(module, (nn.Conv2d, nn.Linear)): 54 | layer_names.append(name) 55 | return layer_names 56 | 57 | def __get_conv_fc_modules(self, layer_names: list = None) -> dict: 58 | """get the conv and fc modules 59 | 60 | Args: 61 | layer_names (list, optional): names of the layers. Defaults to None. 62 | 63 | Returns: 64 | dict: conv modules 65 | """ 66 | if layer_names is None: 67 | layer_names = self.__get_conv_fc_layer_names() 68 | modules = dict() 69 | for name in layer_names: 70 | modules[name] = self.__get_module(name) 71 | return modules 72 | 73 | def __register_single_hook(self, 74 | layer_name: str, 75 | module: nn.Module, 76 | get_input: bool) -> None: 77 | """register the hook for a single module 78 | 79 | Args: 80 | layer_name (str): name of the layer 81 | module (nn.Module): module to register the hook 82 | get_input (bool): whether to get the input of the model 83 | """ 84 | def forward_hook(module, input, output): 85 | if get_input: 86 | assert len(input) == 1, "only support single input" 87 | self.__featuremaps[layer_name].append(input[0].detach().cpu()) 88 | else: 89 | if isinstance(output, tuple): 90 | self.__featuremaps[layer_name].append(output[0].detach().cpu()) 91 | else: 92 | self.__featuremaps[layer_name].append(output.detach().cpu()) 93 | self.__hooks[layer_name] = module.register_forward_hook(forward_hook) 94 | 95 | def __register_hooks(self, 96 | layer_names: list = None, 97 | get_input: bool = False) -> None: 98 | """register the hooks 99 | 100 | Args: 101 | layer_names (list, optional): names of the layers to register the hooks. Defaults to None. 102 | get_input (bool, optional): whether to get the input of the model. Defaults to False. 103 | """ 104 | conv_modules = self.__get_conv_fc_modules(layer_names) 105 | for name, module in conv_modules.items(): 106 | self.__register_single_hook(name, module, get_input) 107 | 108 | def __remove_hooks(self) -> None: 109 | """remove the hooks 110 | """ 111 | for hook in self.__hooks.values(): 112 | hook.remove() 113 | self.__hooks = {} 114 | 115 | def __remove_featuremaps(self) -> None: 116 | """remove the feature maps 117 | """ 118 | self.__featuremaps = defaultdict(list) 119 | 120 | def __convert_featuremaps_to_tensor(self) -> dict: 121 | """convert the feature maps to tensor 122 | 123 | Returns: 124 | dict: feature maps 125 | """ 126 | featuremaps = dict() 127 | for key, value in self.__featuremaps.items(): 128 | if "visual.transformer.resblocks" in key and len(value[0].shape) == 3: 129 | value = [value[i].permute(1, 0, 2) for i in range(len(value))] 130 | featuremaps[key] = torch.cat(value, dim=0) 131 | return featuremaps 132 | 133 | def get_featuremaps(self, 134 | dataloader: DataLoader, 135 | layer_names: list = None, 136 | get_input: bool = False) -> tuple: 137 | """get the feature maps from the conv layers 138 | tips: use the hook to get the feature maps 139 | 140 | Args: 141 | dataloader (DataLoader): data loader to load the data 142 | layer_names (list): names of the layers to extract feature map. Defaults to None. 143 | get_input (bool): whether to get the input of the model. Defaults to False. 144 | 145 | Returns: 146 | (feature maps, preds): feature maps and preds 147 | """ 148 | logger = get_logger( 149 | f"{__name__}.{self.__class__.__name__}.get_featuremaps" 150 | ) 151 | 152 | # clear the feature maps 153 | self.__remove_featuremaps() 154 | 155 | # register the hooks 156 | self.__register_hooks(layer_names, get_input) 157 | 158 | # extract the feature maps 159 | with torch.no_grad(): 160 | features_lst, logits_lst = [], [] 161 | for _, data in enumerate(tqdm(dataloader)): 162 | # move the data to the device 163 | data = maybe_dictionarize(data) 164 | x = data['images'].to(self.__device) 165 | # forward 166 | features = self.__model.image_encoder(x) 167 | logits = self.__model.classification_head(features) 168 | # get the logits 169 | logits_lst.append(logits.detach().cpu()) 170 | # get the features 171 | features_lst.append(features.detach().cpu()) 172 | 173 | # concat the features & logits 174 | features = torch.cat(features_lst, dim=0) 175 | logits = torch.cat(logits_lst, dim=0) 176 | 177 | # convert the feature maps to tensor 178 | featuremaps = self.__convert_featuremaps_to_tensor() 179 | featuremaps = {**featuremaps, "features": features, "logits": logits} 180 | # clear the feature_maps 181 | self.__remove_featuremaps() 182 | # remove the hooks 183 | self.__remove_hooks() 184 | 185 | logger.info(f"feature maps: {featuremaps.keys()}") 186 | return featuremaps -------------------------------------------------------------------------------- /task_arithemetic/src/utils/load.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import torch 4 | from typing import Dict 5 | 6 | # import internal libs 7 | from src.task_vectors import TaskVector 8 | 9 | def decode_model_name(model_name: str) -> tuple: 10 | """Decode the model name into scaling coefficient and task vectors info. 11 | 12 | Args: 13 | model_name: The model name. 14 | In the format of "scaling_coef+dataset1+dataset2+...+datasetN" or "scaling_coef_dataset1_dataset2_..._datasetN". 15 | 16 | Returns: 17 | scaling_coef: The scaling coefficient. 18 | task_vectors_info: A dictionary containing the task vectors information. 19 | where keys are the task names, and values are the labels indicating "add" or "minus". 20 | """ 21 | # extract the scaling coefficient and task vectors info 22 | scaling_coef = re.search(r"(\d+\.\d+)", model_name).group() 23 | assert model_name[:len(scaling_coef)] == scaling_coef, \ 24 | f"Invalid model_name {model_name}." 25 | 26 | # extract the task vectors info 27 | model_name = model_name.replace(scaling_coef, "") 28 | task_vectors_info = {} 29 | while len(model_name) > 0: 30 | # first char is the label, the first combination of letter and digit is the dataset name 31 | label, dataset_name = model_name[0], re.search(r"[a-zA-Z0-9]+", model_name).group() 32 | if label == "+": 33 | task_vectors_info[dataset_name] = "add" 34 | elif label == "_": 35 | task_vectors_info[dataset_name] = "minus" 36 | else: 37 | raise ValueError(f"Invalid model_name {model_name}.") 38 | model_name = model_name[1+len(dataset_name):] 39 | 40 | return float(scaling_coef), task_vectors_info 41 | 42 | 43 | def load_image_encoder(model_root: str, 44 | task_vectors_info: Dict[str, str], 45 | scaling_coef: float) -> torch.nn.Module: 46 | """Loads the image encoder from a pretrained model. 47 | 48 | Args: 49 | model_root: Path to the folder containing the pretrained model. 50 | task_vectors_info: A dictionary containing the task vectors information. 51 | where keys are the task names, and values are the labels indicating "add" or "minus". 52 | 53 | Returns: 54 | The image encoder. 55 | """ 56 | assert scaling_coef >= -1.0 and scaling_coef <= 1.0, \ 57 | f"scaling_coef should be in [-1.0, 1.0], but got {scaling_coef}." 58 | 59 | # create an empty task vector 60 | pretrained_checkpoint = os.path.join(model_root, "zeroshot.pt") 61 | task_vector = TaskVector(pretrained_checkpoint, pretrained_checkpoint) 62 | 63 | # iterate over the task vectors info 64 | for dataset_name, label in task_vectors_info.items(): 65 | # get the task vector 66 | finetuned_checkpoint = os.path.join(model_root, f"{dataset_name}/finetuned.pt") 67 | task_vector_curr = TaskVector(pretrained_checkpoint, finetuned_checkpoint) 68 | 69 | # add or minus the task vector 70 | if label == "add": 71 | task_vector = task_vector + task_vector_curr 72 | elif label == "minus": 73 | task_vector = task_vector + (-task_vector_curr) 74 | else: 75 | raise ValueError(f"Unknown label {label}.") 76 | 77 | # apply the task vector to the pretrained model 78 | return task_vector.apply_to(pretrained_checkpoint, scaling_coef) 79 | 80 | 81 | def load_image_encoder_single_task(model_root: str, 82 | dataset_name: str, 83 | scaling_coef: float) -> torch.nn.Module: 84 | """Loads the image encoder from a pretrained model. 85 | 86 | Args: 87 | model_root: Path to the folder containing the pretrained model. 88 | task_vectors_info: A dictionary containing the task vectors information. 89 | where keys are the task names, and values are the labels indicating "add" or "minus". 90 | scaling_coef: The scaling coefficient. 91 | 92 | Returns: 93 | The image encoder. 94 | """ 95 | assert scaling_coef >= -1.0 and scaling_coef <= 1.0, \ 96 | f"scaling_coef should be in [-1.0, 1.0], but got {scaling_coef}." 97 | 98 | # create an empty task vector 99 | pretrained_checkpoint = os.path.join(model_root, "zeroshot.pt") 100 | 101 | # get the task vector 102 | finetuned_checkpoint = os.path.join(model_root, f"{dataset_name}/finetuned.pt") 103 | task_vector = TaskVector(pretrained_checkpoint, finetuned_checkpoint) 104 | 105 | return task_vector.apply_to(pretrained_checkpoint, scaling_coef) 106 | -------------------------------------------------------------------------------- /task_arithemetic/src/utils/plot.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | from scipy.stats import linregress 5 | 6 | def scatter_plot_with_regression(y: np.ndarray, 7 | x: np.ndarray, 8 | save_path: str, 9 | fname: str, 10 | xlabel: str=None, 11 | ylabel: str=None, 12 | title: str=None, 13 | xlim: tuple=None, 14 | ylim: tuple=None, 15 | figsize: tuple=(7, 7)): 16 | """plot the scatter plot with linear regression 17 | 18 | Args: 19 | y (np.ndarray): the y values. 20 | x (np.ndarray): the x values. 21 | save_path (str): the save path. 22 | fname (str): the filename. 23 | xlabel (str, optional): the xlabel. Defaults to None. 24 | ylabel (str, optional): the ylabel. Defaults to None. 25 | title (str, optional): the title. Defaults to None. 26 | xlim (tuple, optional): the xlim. Defaults to None. 27 | ylim (tuple, optional): the ylim. Defaults to None. 28 | figsize (tuple, optional): the figsize. Defaults to (7, 7). 29 | """ 30 | # Initialise the figure and axes. 31 | fig, ax = plt.subplots(figsize=figsize) 32 | # Perform linear regression to get slope and intercept 33 | slope, intercept, _, _, _ = linregress(x, y) 34 | 35 | # Create a scatter plot 36 | ax.scatter(x, y, label="Data points") 37 | 38 | # Plot the regression line 39 | regression_line = slope * np.array(x) + intercept 40 | ax.plot(x, regression_line, color='red', 41 | label=f'Linear Regression (slope={slope:.2f}, intercept={intercept:.2f})') 42 | 43 | # Add labels and legend 44 | ax.grid() 45 | ax.set(xlabel=xlabel, ylabel=ylabel, title=title) 46 | 47 | # set the xlim and ylim 48 | if xlim is not None: 49 | ax.set_xlim(xlim) 50 | if ylim is not None: 51 | ax.set_ylim(ylim) 52 | 53 | # Add a legend, and position it on the lower right (with no box) 54 | plt.legend(frameon=True, prop={'size': 10}) 55 | 56 | # save the fig 57 | path = os.path.join(save_path, fname) 58 | fig.savefig(path) 59 | plt.close() 60 | 61 | 62 | def plot_multiple_curves(Y: dict, 63 | x: np.ndarray, 64 | save_path: str, 65 | fname: str, 66 | xlabel: str, 67 | ylabel: str, 68 | title: str = None, 69 | ylim: list = None, 70 | figsize: tuple = (7, 5)) -> None: 71 | """plot curves in one figure for each key in dictionary. 72 | Args: 73 | Y (dict): the dictionary of the curves. 74 | x (np.ndarray): the x axis. 75 | save_path (str): the path to save the figure. 76 | fname (str): the file name of the figure. 77 | xlabel (str): the label of x axis. 78 | ylabel (str): the label of y axis. 79 | title (str): the title of the figure. 80 | ylim (list): the range of y axis. 81 | figsize (tuple): the size of the figure. 82 | """ 83 | # Initialise the figure and axes. 84 | fig, ax = plt.subplots(figsize=figsize) 85 | 86 | # Draw all the lines in the same plot, assigning a label for each one to be 87 | # shown in the legend. 88 | for label, y in Y.items(): 89 | ax.plot(x, y, label=label) 90 | 91 | # Add labels and legend 92 | ax.grid() 93 | ax.set(xlabel=xlabel, ylabel=ylabel, title=title) 94 | 95 | # Add a legend, and position it on the lower right (with no box) 96 | plt.legend(frameon=True, prop={'size': 10}) 97 | 98 | # set the ylim 99 | plt.ylim(ylim) 100 | 101 | # save the fig 102 | path = os.path.join(save_path, fname) 103 | fig.savefig(path) 104 | plt.close() -------------------------------------------------------------------------------- /task_arithemetic/src/utils/tools.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | import numpy as np 5 | from torch.utils.data import DataLoader 6 | from collections import OrderedDict 7 | from tqdm import tqdm 8 | from functools import reduce 9 | 10 | # import internal libs 11 | from src.datasets.common import maybe_dictionarize 12 | from src.utils import get_logits 13 | 14 | def search_by_suffix(directory: str, 15 | suffix: str) -> list: 16 | """find all the files with the suffix under the directory 17 | 18 | Args: 19 | directory (str): the directory to find the files 20 | suffix (str): the suffix of the files 21 | 22 | Returns: 23 | list: the list of the files 24 | """ 25 | file_paths = [] 26 | for root, dirs, files in os.walk(directory): 27 | for file in files: 28 | if file.endswith(suffix): 29 | file_paths.append(os.path.join(root, file)) 30 | return file_paths 31 | 32 | 33 | def interpolate_weights(A: OrderedDict, 34 | B: OrderedDict, 35 | alpha: float, 36 | beta: float,) -> OrderedDict: 37 | """interpolate the weights 38 | Args: 39 | A: the weights of model A 40 | B: the weights of model B 41 | alpha: the interpolation coefficient 42 | beta: the interpolation coefficient 43 | 44 | Returns: 45 | the interpolated weights 46 | """ 47 | assert A.keys() == B.keys(), "the keys of A and B should be the same" 48 | C = OrderedDict() 49 | for k, v in A.items(): 50 | if k.startswith("module."): 51 | k = k[7:] 52 | C[k] = alpha * v + beta * B[k] 53 | return C 54 | 55 | 56 | def get_module(model: nn.Module, 57 | module_name: str) -> nn.Module: 58 | """get the module from the model 59 | 60 | Args: 61 | model (nn.Module): the model to extract featuremaps. 62 | module_name (str): name of the module 63 | 64 | Returns: 65 | nn.Module: the module 66 | """ 67 | return reduce(getattr, module_name.split('.'), model) 68 | 69 | 70 | def evaluate(device: torch.device, 71 | model: nn.Module, 72 | dataloader: DataLoader,) -> tuple: 73 | """evaluate the model over the dataset 74 | 75 | Args: 76 | device (torch.device): the device to run the model. 77 | model (nn.Module): the model to be evaluated. 78 | dataloader (Dataloader): usually the test loader. 79 | 80 | Return: 81 | (avg_acc, avg_loss, predictions) 82 | """ 83 | # init the loss function 84 | loss_fn = nn.CrossEntropyLoss(reduction="none").to(device) 85 | # set the model to eval mode 86 | model.eval() 87 | # evaluate 88 | with torch.no_grad(): 89 | predictions, loss_lst, corrects, n = [], [], 0., 0. 90 | for _, data in enumerate(tqdm(dataloader)): 91 | # put data to the device 92 | data = maybe_dictionarize(data) 93 | x = data['images'].to(device) 94 | y = data['labels'].to(device) 95 | 96 | # forward 97 | logits = get_logits(x, model) 98 | # get the loss 99 | losses = loss_fn(logits, y) 100 | # get the preds 101 | preds = logits.argmax(dim=1).to(device) 102 | 103 | # update 104 | predictions.extend(preds.tolist()) 105 | loss_lst.extend(losses.tolist()) 106 | corrects += preds.eq(y.view_as(preds)).sum().item() 107 | n += y.size(0) 108 | 109 | # get the average acc 110 | avg_acc = corrects / n 111 | avg_loss = sum(loss_lst) / len(loss_lst) 112 | 113 | return avg_acc, avg_loss, np.array(predictions) 114 | -------------------------------------------------------------------------------- /task_arithemetic/src/utils/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import pickle 5 | from tqdm import tqdm 6 | import math 7 | 8 | import numpy as np 9 | 10 | 11 | def assign_learning_rate(param_group, new_lr): 12 | param_group["lr"] = new_lr 13 | 14 | 15 | def _warmup_lr(base_lr, warmup_length, step): 16 | return base_lr * (step + 1) / warmup_length 17 | 18 | 19 | def cosine_lr(optimizer, base_lrs, warmup_length, steps): 20 | if not isinstance(base_lrs, list): 21 | base_lrs = [base_lrs for _ in optimizer.param_groups] 22 | assert len(base_lrs) == len(optimizer.param_groups) 23 | def _lr_adjuster(step): 24 | for param_group, base_lr in zip(optimizer.param_groups, base_lrs): 25 | if step < warmup_length: 26 | lr = _warmup_lr(base_lr, warmup_length, step) 27 | else: 28 | e = step - warmup_length 29 | es = steps - warmup_length 30 | lr = 0.5 * (1 + np.cos(np.pi * e / es)) * base_lr 31 | assign_learning_rate(param_group, lr) 32 | return _lr_adjuster 33 | 34 | 35 | def accuracy(output, target, topk=(1,)): 36 | pred = output.topk(max(topk), 1, True, True)[1].t() 37 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 38 | return [float(correct[:k].reshape(-1).float().sum(0, keepdim=True).cpu().numpy()) for k in topk] 39 | 40 | 41 | def torch_load_old(save_path, device=None): 42 | with open(save_path, 'rb') as f: 43 | classifier = pickle.load(f) 44 | if device is not None: 45 | classifier = classifier.to(device) 46 | return classifier 47 | 48 | 49 | def torch_save(model, save_path): 50 | if os.path.dirname(save_path) != '': 51 | os.makedirs(os.path.dirname(save_path), exist_ok=True) 52 | torch.save(model.cpu(), save_path) 53 | 54 | 55 | def torch_load(save_path, device=None): 56 | model = torch.load(save_path) 57 | if device is not None: 58 | model = model.to(device) 59 | return model 60 | 61 | 62 | def get_logits(inputs, classifier): 63 | assert callable(classifier) 64 | if hasattr(classifier, 'to'): 65 | classifier = classifier.to(inputs.device) 66 | return classifier(inputs) 67 | 68 | 69 | def get_probs(inputs, classifier): 70 | if hasattr(classifier, 'predict_proba'): 71 | probs = classifier.predict_proba(inputs.detach().cpu().numpy()) 72 | return torch.from_numpy(probs) 73 | logits = get_logits(inputs, classifier) 74 | return logits.softmax(dim=1) 75 | 76 | 77 | class LabelSmoothing(torch.nn.Module): 78 | def __init__(self, smoothing=0.0): 79 | super(LabelSmoothing, self).__init__() 80 | self.confidence = 1.0 - smoothing 81 | self.smoothing = smoothing 82 | 83 | def forward(self, x, target): 84 | logprobs = torch.nn.functional.log_softmax(x, dim=-1) 85 | 86 | nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1)) 87 | nll_loss = nll_loss.squeeze(1) 88 | smooth_loss = -logprobs.mean(dim=-1) 89 | loss = self.confidence * nll_loss + self.smoothing * smooth_loss 90 | return loss.mean() 91 | -------------------------------------------------------------------------------- /task_arithemetic_t5/README.md: -------------------------------------------------------------------------------- 1 | # Editing Models with Task Arithmetic 2 | 3 | This directory contains our implementation of T5 model for the ICLR 2023 paper [Editing Models with Task Arithmetic](https://arxiv.org/abs/2212.04089). **`cross-task-linearity.py` contains the main implementation of cross-task-linearity.** 4 | 5 | You should download the dataset from [huggingface](https://huggingface.co/) including [IMDB](https://huggingface.co/datasets/stanfordnlp/imdb/), [QASC](https://huggingface.co/datasets/allenai/qasc) and etc. The fine-tuned T5 model can be downloaded from https://huggingface.co/mrm8488 6 | 7 | 8 | Here are the scripts to run the Cross-Task-Linearity Evaluation: 9 | 10 | ```bash 11 | python cross-task-linearity.py \ 12 | --save_root ./outs/ctl_addition/qasc \ 13 | --modelA_path "t5-base-finetuned-qasc" \ 14 | --modelB_path "t5-base-finetuned-imdb" \ 15 | --modelA_coef 0.8 \ 16 | --modelB_coef 0.8 \ 17 | --sample_num 500 \ 18 | --dataset qasc 19 | ``` -------------------------------------------------------------------------------- /task_arithemetic_t5/avgmeter.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import torch 3 | from collections import defaultdict 4 | from typing import Optional 5 | 6 | class AverageMeter(object): 7 | """Computes and stores the average and current value. 8 | 9 | Examples:: 10 | >>> # Initialize a meter to record loss 11 | >>> losses = AverageMeter() 12 | >>> # Update meter after every minibatch update 13 | >>> losses.update(loss_value, batch_size) 14 | """ 15 | def __init__(self): 16 | self.reset() 17 | 18 | def reset(self): 19 | self.val = 0 20 | self.avg = 0 21 | self.sum = 0 22 | self.count = 0 23 | 24 | def update(self, val, n=1): 25 | self.val = val 26 | self.sum += val * n 27 | self.count += n 28 | self.avg = self.sum / self.count 29 | 30 | 31 | class MetricMeter(object): 32 | """A collection of metrics. 33 | 34 | Source: https://github.com/KaiyangZhou/Dassl.pytorch 35 | 36 | Examples:: 37 | >>> # 1. Create an instance of MetricMeter 38 | >>> metric = MetricMeter() 39 | >>> # 2. Update using a dictionary as input 40 | >>> input_dict = {'loss_1': value_1, 'loss_2': value_2} 41 | >>> metric.update(input_dict) 42 | >>> # 3. Convert to string and print 43 | >>> print(str(metric)) 44 | """ 45 | def __init__(self, delimiter='\n\t'): 46 | self.delimiter = delimiter 47 | self.reset() 48 | 49 | def reset(self): 50 | self.meters = defaultdict(AverageMeter) 51 | 52 | def update(self, input_dict: dict, n: int=1): 53 | """Update the meter with a dictionary. 54 | 55 | Args: 56 | input_dict (dict): A dictionary of metrics. 57 | n (int): The number of samples in the input. 58 | """ 59 | if input_dict is None: 60 | return 61 | 62 | if not isinstance(input_dict, dict): 63 | raise TypeError( 64 | 'Input to MetricMeter.update() must be a dictionary' 65 | ) 66 | 67 | for k, v in input_dict.items(): 68 | if isinstance(v, torch.Tensor): 69 | v = v.item() 70 | self.meters[k].update(v, n) 71 | 72 | def __str__(self): 73 | output_str = [] 74 | for name, meter in self.meters.items(): 75 | output_str.append( 76 | '{} {:.4f} ({:.4f})'.format(name, meter.val, meter.avg) 77 | ) 78 | return self.delimiter + self.delimiter.join(output_str) 79 | 80 | 81 | class MetricTracker(object): 82 | """track metrics over time and compute average 83 | """ 84 | def __init__(self): 85 | self.reset() 86 | 87 | def reset(self): 88 | """reset metrics 89 | """ 90 | self.metrics = defaultdict(list) 91 | self.meter = MetricMeter() 92 | 93 | def update(self, input_dict: dict, n: int=1): 94 | """update metrics 95 | 96 | Args: 97 | input_dict (dict): A dictionary of metrics. 98 | n (int): The number of samples in the input. 99 | """ 100 | self.meter.update(input_dict, n) 101 | 102 | def track(self, input_dict: Optional[dict]=None): 103 | """track metrics 104 | """ 105 | if input_dict is not None: 106 | for k, v in input_dict.items(): 107 | assert k not in self.meter.meters, \ 108 | f'key {k} already exists in meter.keys() {self.meter.meters.keys()}' 109 | self.metrics[k].append(v) 110 | 111 | for k, v in self.meter.meters.items(): 112 | self.metrics[k].append(v.avg) 113 | self.meter.reset() 114 | 115 | def __str__(self): 116 | return str(self.meter) 117 | 118 | def save_to_csv(self, filename: str): 119 | """save metrics to csv file 120 | """ 121 | df = pd.DataFrame.from_dict(self.metrics) 122 | df.to_csv(filename, index=False) -------------------------------------------------------------------------------- /task_arithemetic_t5/cross-task-linearity.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com' 3 | import argparse 4 | import torch 5 | import torch.nn as nn 6 | import copy 7 | from torch.utils.data import DataLoader, Subset 8 | from transformers import AutoTokenizer, AutoModelForSeq2SeqLM 9 | from utils import get_datetime, get_logger, set_logger, set_device, set_seed, \ 10 | log_settings, interpolate_weights, save_current_src 11 | from data_utils import load_hf_dataset, get_evaluate_fn 12 | from avgmeter import MetricTracker 13 | from dissimilarity import DissimilarityMetricOverSamples, DissimilarityMetric 14 | from task_vectors import TaskVector 15 | 16 | def eval_linearity(save_path: str, 17 | featuremaps: dict, 18 | alpha: float, 19 | beta: float, 20 | metric: str) -> None: 21 | """evaluate the linearity over different layers and alpha and beta. 22 | 23 | Args: 24 | save_path (str): the path to save the dissimilarity. 25 | featuremaps (dict): the featuremaps of model A and model B. 26 | alpha (float): the interpolation coefficient. 27 | beta (float): the interpolation coefficient. 28 | 29 | Return: 30 | None 31 | """ 32 | logger = get_logger(f"{__name__}.eval_linearity") 33 | if not os.path.exists(save_path): 34 | os.makedirs(save_path) 35 | 36 | distance_fn = DissimilarityMetric(metric=metric) 37 | distance_fn_over_samples = DissimilarityMetricOverSamples(metric=metric) 38 | 39 | # initialize the MetricTracker 40 | tracker = MetricTracker() 41 | 42 | # get the layer_names 43 | layer_names = list(featuremaps["A"].keys()) 44 | # calculate the dissimilarity 45 | for layer_name in layer_names: 46 | logger.info(f"layer: {layer_name}") 47 | 48 | # get the featuremap of model A and model B 49 | featuremapA = featuremaps["A"][layer_name] 50 | featuremapB = featuremaps["B"][layer_name] 51 | featuremap_alpha = featuremaps["alpha"][layer_name] 52 | featuremap_int = alpha * featuremapA + beta * featuremapB 53 | 54 | if metric == "cosine": 55 | dist_A_B, coef_A_B = \ 56 | distance_fn(featuremapA.cpu(), featuremapB.cpu(), get_coef=True) 57 | dist_alpha_A, coef_alpha_A = \ 58 | distance_fn(featuremap_alpha.cpu(), featuremapA.cpu(), get_coef=True) 59 | dist_alpha_B, coef_alpha_B = \ 60 | distance_fn(featuremap_alpha.cpu(), featuremapB.cpu(), get_coef=True) 61 | dist_alpha_int, coef_alpha_int = \ 62 | distance_fn(featuremap_alpha.cpu(), featuremap_int.cpu(), get_coef=True) 63 | 64 | # over samples 65 | dist_A_B_over_samples, coef_A_B_over_samples = \ 66 | distance_fn_over_samples(featuremapA.cpu(), featuremapB.cpu(), get_coef=True) 67 | dist_alpha_A_over_samples, coef_alpha_A_over_samples = \ 68 | distance_fn_over_samples(featuremap_alpha.cpu(), featuremapA.cpu(), get_coef=True) 69 | dist_alpha_B_over_samples, coef_alpha_B_over_samples = \ 70 | distance_fn_over_samples(featuremap_alpha.cpu(), featuremapB.cpu(), get_coef=True) 71 | dist_alpha_int_over_samples, coef_alpha_int_over_samples = \ 72 | distance_fn_over_samples(featuremap_alpha.cpu(), featuremap_int.cpu(), get_coef=True) 73 | 74 | else: 75 | dist_A_B = distance_fn(featuremapA.cpu(), featuremapB.cpu()) 76 | dist_alpha_A = distance_fn(featuremap_alpha.cpu(), featuremapA.cpu()) 77 | dist_alpha_B = distance_fn(featuremap_alpha.cpu(), featuremapB.cpu()) 78 | dist_alpha_int = distance_fn(featuremap_alpha.cpu(), featuremap_int.cpu()) 79 | # over samples 80 | dist_A_B_over_samples = distance_fn_over_samples(featuremapA.cpu(), featuremapB.cpu()) 81 | dist_alpha_A_over_samples = distance_fn_over_samples(featuremap_alpha.cpu(), featuremapA.cpu()) 82 | dist_alpha_B_over_samples = distance_fn_over_samples(featuremap_alpha.cpu(), featuremapB.cpu()) 83 | dist_alpha_int_over_samples = distance_fn_over_samples(featuremap_alpha.cpu(), featuremap_int.cpu()) 84 | 85 | layer_save_path = os.path.join(save_path, layer_name) 86 | if not os.path.exists(layer_save_path): 87 | os.makedirs(layer_save_path) 88 | 89 | torch.save(dist_A_B_over_samples, os.path.join(layer_save_path, f"dist_A_B.pt")) 90 | torch.save(dist_alpha_A_over_samples, os.path.join(layer_save_path, f"dist_alpha_A.pt")) 91 | torch.save(dist_alpha_B_over_samples, os.path.join(layer_save_path, f"dist_alpha_B.pt")) 92 | torch.save(dist_alpha_int_over_samples, os.path.join(layer_save_path, f"dist_alpha_int.pt")) 93 | 94 | if metric == "cosine": 95 | torch.save(coef_A_B_over_samples, os.path.join(layer_save_path, f"coef_A_B.pt")) 96 | torch.save(coef_alpha_int_over_samples, os.path.join(layer_save_path, f"coef_alpha_int.pt")) 97 | torch.save(coef_alpha_A_over_samples, os.path.join(layer_save_path, f"coef_alpha_A.pt")) 98 | torch.save(coef_alpha_B_over_samples, os.path.join(layer_save_path, f"coef_alpha_B.pt")) 99 | 100 | tracker.track({ 101 | "layer": layer_name, 102 | "dist_A_B": dist_A_B.item(), 103 | **({"coef_A_B": coef_A_B.item()} if metric == "cosine" else {}), 104 | "dist_alpha_A": dist_alpha_A.item(), 105 | **({"coef_alpha_A": coef_alpha_A.item()} if metric == "cosine" else {}), 106 | "dist_alpha_B": dist_alpha_B.item(), 107 | **({"coef_alpha_B": coef_alpha_B.item()} if metric == "cosine" else {}), 108 | "dist_alpha_int": dist_alpha_int.item(), 109 | **({"coef_alpha_int": coef_alpha_int.item()} if metric == "cosine" else {}), 110 | # over samples 111 | "dist_A_B_over_samples_mean": dist_A_B_over_samples.mean().item(), 112 | "dist_A_B_over_samples_std": dist_A_B_over_samples.std().item(), 113 | **{"coef_A_B_over_samples_mean": coef_A_B_over_samples.mean().item() if metric == "cosine" else {}}, 114 | **{"coef_A_B_over_samples_std": coef_A_B_over_samples.std().item() if metric == "cosine" else {}}, 115 | 116 | "dist_alpha_A_over_samples_mean": dist_alpha_A_over_samples.mean().item(), 117 | "dist_alpha_A_over_samples_std": dist_alpha_A_over_samples.std().item(), 118 | **{"coef_alpha_A_over_samples_mean": coef_alpha_A_over_samples.mean().item() if metric == "cosine" else {}}, 119 | **{"coef_alpha_A_over_samples_std": coef_alpha_A_over_samples.std().item() if metric == "cosine" else {}}, 120 | 121 | "dist_alpha_B_over_samples_mean": dist_alpha_B_over_samples.mean().item(), 122 | "dist_alpha_B_over_samples_std": dist_alpha_B_over_samples.std().item(), 123 | **{"coef_alpha_B_over_samples_mean": coef_alpha_B_over_samples.mean().item() if metric == "cosine" else {}}, 124 | **{"coef_alpha_B_over_samples_std": coef_alpha_B_over_samples.std().item() if metric == "cosine" else {}}, 125 | 126 | 127 | "dist_alpha_int_over_samples_mean": dist_alpha_int_over_samples.mean().item(), 128 | "dist_alpha_int_over_samples_std": dist_alpha_int_over_samples.std().item(), 129 | **{"coef_alpha_int_over_samples_mean": coef_alpha_int_over_samples.mean().item() if metric == "cosine" else {}}, 130 | **{"coef_alpha_int_over_samples_std": coef_alpha_int_over_samples.std().item() if metric == "cosine" else {}}, 131 | }) 132 | 133 | tracker.save_to_csv(os.path.join(save_path, "sub_linearity.csv")) 134 | 135 | 136 | def get_featuremaps(device: torch.device, 137 | modelA, 138 | modelB, 139 | base_model: nn.Module, 140 | tokenizer, 141 | dataset, 142 | dataset_name: str, 143 | alpha: float, 144 | beta: float,): 145 | """get the featuremaps of model A and model B 146 | 147 | Args: 148 | device (torch.device): the device to run the model. 149 | modelA_path (str): the path of model A. 150 | modelB_path (str): the path of model B. 151 | model (nn.Module): the model to extract featuremaps. 152 | dataloader (torch.utils.data.DataLoader): the dataloader. 153 | alpha (float): the interpolation coefficient. 154 | beta (float): the interpolation coefficient. 155 | 156 | Return: 157 | None 158 | """ 159 | logger = get_logger(f"{__name__}.get_featuremaps") 160 | 161 | # prepare the weights of model A and B 162 | weightA, weightB = modelA.state_dict(), modelB.state_dict() 163 | weight_alpha = interpolate_weights(weightA, weightB, alpha, beta) 164 | # base_model.load_state_dict(weight_alpha) 165 | model_alpha = copy.deepcopy(base_model) 166 | model_alpha.load_state_dict(weight_alpha) 167 | 168 | # set the layers 169 | model_name = base_model.__class__.__name__ 170 | logger.info(f"model_name: {model_name}") 171 | if model_name.startswith("T5"): 172 | from transformers.models.t5.modeling_t5 import T5Block 173 | layers = [name for name, module in base_model.named_modules() if isinstance(module, (T5Block))] 174 | logger.info(f"feature layers: {layers}") 175 | else: 176 | raise NotImplementedError(f"the model - {model_name} is not implemented") 177 | 178 | # get the featuremaps of interpolated model 179 | logger.info(f"get the featuremaps of interpolated model") 180 | featuremaps = dict() 181 | for key, model in [("A", modelA), ("B", modelB), ("alpha", model_alpha)]: 182 | logger.info(f"key: {key}") 183 | 184 | eval_fn = get_evaluate_fn(dataset_name) 185 | model.eval() 186 | model.to(device) 187 | encoder_featuremap = eval_fn(model=model, 188 | tokenizer=tokenizer, 189 | dataset=dataset, 190 | device=device) 191 | 192 | featuremaps[key] = encoder_featuremap 193 | return featuremaps 194 | 195 | def add_args() -> argparse.Namespace: 196 | parser = argparse.ArgumentParser( 197 | description="simple verification") 198 | ## the basic setting of exp 199 | 200 | parser.add_argument('--device', default=0, type=int, 201 | help="set the device.") 202 | parser.add_argument("--seed", default=0, type=int, 203 | help="set the seed.") 204 | parser.add_argument("--save_root", default="./outs/llfc/", type=str, 205 | help='the path of saving results.') 206 | parser.add_argument("--base_model", default="./t5-base", type=str, 207 | help='the base model name.') 208 | parser.add_argument("--modelA_path", default=None, type=str, 209 | help='the path of pretrained model A.') 210 | parser.add_argument("--modelA_coef", default=1.0, type=float, 211 | help='the coefficient of model A.') 212 | parser.add_argument("--modelB_path", default=None, type=str, 213 | help='the path of pretrained model B.') 214 | parser.add_argument("--modelB_coef", default=1.0, type=float, 215 | help='the coefficient of model B.') 216 | parser.add_argument("--dataset", default="imdb", type=str, choices=["imdb", "race", "qasc", "multi_news", "squad", "common_gen"], 217 | help='the dataset name.') 218 | parser.add_argument("--model", default="T5", type=str, 219 | help='the model name.') 220 | parser.add_argument("--alpha", default=0.5, type=float, 221 | help="set the alpha to interpolate the weights.") 222 | parser.add_argument("--beta", default=0.5, type=float, 223 | help="set the beta to interpolate the weights.") 224 | parser.add_argument("--sample_num", default=5000, type=int, 225 | help="set the sample number.") 226 | parser.add_argument("--batch-size", default=128, type=int, 227 | help="set the batch size.") 228 | parser.add_argument("-v", "--verbose", action="store_true", dest="verbose", 229 | help="enable debug info output.") 230 | args = parser.parse_args() 231 | 232 | if not os.path.exists(args.save_root): 233 | os.makedirs(args.save_root) 234 | 235 | # set the save_path 236 | modelA_name = args.modelA_path.split("/")[-1] 237 | modelB_name = args.modelB_path.split("/")[-1] 238 | exp_name = "-".join([get_datetime(), 239 | f"seed{args.seed}", 240 | f"{args.dataset}", 241 | f"{args.model}", 242 | f"modelA_{modelA_name}", 243 | f"modelB_{modelB_name}", 244 | f"modelA_coef{args.modelA_coef}", 245 | f"modelB_coef{args.modelB_coef}", 246 | f"alpha{args.alpha}", 247 | f"beta{args.beta}", 248 | f"sample_num{args.sample_num}", 249 | f"bs{args.batch_size}",]) 250 | args.save_path = os.path.join(args.save_root, exp_name) 251 | if not os.path.exists(args.save_path): 252 | os.makedirs(args.save_path) 253 | return args 254 | 255 | 256 | def main(): 257 | args = add_args() 258 | 259 | set_logger(args.save_path) 260 | logger = get_logger(__name__, args.verbose) 261 | set_seed(args.seed) 262 | args.device = set_device(args.device) 263 | logger.info("#########parameters settings....") 264 | log_settings(args) 265 | 266 | # prepare datasets 267 | assert args.model == "T5", "only support T5 now." 268 | 269 | # prepare dataset 270 | logger.info("#########prepare dataset....") 271 | dataset = load_hf_dataset(args.dataset) 272 | indices = torch.randperm(len(dataset))[:args.sample_num].numpy().tolist() 273 | subset = Subset(dataset, indices) 274 | # dataloader = DataLoader(subset, batch_size=args.batch_size, shuffle=False) 275 | 276 | 277 | # prepare model 278 | logger.info("#########prepare model....") 279 | base_model, base_tokenizer = AutoModelForSeq2SeqLM.from_pretrained(args.base_model), \ 280 | AutoTokenizer.from_pretrained(args.base_model) 281 | modelA_origin, tokenizerA = AutoModelForSeq2SeqLM.from_pretrained(args.modelA_path), \ 282 | AutoTokenizer.from_pretrained(args.modelA_path) 283 | modelB_origin, tokenizerB = AutoModelForSeq2SeqLM.from_pretrained(args.modelB_path), \ 284 | AutoTokenizer.from_pretrained(args.modelB_path) 285 | 286 | 287 | task_vector_A = TaskVector(pretrained_checkpoint=base_model, finetuned_checkpoint=modelA_origin) 288 | modelA = task_vector_A.apply_to(copy.deepcopy(base_model), scaling_coef=args.modelA_coef) 289 | task_vector_B = TaskVector(pretrained_checkpoint=base_model, finetuned_checkpoint=modelB_origin) 290 | modelB = task_vector_B.apply_to(copy.deepcopy(base_model), scaling_coef=args.modelB_coef) 291 | 292 | 293 | # modelA = modelA_origin 294 | # modelB = modelB_origin 295 | 296 | logger.info("#########get the featuremaps....") 297 | featuremaps = get_featuremaps(device=args.device, 298 | modelA=modelA, 299 | modelB=modelB, 300 | base_model = base_model, 301 | tokenizer=base_tokenizer, 302 | dataset=subset, 303 | dataset_name=args.dataset, 304 | alpha=args.alpha, 305 | beta=args.beta,) 306 | 307 | eval_linearity(save_path=os.path.join(args.save_path, "exp"), 308 | featuremaps=featuremaps, 309 | alpha=args.alpha, 310 | beta=args.beta, 311 | metric="cosine") 312 | 313 | if __name__ == "__main__": 314 | main() 315 | -------------------------------------------------------------------------------- /task_arithemetic_t5/data_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com' 3 | 4 | import torch 5 | from datasets import load_dataset 6 | from collections import defaultdict 7 | from utils import get_logger 8 | from featuremap import * 9 | dataset_names = [("imdb",), ("race", "all"), ("qasc",), ("multi_news",), ("squad",), ("allenai/common_gen",)] 10 | 11 | 12 | def get_evaluate_fn(dataset_name): 13 | if len([name[0] for name in dataset_names if dataset_name in name[0]]) == 0: 14 | raise ValueError(f"dataset_name should be in {dataset_names[:0]}") 15 | return eval(f"eval_{dataset_name}") 16 | 17 | 18 | 19 | def process_hf_dataset_to_torch_dataset(hf_ds, dataset_name): 20 | if dataset_name == "imdb": 21 | hf_ds = hf_ds['test'] 22 | elif dataset_name == "race": 23 | hf_ds = hf_ds['test'] 24 | elif dataset_name == "qasc": 25 | hf_ds = hf_ds['test'] 26 | elif dataset_name == "multi_news": 27 | hf_ds = hf_ds['test'] 28 | elif dataset_name == "multi_news": 29 | hf_ds = hf_ds['test'] 30 | elif dataset_name == "squad": 31 | hf_ds = hf_ds['validation'] 32 | elif dataset_name == "common_gen": 33 | hf_ds = hf_ds['validation'] 34 | else: 35 | raise NotImplementedError(f"dataset_name {dataset_name} not implemented.") 36 | ds_torch = hf_ds.with_format("torch") 37 | return ds_torch 38 | 39 | def load_hf_dataset(dataset_name): 40 | if len([name[0] for name in dataset_names if dataset_name in name[0]]) == 0: 41 | raise ValueError(f"dataset_name should be in {dataset_names[:0]}") 42 | dataset_tuple = [name for name in dataset_names if dataset_name in name[0]][0] 43 | hf_ds = load_dataset(*dataset_tuple) 44 | torch_ds = process_hf_dataset_to_torch_dataset(hf_ds, dataset_name) 45 | return torch_ds -------------------------------------------------------------------------------- /task_arithemetic_t5/dissimilarity.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class DissimilarityMetric: 4 | """compute the distance of two featuremaps 5 | """ 6 | def __init__(self, metric): 7 | self.__metric = metric 8 | 9 | def __call__(self, A, B, **kwargs): 10 | if self.__metric == "vanilla": 11 | return self.__vanilla(A, B) 12 | elif self.__metric == "cosine": 13 | return self.__cosine_similarity(A, B, **kwargs) 14 | 15 | def __vanilla(self, A: torch.Tensor, B: torch.Tensor) -> torch.Tensor: 16 | """compute the vanilla distance between two featuremaps, Frobenius norm 17 | 18 | Args: 19 | A (torch.Tensor): the featuremap A, shape: (N, D) 20 | B (torch.Tensor): the featuremap B, shape: (N, D) 21 | """ 22 | assert A.shape == B.shape, \ 23 | f"the shape of A - {A.shape} should be the same as the shape of B - {B.shape}" 24 | A = A.view(A.shape[0], -1) # shape: (N, D) 25 | B = B.view(B.shape[0], -1) # shape: (N, D) 26 | 27 | def norm_square(A: torch.Tensor) -> torch.Tensor: 28 | return torch.sum(A ** 2) # shape: (1, ) 29 | 30 | return norm_square(A - B) / torch.norm(A, p="fro") / torch.norm(B, p="fro") 31 | 32 | 33 | def __cosine_similarity(self, A: torch.Tensor, B: torch.Tensor, **kwargs) -> torch.Tensor: 34 | """compute the cosine similarity between two matrices 35 | 36 | dist = 1 - / (||A|| * ||B||) 37 | 38 | Args: 39 | A (torch.Tensor): the featuremap A, shape: (N, D) 40 | B (torch.Tensor): the featuremap B, shape: (N, D) 41 | kwargs: the keyword arguments 42 | 43 | Return: 44 | dist (torch.Tensor): the distance between A and B 45 | """ 46 | get_coef = kwargs.get("get_coef", False) 47 | 48 | assert A.shape == B.shape, \ 49 | f"the shape of A - {A.shape} should be the same as the shape of B - {B.shape}" 50 | 51 | A = A.view(A.shape[0], -1).double() # shape: (N, D) 52 | B = B.view(B.shape[0], -1).double() # shape: (N, D) 53 | 54 | # compute the frobenius inner product of A and B 55 | inner_product = torch.sum(A * B) # shape: (1, 1) 56 | # compute the frobenius norm of A and B 57 | A_norm = torch.norm(A, p="fro") # shape: (1, 1) 58 | B_norm = torch.norm(B, p="fro") # shape: (1, 1) 59 | 60 | # cal the distance 61 | dist = 1 - torch.abs(inner_product) / (A_norm * B_norm) 62 | 63 | # compute the coefficient 64 | coef = inner_product / (B_norm ** 2) 65 | 66 | assert torch.abs(inner_product) <= A_norm * B_norm * (1 + 1e-10), \ 67 | f"the inner product - {inner_product} should be less than the product of the norm - {A_norm * B_norm}" 68 | 69 | if get_coef: 70 | return dist, coef 71 | else: 72 | return dist 73 | 74 | 75 | class DissimilarityMetricOverSamples: 76 | """compute the distance of two featuremaps, for each sample 77 | """ 78 | def __init__(self, metric): 79 | self.__metric = metric 80 | 81 | def __call__(self, A, B, **kwargs): 82 | if self.__metric == "vanilla": 83 | return self.__vanilla(A, B) 84 | elif self.__metric == "cosine": 85 | return self.__cosine_similarity(A, B, **kwargs) 86 | 87 | def __vanilla(self, A: torch.Tensor, B: torch.Tensor) -> torch.Tensor: 88 | """compute the vanilla distance between two featuremaps, Frobenius norm 89 | 90 | Args: 91 | A (torch.Tensor): the featuremap A, shape: (N, D) 92 | B (torch.Tensor): the featuremap B, shape: (N, D) 93 | """ 94 | assert A.shape == B.shape, \ 95 | f"the shape of A - {A.shape} should be the same as the shape of B - {B.shape}" 96 | A = A.view(A.shape[0], -1) # shape: (N, D) 97 | B = B.view(B.shape[0], -1) # shape: (N, D) 98 | 99 | def norm_square(A: torch.Tensor) -> torch.Tensor: 100 | return torch.sum(A ** 2, dim=-1) # shape: (N, ) 101 | return norm_square(A - B) / torch.norm(A, p="fro", dim=-1) / torch.norm(B, p="fro", dim=-1) # shape: (N, ) 102 | 103 | 104 | def __cosine_similarity(self, A: torch.Tensor, B: torch.Tensor, **kwargs) -> torch.Tensor: 105 | """compute the cosine similarity between two matrices 106 | 107 | dist = 1 - / (||A|| * ||B||) 108 | 109 | Args: 110 | A (torch.Tensor): the featuremap A, shape: (N, D) 111 | B (torch.Tensor): the featuremap B, shape: (N, D) 112 | kwargs: the keyword arguments 113 | 114 | Return: 115 | dist (torch.Tensor): the distance between A and B 116 | """ 117 | get_coef = kwargs.get("get_coef", False) 118 | 119 | assert A.shape == B.shape, \ 120 | f"the shape of A - {A.shape} should be the same as the shape of B - {B.shape}" 121 | 122 | A = A.view(A.shape[0], -1).double() # shape: (N, D) 123 | B = B.view(B.shape[0], -1).double() # shape: (N, D) 124 | 125 | # compute the frobenius inner product of A and B 126 | inner_product = torch.sum(A * B, dim=-1) # shape: (N, ) 127 | # compute the frobenius norm of A and B 128 | A_norm = torch.norm(A, p="fro", dim=-1) # shape: (N, ) 129 | B_norm = torch.norm(B, p="fro", dim=-1) # shape: (N, ) 130 | 131 | # cal the distance 132 | dist = 1 - torch.abs(inner_product) / (A_norm * B_norm) # shape: (N, ) 133 | 134 | # compute the coefficient 135 | coef = inner_product / (B_norm ** 2) # shape: (N, ) 136 | 137 | if get_coef: 138 | return dist, coef 139 | else: 140 | return dist -------------------------------------------------------------------------------- /task_arithemetic_t5/featuremap.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from utils import get_logger 3 | from collections import defaultdict 4 | 5 | 6 | def get_t5_featuremaps(featuremaps, encoder_hidden_states, decoder_hidden_states): 7 | enc_id_to_layer_name = {} 8 | enc_id_to_layer_name[0] = "encoder.init_embedding" 9 | for i in range(1, len(encoder_hidden_states)): 10 | enc_id_to_layer_name[i] = f"encoder.T5Block.{i-1}" 11 | dec_id_to_layer_name = {} 12 | dec_id_to_layer_name[0] = "decoder.init_embedding" 13 | for i in range(1, len(decoder_hidden_states[0])): 14 | dec_id_to_layer_name[i] = f"decoder.T5Block.{i-1}" 15 | 16 | for layer_id, layer_name in enc_id_to_layer_name.items(): 17 | avg_embedding = encoder_hidden_states[layer_id].mean(dim=1).squeeze() 18 | featuremaps[layer_name].append(avg_embedding) 19 | for layer_id, layer_name in dec_id_to_layer_name.items(): 20 | attention_pooling_embedding = decoder_hidden_states[0][layer_id] 21 | attention_pooling_embedding = attention_pooling_embedding.squeeze() 22 | featuremaps[layer_name].append(attention_pooling_embedding) 23 | return featuremaps 24 | 25 | 26 | def eval_imdb(model, tokenizer, dataset, device): 27 | model.to(device) 28 | featuremaps = defaultdict(list) 29 | logger = get_logger(f"{__name__}.eval_imdb") 30 | for i, item in enumerate(dataset): 31 | inputs_ids = tokenizer.encode(item['text']+'', return_tensors="pt").to(device) 32 | len_inputs_seq = len(inputs_ids[0]) 33 | raw_outputs = model.generate(input_ids=inputs_ids, max_length=2, output_hidden_states=True, return_dict_in_generate=True) 34 | output_texts = [tokenizer.decode(ids) for ids in raw_outputs['sequences']] 35 | # logger.info(f"{i}th token len: {len_inputs_seq}, {i}th output text: {output_texts}") 36 | get_t5_featuremaps(featuremaps, raw_outputs['encoder_hidden_states'], raw_outputs['decoder_hidden_states']) 37 | 38 | enc_fm = {} 39 | for key, val in featuremaps.items(): 40 | enc_fm[key] = torch.stack(val, dim=0) 41 | return enc_fm 42 | 43 | def eval_race(model, tokenizer, dataset, device): 44 | model.to(device) 45 | featuremaps = defaultdict(list) 46 | logger = get_logger(f"{__name__}.eval_race") 47 | option_alphabet = ["(A)", "(B)", "(C)", "(D)", "(E)", "(F)", "(G)", "(H)", "(I)", "(J)", "(K)", \ 48 | "(L)", "(M)", "(N)", "(O)", "(P)", "(Q)", "(R)", "(S)", "(T)", "(U)", "(V)", "(W)", "(X)", "(Y)", "(Z)"] 49 | for i, item in enumerate(dataset): 50 | article = item['article'] 51 | question = item['question'] 52 | options = item['options'] 53 | 54 | options_with_alphabet = " ".join([f"{alphabet} {option}" for alphabet, option in zip(option_alphabet, options)]) 55 | context = f"{article} {options_with_alphabet}" 56 | input_text = f"question: {question} context: {context}" 57 | inputs_ids = tokenizer([input_text], return_tensors="pt").to(device) 58 | len_inputs_seq = len(inputs_ids[0]) 59 | 60 | raw_outputs = model.generate(input_ids=inputs_ids['input_ids'], attention_mask=inputs_ids['attention_mask'], max_length=128, 61 | output_hidden_states=True, return_dict_in_generate=True) 62 | output_texts = [tokenizer.decode(ids) for ids in raw_outputs['sequences']] 63 | logger.info(f"{i}th options: {options_with_alphabet}") 64 | logger.info(f"{i}th token len: {len_inputs_seq}, {i}th output text: {output_texts}") 65 | get_t5_featuremaps(featuremaps, raw_outputs['encoder_hidden_states'], raw_outputs['decoder_hidden_states']) 66 | 67 | enc_fm = {} 68 | for key, val in featuremaps.items(): 69 | enc_fm[key] = torch.stack(val, dim=0) 70 | return enc_fm 71 | 72 | def eval_qasc(model, tokenizer, dataset, device): 73 | """ 74 | def get_response(question, context, max_length=64): 75 | input_text = 'question: %s context: %s' % (question, context) 76 | features = tokenizer([input_text], return_tensors='pt') 77 | 78 | output = model.generate(input_ids=features['input_ids'], 79 | attention_mask=features['attention_mask'], 80 | max_length=max_length) 81 | 82 | return tokenizer.decode(output[0]) 83 | 84 | fact_1 = 'a watch is used for measuring time' 85 | fact_2 = 'Times are measured in seconds.' 86 | context = fact_1 + ' ' + fact_2 87 | question = 'What can be used to measure seconds? (A) Watch (B) seconds (C) fluid (D) Ruler (E) goggles (F) glasses (G) Drill (H) Scale' 88 | 89 | get_response(question, context) 90 | """ 91 | model.to(device) 92 | featuremaps = defaultdict(list) 93 | logger = get_logger(f"{__name__}.eval_qasc") 94 | for i, item in enumerate(dataset): 95 | question = item['formatted_question'] 96 | context = " ".join([item["fact1"], item["fact2"], item["combinedfact"]]) 97 | input_text = f"question: {question} context: {context}" 98 | inputs_ids = tokenizer([input_text], return_tensors="pt").to(device) 99 | len_inputs_seq = len(inputs_ids[0]) 100 | 101 | raw_outputs = model.generate(input_ids=inputs_ids['input_ids'], attention_mask=inputs_ids['attention_mask'], max_length=128, 102 | output_hidden_states=True, return_dict_in_generate=True) 103 | output_texts = [tokenizer.decode(ids) for ids in raw_outputs['sequences']] 104 | logger.info(f"{i}th options: {question}") 105 | logger.info(f"{i}th token len: {len_inputs_seq}, {i}th output text: {output_texts}") 106 | get_t5_featuremaps(featuremaps, raw_outputs['encoder_hidden_states'], raw_outputs['decoder_hidden_states']) 107 | 108 | enc_fm = {} 109 | for key, val in featuremaps.items(): 110 | enc_fm[key] = torch.stack(val, dim=0) 111 | return enc_fm 112 | 113 | 114 | def eval_multi_news(model, tokenizer, dataset, device): 115 | """ 116 | def summarize(text, max_length=150): 117 | input_ids = tokenizer.encode(text, return_tensors="pt", add_special_tokens=True) 118 | generated_ids = model.generate(input_ids=input_ids, num_beams=2, max_length=max_length, repetition_penalty=2.5, length_penalty=1.0, early_stopping=True) 119 | preds = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True) for g in generated_ids] 120 | return preds[0] 121 | """ 122 | model.to(device) 123 | featuremaps = defaultdict(list) 124 | logger = get_logger(f"{__name__}.eval_multi_news") 125 | for i, item in enumerate(dataset): 126 | document = item['document'] 127 | if len(document) > 4096: 128 | document = document[:4096] 129 | inputs_ids = tokenizer.encode(document, return_tensors="pt", add_special_tokens=True).to(device) 130 | len_inputs_seq = len(inputs_ids[0]) 131 | 132 | raw_outputs = model.generate(input_ids=inputs_ids, num_beams=2, max_length=128, repetition_penalty=2.5, \ 133 | length_penalty=1.0, early_stopping=True, output_hidden_states=True, return_dict_in_generate=True) 134 | output_texts = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True) for g in raw_outputs['sequences']] 135 | logger.info(f"{i}th token len: {len_inputs_seq}, {i}th output text: {output_texts}") 136 | get_t5_featuremaps(featuremaps, raw_outputs['encoder_hidden_states'], raw_outputs['decoder_hidden_states']) 137 | 138 | enc_fm = {} 139 | for key, val in featuremaps.items(): 140 | enc_fm[key] = torch.stack(val, dim=0) 141 | return enc_fm 142 | 143 | 144 | def eval_squad(model, tokenizer, dataset, device): 145 | """ 146 | def get_question(answer, context, max_length=64): 147 | input_text = "answer: %s context: %s " % (answer, context) 148 | features = tokenizer([input_text], return_tensors='pt') 149 | 150 | output = model.generate(input_ids=features['input_ids'], 151 | attention_mask=features['attention_mask'], 152 | max_length=max_length) 153 | 154 | return tokenizer.decode(output[0]) 155 | 156 | context = "Manuel has created RuPERTa-base with the support of HF-Transformers and Google" 157 | answer = "Manuel" 158 | 159 | get_question(answer, context) 160 | """ 161 | model.to(device) 162 | featuremaps = defaultdict(list) 163 | logger = get_logger(f"{__name__}.eval_squad") 164 | 165 | for i, item in enumerate(dataset): 166 | context = item['context'] 167 | answer = item['answers']["text"][0] 168 | input_text = f"answer: {answer} context: {context} " 169 | inputs_ids = tokenizer([input_text], return_tensors="pt").to(device) 170 | len_inputs_seq = len(inputs_ids[0]) 171 | 172 | raw_outputs = model.generate(input_ids=inputs_ids['input_ids'], attention_mask=inputs_ids['attention_mask'], max_length=128, 173 | output_hidden_states=True, return_dict_in_generate=True) 174 | output_texts = [tokenizer.decode(ids) for ids in raw_outputs['sequences']] 175 | logger.info(f"{i}th token len: {len_inputs_seq}, {i}th output text: {output_texts}") 176 | get_t5_featuremaps(featuremaps, raw_outputs['encoder_hidden_states'], raw_outputs['decoder_hidden_states']) 177 | 178 | enc_fm = {} 179 | for key, val in featuremaps.items(): 180 | enc_fm[key] = torch.stack(val, dim=0) 181 | return enc_fm 182 | 183 | 184 | def eval_common_gen(model, tokenizer, dataset, device): 185 | """ 186 | def gen_sentence(words, max_length=32): 187 | input_text = words 188 | features = tokenizer([input_text], return_tensors='pt') 189 | 190 | output = model.generate(input_ids=features['input_ids'], 191 | attention_mask=features['attention_mask'], 192 | max_length=max_length) 193 | 194 | return tokenizer.decode(output[0], skip_special_tokens=True) 195 | 196 | words = "tree plant ground hole dig" 197 | 198 | gen_sentence(words) 199 | """ 200 | model.to(device) 201 | featuremaps = defaultdict(list) 202 | logger = get_logger(f"{__name__}.eval_common_gen") 203 | for i, item in enumerate(dataset): 204 | input_text = " ".join(item['concepts']) 205 | inputs_ids = tokenizer([input_text], return_tensors="pt").to(device) 206 | len_inputs_seq = len(inputs_ids[0]) 207 | 208 | raw_outputs = model.generate(input_ids=inputs_ids['input_ids'], attention_mask=inputs_ids['attention_mask'], max_length=64, 209 | output_hidden_states=True, return_dict_in_generate=True) 210 | output_texts = [tokenizer.decode(ids) for ids in raw_outputs['sequences']] 211 | logger.info(f"{i}th target sentence: {item['target']}") 212 | logger.info(f"{i}th token len: {len_inputs_seq}, {i}th output text: {output_texts}") 213 | get_t5_featuremaps(featuremaps, raw_outputs['encoder_hidden_states'], raw_outputs['decoder_hidden_states']) 214 | 215 | enc_fm = {} 216 | for key, val in featuremaps.items(): 217 | enc_fm[key] = torch.stack(val, dim=0) 218 | return enc_fm -------------------------------------------------------------------------------- /task_arithemetic_t5/task_vectors.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class TaskVector(): 5 | def __init__(self, pretrained_checkpoint=None, finetuned_checkpoint=None, vector=None): 6 | """Initializes the task vector from a pretrained and a finetuned checkpoints. 7 | 8 | This can either be done by passing two state dicts (one corresponding to the 9 | pretrained model, and another to the finetuned model), or by directly passying in 10 | the task vector state dict. 11 | """ 12 | if vector is not None: 13 | self.vector = vector 14 | else: 15 | assert pretrained_checkpoint is not None and finetuned_checkpoint is not None 16 | with torch.no_grad(): 17 | pretrained_state_dict = pretrained_checkpoint.state_dict() 18 | finetuned_state_dict = finetuned_checkpoint.state_dict() 19 | # import ipdb; ipdb.set_trace() 20 | self.vector = {} 21 | for key in pretrained_state_dict: 22 | # if not (key.startswith("encoder.block") or key.startswith("decoder.block")): 23 | # continue 24 | # if key.startswith("shared"): 25 | # continue 26 | if pretrained_state_dict[key].dtype in [torch.int64, torch.uint8]: 27 | continue 28 | self.vector[key] = finetuned_state_dict[key] - pretrained_state_dict[key] 29 | 30 | def __add__(self, other): 31 | """Add two task vectors together.""" 32 | with torch.no_grad(): 33 | new_vector = {} 34 | for key in self.vector: 35 | if key not in other.vector: 36 | print(f'Warning, key {key} is not present in both task vectors.') 37 | continue 38 | new_vector[key] = self.vector[key] + other.vector[key] 39 | return TaskVector(vector=new_vector) 40 | 41 | def __radd__(self, other): 42 | if other is None or isinstance(other, int): 43 | return self 44 | return self.__add__(other) 45 | 46 | def __neg__(self): 47 | """Negate a task vector.""" 48 | with torch.no_grad(): 49 | new_vector = {} 50 | for key in self.vector: 51 | new_vector[key] = - self.vector[key] 52 | return TaskVector(vector=new_vector) 53 | 54 | def apply_to(self, pretrained_checkpoint, scaling_coef=1.0): 55 | """Apply a task vector to a pretrained model.""" 56 | with torch.no_grad(): 57 | pretrained_model = pretrained_checkpoint 58 | new_state_dict = {} 59 | pretrained_state_dict = pretrained_model.state_dict() 60 | for key in pretrained_state_dict: 61 | if key not in self.vector: 62 | print(f'Warning: key {key} is present in the pretrained state dict but not in the task vector') 63 | continue 64 | new_state_dict[key] = pretrained_state_dict[key] + scaling_coef * self.vector[key] 65 | pretrained_model.load_state_dict(new_state_dict, strict=False) 66 | return pretrained_model 67 | 68 | 69 | -------------------------------------------------------------------------------- /task_arithemetic_t5/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import datetime 3 | import argparse 4 | import logging 5 | import random 6 | import numpy as np 7 | import torch 8 | from typing import OrderedDict 9 | 10 | def interpolate_weights(A: OrderedDict, 11 | B: OrderedDict, 12 | alpha: float, 13 | beta: float,) -> OrderedDict: 14 | """interpolate the weights 15 | Args: 16 | A: the weights of model A 17 | B: the weights of model B 18 | alpha: the interpolation coefficient 19 | beta: the interpolation coefficient 20 | 21 | Returns: 22 | the interpolated weights 23 | """ 24 | assert A.keys() == B.keys(), "the keys of A and B should be the same" 25 | C = OrderedDict() 26 | for k, v in A.items(): 27 | if k.startswith("module."): 28 | k = k[7:] 29 | C[k] = alpha * v + beta * B[k] 30 | return C 31 | 32 | 33 | def get_datetime() -> str: 34 | """get the date. 35 | Returns: 36 | date (str): the date. 37 | """ 38 | datetime_ = datetime.datetime.now().strftime("%m%d-%H%M%S") 39 | return datetime_ 40 | 41 | 42 | def set_logger(save_path: str) -> None: 43 | """set the logger. 44 | Args: 45 | save_path(str): the path for saving logfile.txt 46 | name(str): the name of the logger 47 | verbose(bool): if true, will print to console. 48 | 49 | Returns: 50 | None 51 | """ 52 | # set the logger 53 | logfile = os.path.join(save_path, "logfile.txt") 54 | logging.basicConfig(filename=logfile, 55 | filemode="w+", 56 | format='%(name)-12s: %(levelname)-8s %(message)s', 57 | datefmt="%H:%M:%S", 58 | level=logging.INFO) 59 | # define a Handler which writes DEBUG messages or higher to the sys.stderr 60 | console = logging.StreamHandler() 61 | console.setLevel(logging.DEBUG) 62 | # tell the handler to use this format 63 | console.setFormatter(logging.Formatter( 64 | '%(name)-12s: %(levelname)-8s %(message)s')) 65 | # add the handler to the root logger 66 | logging.getLogger().addHandler(console) 67 | 68 | 69 | def get_logger(name:str, 70 | verbose:bool = True) -> logging.Logger: 71 | """get the logger. 72 | Args: 73 | name (str): the name of the logger 74 | verbose (bool): if true, will print to console. 75 | Returns: 76 | logger (logging.Logger) 77 | """ 78 | logger = logging.getLogger(name) 79 | 80 | logger.setLevel(logging.DEBUG) 81 | if not verbose: 82 | logger.setLevel(logging.INFO) 83 | return logger 84 | 85 | 86 | def set_seed(seed: int = 0) -> None: 87 | """set the random seed for multiple packages. 88 | Args: 89 | seed (int): the seed. 90 | 91 | Returns: 92 | None 93 | """ 94 | random.seed(seed) 95 | os.environ['PYTHONHASHSEED'] = str(seed) 96 | np.random.seed(seed) 97 | torch.manual_seed(seed) 98 | torch.cuda.manual_seed(seed) 99 | torch.backends.cudnn.deterministic = True 100 | 101 | 102 | def set_device(device: int) -> torch.device: 103 | """set GPU device. 104 | Args: 105 | device (int) the number of GPU device 106 | 107 | Returns: 108 | device (torch.device) 109 | """ 110 | logger = get_logger(__name__) 111 | if torch.cuda.is_available(): 112 | if device >= torch.cuda.device_count(): 113 | logger.error("CUDA error, invalid device ordinal") 114 | exit(1) 115 | else: 116 | logger.error("Plz choose other machine with GPU to run the program") 117 | exit(1) 118 | os.environ['CUDA_VISIBLE_DEVICES'] = str(device) 119 | device = torch.device("cuda:" + str(device)) 120 | logger.info(device) 121 | return device 122 | 123 | 124 | def log_settings(args: argparse.Namespace, config: dict = {}) -> None: 125 | """log the settings of the program. 126 | Args: 127 | args (argparse.Namespace): the arguments. 128 | config (dict): the config. 129 | """ 130 | logger = get_logger(__name__) 131 | hyperparameters = { 132 | **args.__dict__, 133 | **{key: value for key, value in config.items() \ 134 | if key.isupper() and type(value) in [int, float, str, bool, dict]} 135 | } 136 | logger.info(hyperparameters) 137 | 138 | 139 | def save_current_src(save_path: str) -> None: 140 | """save the current src. 141 | Args: 142 | save_path (str): the path to save the current src. 143 | src_path (str): the path to the current src. 144 | Returns: 145 | None 146 | """ 147 | logger = get_logger(__name__) 148 | logger.info("save the current src") 149 | src_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 150 | os.system("cp -r {} {}".format(src_path, save_path)) 151 | script_path = os.path.join(os.path.dirname(src_path), "scripts") 152 | os.system("cp -r {} {}".format(script_path, save_path)) --------------------------------------------------------------------------------