├── README.md ├── data └── data_links.txt └── src ├── args.py ├── datasets ├── cars.py ├── cifar10.py ├── cifar100.py ├── common.py ├── dtd.py ├── eurosat.py ├── flowers.py ├── gtsrb.py ├── imagenet.py ├── imagenet100.py ├── mnist.py ├── pets.py ├── registry.py ├── resisc45.py ├── stl10.py ├── sun397.py ├── svhn.py └── templates.py ├── eval.py ├── figures ├── comparison.png ├── exp.png ├── figures.txt ├── main_table.png ├── neulig_overview.png └── neulig_train_pip.png ├── finetune_clean.py ├── heads.py ├── modeling.py ├── neulig_main.py ├── pgbar.py ├── task_vectors.py ├── tm_utils.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 |
2 |

Towards Performance Consistency in Multi-Level Model Collaboration

3 | 4 | 5 |
6 | 7 | arXiv Paper 8 | 9 |
10 | 11 |
12 | Qi Li Runpeng Yu Xinchao Wang 13 |
14 |
15 | xML-Lab, National University of Singapore  16 | corresponding author 17 |
18 |
19 | 20 | 21 | ------------------ 22 | TL;DR (1) - Achieve performance consistency between merging and ensembling in a unified framework. 23 | 24 | TL;DR (2) - Provide theoretical support for the realization of the performance consistency. 25 | 26 | 27 | ## Graphical Abstract 28 | 29 | 30 | 34 | 38 | 39 |
31 | 32 |

Figure 1. An illustration of Portland, which consists of a linear layer followed by a softmax function.

33 |
35 | 36 |

Figure 2. The training process of Portland.

37 |
40 | 41 | 42 | 43 |
44 |
45 | Diagram 1 46 |
47 |
48 | Figure 3. A toy experiment to verify theoretical feasibility. In this experiment, we merged two models that were fine-tuned on different datasets. Marker shapes represent different methods, while colors indicate different experimental groups, with each group using a distinct combination of datasets. In total, 10 groups are conducted (represented by 10 different colors). Hollow markers for each method indicate the average results across these 10 groups. 49 | 50 |
51 |
52 | Diagram 2 53 |
54 |
55 | Table 1. The asterisk indicates that the condition is partially satisfied. For Simple-Averaging, the theoretical discussion is limited to the relationship between the performance of merging two models and that of ensembling. Furthermore, although both Simple-Averaging and Task-Arithmetic can be applied to CNN-based models, their performance is suboptimal. In the case of Diverse-Origin Models, all previous methods yield performance close to random guessing, but our conclusions remain applicable. 56 | 57 |
58 |
59 | Diagram 3 60 |
61 |
62 | Table 2. Results of various methods across multiple datasets, including the merging performance, the ensembling performance, and the performance gap for both CLIP-RN50 and CLIP-ViT-B/32. 63 | 64 | ## Installation & Preparation 65 | 66 | 1. Clone the repo and prepare the virtual environment. 67 | 68 | ``` 69 | git clone https://github.com/LiQiiiii/Neural-Ligand.git 70 | ``` 71 | 72 | ``` 73 | cd Neural-Ligand 74 | ``` 75 | 76 | ``` 77 | conda create -n neulig python=3.8.10 78 | ``` 79 | 80 | ``` 81 | conda activate neulig 82 | ``` 83 | 84 | The codes are tested on torch 2.0.0 and torchvision 0.15.1. 85 | 86 | 2. Prepare the dataset and models. The download link of the datasets used in the paper can be found in `./data/data_links.txt`. Save them in the `./data` folder. Run: 87 | 88 | ``` 89 | python ./src/finetune_clean.py 90 | ``` 91 | 92 | to get the corresponding models for the training and evaluation. 93 | 94 | 95 | 96 | ## Training & Evaluation 97 | 98 | ``` 99 | python ./src/neulig_main.py --num_co_models 2 --global_epoch 1000 --alignment_type sup --model RN50 100 | ``` 101 | 102 | where `--num_co_models` is the number of collaborating models, `--alignment_type` controls the alignment term (i.e., sup/semi), and `--model` controls the model type (i.e., RN50/ViT-B-32/ViT-L-14). 103 | 104 | ## Citation 105 | 106 | If you finding our work interesting or helpful to you, please cite as follows: 107 | 108 | ``` 109 | @article{li2025multi, 110 | title={Multi-Level Collaboration in Model Merging}, 111 | author={Li, Qi and Yu, Runpeng and Wang, Xinchao}, 112 | journal={arXiv preprint arXiv:2503.01268}, 113 | year={2025} 114 | } 115 | ``` 116 | -------------------------------------------------------------------------------- /data/data_links.txt: -------------------------------------------------------------------------------- 1 | 2 | CIFAR10, CIFAR10, MNIST, GTSRB, SVHN can be automatically downloaded via torchvision. 3 | 4 | # RESISC45 5 | https://huggingface.co/datasets/timm/resisc45 6 | 7 | 8 | # STL10 9 | https://ai.stanford.edu/~acoates/stl10 10 | 11 | -------------------------------------------------------------------------------- /src/args.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import torch 4 | def parse_arguments(): 5 | parser = argparse.ArgumentParser() 6 | parser.add_argument( 7 | "--data-location", 8 | type=str, 9 | default=os.path.expanduser('./data'), 10 | help="The root directory for the datasets.", 11 | ) 12 | parser.add_argument( 13 | "--eval-datasets", 14 | default=None, 15 | type=lambda x: x.split(","), 16 | help="Which datasets to use for evaluation. Split by comma, e.g. MNIST,EuroSAT. " 17 | ) 18 | parser.add_argument( 19 | "--train-dataset", 20 | default=None, 21 | type=lambda x: x.split(","), 22 | help="Which dataset(s) to patch on.", 23 | ) 24 | parser.add_argument( 25 | "--exp_name", 26 | type=str, 27 | default=None, 28 | help="Name of the experiment, for organization purposes only." 29 | ) 30 | parser.add_argument( 31 | "--results-db", 32 | type=str, 33 | default=None, 34 | help="Where to store the results, else does not store", 35 | ) 36 | parser.add_argument( 37 | "--batch-size", 38 | type=int, 39 | default=128, 40 | ) 41 | parser.add_argument( 42 | "--lr", 43 | type=float, 44 | default=0.001, 45 | help="Learning rate." 46 | ) 47 | parser.add_argument( 48 | "--wd", 49 | type=float, 50 | default=0.1, 51 | help="Weight decay" 52 | ) 53 | parser.add_argument( 54 | "--ls", 55 | type=float, 56 | default=0.0, 57 | help="Label smoothing." 58 | ) 59 | parser.add_argument( 60 | "--warmup_length", 61 | type=int, 62 | default=500, 63 | ) 64 | parser.add_argument( 65 | "--epochs", 66 | type=int, 67 | default=10, 68 | ) 69 | parser.add_argument( 70 | "--load", 71 | type=lambda x: x.split(","), 72 | default=None, 73 | help="Optionally load _classifiers_, e.g. a zero shot classifier or probe or ensemble both.", 74 | ) 75 | parser.add_argument( 76 | "--save", 77 | type=str, 78 | default=None, 79 | help="Optionally save a _classifier_, e.g. a zero shot classifier or probe.", 80 | ) 81 | parser.add_argument( 82 | "--cache-dir", 83 | type=str, 84 | default=None, 85 | help="Directory for caching features and encoder", 86 | ) 87 | parser.add_argument( 88 | "--openclip-cachedir", 89 | type=str, 90 | default='./open_clip', 91 | help='Directory for caching models from OpenCLIP' 92 | ) 93 | 94 | parser.add_argument( 95 | "--ckpt-dir", 96 | type=str, 97 | default='./checkpoints', 98 | ) 99 | parser.add_argument( 100 | "--logs-dir", 101 | type=str, 102 | default='./logs/', 103 | ) 104 | parser.add_argument( 105 | "--suffix", 106 | type=str, 107 | default='Val', 108 | ) 109 | parser.add_argument( 110 | "--ada_name", 111 | type=str, 112 | default='lambda.pt', 113 | ) 114 | parser.add_argument( 115 | "--scaling-coef-", 116 | type=float, 117 | default=0.3, 118 | help="Label smoothing." 119 | ) 120 | parser.add_argument( 121 | "--model", 122 | type=str, 123 | default='RN50', 124 | help="The type of model (e.g. RN50, ViT-B-32, ViT-L-14).", 125 | ) 126 | parser.add_argument( 127 | "--num_co_models", 128 | type=int, 129 | default=2, 130 | help="number of collaborating models." 131 | ) 132 | parser.add_argument( 133 | "--global_epoch", 134 | type=int, 135 | default=1000, 136 | help="number of global epochs." 137 | ) 138 | parser.add_argument( 139 | "--scaling", 140 | type=int, 141 | default=100.0, 142 | help="scaling params." 143 | ) 144 | 145 | parser.add_argument( 146 | "--alignment_type", 147 | type=str, 148 | default='sup', 149 | help="sup for supervised and semi for semisupervised." 150 | ) 151 | 152 | parsed_args = parser.parse_args() 153 | parsed_args.device = "cuda" if torch.cuda.is_available() else "cpu" 154 | 155 | if parsed_args.load is not None and len(parsed_args.load) == 1: 156 | parsed_args.load = parsed_args.load[0] 157 | return parsed_args 158 | -------------------------------------------------------------------------------- /src/datasets/cars.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torchvision.datasets as datasets 4 | import pathlib 5 | from typing import Callable, Optional, Any, Tuple 6 | from PIL import Image 7 | from torchvision.datasets.utils import download_and_extract_archive, download_url, verify_str_arg 8 | from torchvision.datasets.vision import VisionDataset 9 | 10 | class PytorchStanfordCars(VisionDataset): 11 | """`Stanford Cars `_ Dataset 12 | 13 | The Cars dataset contains 16,185 images of 196 classes of cars. The data is 14 | split into 8,144 training images and 8,041 testing images, where each class 15 | has been split roughly in a 50-50 split 16 | 17 | .. note:: 18 | 19 | This class needs `scipy `_ to load target files from `.mat` format. 20 | 21 | Args: 22 | root (string): Root directory of dataset 23 | split (string, optional): The dataset split, supports ``"train"`` (default) or ``"test"``. 24 | transform (callable, optional): A function/transform that takes in an PIL image 25 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 26 | target_transform (callable, optional): A function/transform that takes in the 27 | target and transforms it. 28 | download (bool, optional): If True, downloads the dataset from the internet and 29 | puts it in root directory. If dataset is already downloaded, it is not 30 | downloaded again.""" 31 | 32 | def __init__( 33 | self, 34 | root: str, 35 | split: str = "train", 36 | transform: Optional[Callable] = None, 37 | target_transform: Optional[Callable] = None, 38 | download: bool = False, 39 | ) -> None: 40 | 41 | try: 42 | import scipy.io as sio 43 | except ImportError: 44 | raise RuntimeError("Scipy is not found. This dataset needs to have scipy installed: pip install scipy") 45 | 46 | super().__init__(root, transform=transform, target_transform=target_transform) 47 | 48 | self._split = verify_str_arg(split, "split", ("train", "test")) 49 | self._base_folder = pathlib.Path('./data') / "stanford_cars" 50 | devkit = self._base_folder / "devkit" 51 | 52 | if self._split == "train": 53 | self._annotations_mat_path = devkit / "cars_train_annos.mat" 54 | self._images_base_path = self._base_folder / "cars_train" 55 | else: 56 | self._annotations_mat_path = devkit / "cars_test_annos_withlabels.mat" 57 | self._images_base_path = self._base_folder / "cars_test" 58 | 59 | # if download: 60 | # self.download() 61 | 62 | if not self._check_exists(): 63 | raise RuntimeError("Dataset not found. You can use download=True to download it") 64 | 65 | self._samples = [ 66 | ( 67 | str(self._images_base_path / annotation["fname"]), 68 | annotation["class"] - 1, # Original target mapping starts from 1, hence -1 69 | ) 70 | for annotation in sio.loadmat(self._annotations_mat_path, squeeze_me=True)["annotations"] 71 | ] 72 | 73 | self.classes = sio.loadmat(str(devkit / "cars_meta.mat"), squeeze_me=True)["class_names"].tolist() 74 | self.class_to_idx = {cls: i for i, cls in enumerate(self.classes)} 75 | 76 | def __len__(self) -> int: 77 | return len(self._samples) 78 | 79 | def __getitem__(self, idx: int) -> Tuple[Any, Any]: 80 | """Returns pil_image and class_id for given index""" 81 | image_path, target = self._samples[idx] 82 | pil_image = Image.open(image_path).convert("RGB") 83 | 84 | if self.transform is not None: 85 | pil_image = self.transform(pil_image) 86 | if self.target_transform is not None: 87 | target = self.target_transform(target) 88 | return pil_image, target, idx 89 | 90 | 91 | def download(self) -> None: 92 | if self._check_exists(): 93 | return 94 | 95 | download_and_extract_archive( 96 | url="https://ai.stanford.edu/~jkrause/cars/car_devkit.tgz", 97 | download_root=str(self._base_folder), 98 | md5="c3b158d763b6e2245038c8ad08e45376", 99 | ) 100 | if self._split == "train": 101 | download_and_extract_archive( 102 | url="https://ai.stanford.edu/~jkrause/car196/cars_train.tgz", 103 | download_root=str(self._base_folder), 104 | md5="065e5b463ae28d29e77c1b4b166cfe61", 105 | ) 106 | else: 107 | download_and_extract_archive( 108 | url="https://ai.stanford.edu/~jkrause/car196/cars_test.tgz", 109 | download_root=str(self._base_folder), 110 | md5="4ce7ebf6a94d07f1952d94dd34c4d501", 111 | ) 112 | download_url( 113 | url="https://ai.stanford.edu/~jkrause/car196/cars_test_annos_withlabels.mat", 114 | root=str(self._base_folder), 115 | md5="b0a2b23655a3edd16d84508592a98d10", 116 | ) 117 | 118 | def _check_exists(self) -> bool: 119 | print(self._base_folder / "devkit") 120 | if not (self._base_folder / "devkit").is_dir(): 121 | return False 122 | 123 | return self._annotations_mat_path.exists() and self._images_base_path.is_dir() 124 | 125 | 126 | class Cars: 127 | def __init__(self, 128 | preprocess, 129 | location=os.path.expanduser('./data'), 130 | batch_size=32, 131 | num_workers=16): 132 | # Data loading code 133 | 134 | self.train_dataset = PytorchStanfordCars(location, 'train', preprocess, download=True) 135 | self.train_loader = torch.utils.data.DataLoader( 136 | self.train_dataset, 137 | shuffle=True, 138 | batch_size=batch_size, 139 | num_workers=num_workers, 140 | ) 141 | 142 | self.test_dataset = PytorchStanfordCars(location, 'test', preprocess, download=True) 143 | self.test_loader = torch.utils.data.DataLoader( 144 | self.test_dataset, 145 | batch_size=batch_size, 146 | num_workers=num_workers 147 | ) 148 | self.test_loader_shuffle = torch.utils.data.DataLoader( 149 | self.test_dataset, 150 | shuffle=True, 151 | batch_size=batch_size, 152 | num_workers=num_workers 153 | ) 154 | idx_to_class = dict((v, k) 155 | for k, v in self.train_dataset.class_to_idx.items()) 156 | self.classnames = [idx_to_class[i].replace( 157 | '_', ' ') for i in range(len(idx_to_class))] 158 | -------------------------------------------------------------------------------- /src/datasets/cifar10.py: -------------------------------------------------------------------------------- 1 | import os 2 | import PIL 3 | import torch 4 | import numpy as np 5 | import torchvision 6 | from torchvision import transforms 7 | from torchvision.datasets import CIFAR10 as PyTorchCIFAR10 8 | from torchvision.datasets import VisionDataset 9 | from PIL import Image 10 | 11 | cifar_classnames = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'] 12 | 13 | class MyPyTorchCIFAR10(PyTorchCIFAR10): 14 | def __init__(self, root, download, train, transform): 15 | super().__init__(root=root, download=download, train=train, transform=transform) 16 | 17 | def __getitem__(self, index: int): 18 | """ 19 | Args: 20 | index (int): Index 21 | 22 | Returns: 23 | tuple: (image, target) where target is index of the target class. 24 | """ 25 | img, target = self.data[index], self.targets[index] 26 | 27 | # doing this so that it is consistent with all other datasets 28 | # to return a PIL Image 29 | img = Image.fromarray(img) 30 | 31 | if self.transform is not None: 32 | img = self.transform(img) 33 | 34 | if self.target_transform is not None: 35 | target = self.target_transform(target) 36 | 37 | return img, target, index 38 | 39 | class CIFAR10: 40 | def __init__(self, preprocess, 41 | location=os.path.expanduser('./data'), 42 | batch_size=128, 43 | num_workers=16): 44 | 45 | 46 | self.train_dataset = MyPyTorchCIFAR10( 47 | root=location, download=True, train=True, transform=preprocess 48 | ) 49 | 50 | self.train_loader = torch.utils.data.DataLoader( 51 | self.train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers 52 | ) 53 | 54 | self.test_dataset = MyPyTorchCIFAR10( 55 | root=location, download=True, train=False, transform=preprocess 56 | ) 57 | 58 | self.test_loader = torch.utils.data.DataLoader( 59 | self.test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers 60 | ) 61 | 62 | self.test_loader_shuffle = torch.utils.data.DataLoader( 63 | self.test_dataset, 64 | shuffle=True, 65 | batch_size=batch_size, 66 | num_workers=num_workers 67 | ) 68 | 69 | self.classnames = self.test_dataset.classes 70 | 71 | def convert(x): 72 | if isinstance(x, np.ndarray): 73 | return torchvision.transforms.functional.to_pil_image(x) 74 | return x 75 | 76 | class BasicVisionDataset(VisionDataset): 77 | def __init__(self, images, targets, transform=None, target_transform=None): 78 | if transform is not None: 79 | transform.transforms.insert(0, convert) 80 | super(BasicVisionDataset, self).__init__(root=None, transform=transform, target_transform=target_transform) 81 | assert len(images) == len(targets) 82 | 83 | self.images = images 84 | self.targets = targets 85 | 86 | def __getitem__(self, index): 87 | return self.transform(self.images[index]), self.targets[index] 88 | 89 | def __len__(self): 90 | return len(self.targets) 91 | -------------------------------------------------------------------------------- /src/datasets/cifar100.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torchvision.datasets import CIFAR100 as PyTorchCIFAR100 4 | from PIL import Image 5 | 6 | class MyPyTorchCIFAR100(PyTorchCIFAR100): 7 | def __init__(self, root, download, train, transform): 8 | super().__init__(root=root, download=download, train=train, transform=transform) 9 | 10 | def __getitem__(self, index: int): 11 | """ 12 | Args: 13 | index (int): Index 14 | 15 | Returns: 16 | tuple: (image, target) where target is index of the target class. 17 | """ 18 | img, target = self.data[index], self.targets[index] 19 | 20 | # doing this so that it is consistent with all other datasets 21 | # to return a PIL Image 22 | img = Image.fromarray(img) 23 | 24 | if self.transform is not None: 25 | img = self.transform(img) 26 | 27 | if self.target_transform is not None: 28 | target = self.target_transform(target) 29 | 30 | return img, target, index 31 | 32 | class CIFAR100: 33 | def __init__(self, 34 | preprocess, 35 | location=os.path.expanduser('./data'), 36 | batch_size=128, 37 | num_workers=16): 38 | 39 | self.train_dataset = MyPyTorchCIFAR100( 40 | root=location, download=True, train=True, transform=preprocess 41 | ) 42 | 43 | self.train_loader = torch.utils.data.DataLoader( 44 | self.train_dataset, batch_size=batch_size, num_workers=num_workers 45 | ) 46 | 47 | self.test_dataset = MyPyTorchCIFAR100( 48 | root=location, download=True, train=False, transform=preprocess 49 | ) 50 | 51 | self.test_loader = torch.utils.data.DataLoader( 52 | self.test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers 53 | ) 54 | 55 | self.test_loader_shuffle = torch.utils.data.DataLoader( 56 | self.test_dataset, 57 | shuffle=True, 58 | batch_size=batch_size, 59 | num_workers=num_workers 60 | ) 61 | 62 | self.classnames = self.test_dataset.classes 63 | 64 | 65 | -------------------------------------------------------------------------------- /src/datasets/common.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import json 4 | import glob 5 | import collections 6 | import random 7 | import numpy as np 8 | from tqdm import tqdm 9 | import torchvision.datasets as datasets 10 | from torch.utils.data import Dataset, DataLoader, Sampler 11 | 12 | class SubsetSampler(Sampler): 13 | def __init__(self, indices): 14 | self.indices = indices 15 | 16 | def __iter__(self): 17 | return (i for i in self.indices) 18 | 19 | def __len__(self): 20 | return len(self.indices) 21 | 22 | class ImageFolderWithPaths(datasets.ImageFolder): 23 | def __init__(self, path, transform, flip_label_prob=0.0): 24 | super().__init__(path, transform) 25 | self.flip_label_prob = flip_label_prob 26 | if self.flip_label_prob > 0: 27 | print(f'Flipping labels with probability {self.flip_label_prob}') 28 | num_classes = len(self.classes) 29 | for i in range(len(self.samples)): 30 | if random.random() < self.flip_label_prob: 31 | new_label = random.randint(0, num_classes-1) 32 | self.samples[i] = ( 33 | self.samples[i][0], 34 | new_label 35 | ) 36 | 37 | def __getitem__(self, index): 38 | image, label = super(ImageFolderWithPaths, self).__getitem__(index) 39 | return { 40 | 'images': image, 41 | 'labels': label, 42 | 'image_paths': self.samples[index][0] 43 | } 44 | 45 | def maybe_dictionarize(batch): # double check 46 | if isinstance(batch, dict): 47 | return batch 48 | 49 | if len(batch) ==2: 50 | batch = {'images': batch[0], 'labels': batch[1]} 51 | elif len(batch) == 3: 52 | batch = {'images': batch[0], 'labels': batch[1], 'indices': batch[2]} 53 | elif len(batch) == 4: 54 | batch = {'images': batch[0], 'labels': batch[1], 'indices': batch[2], 'metadata': batch[3]} 55 | else: 56 | raise ValueError(f'Unexpected number of elements: {len(batch)}') 57 | 58 | return batch 59 | 60 | def get_features_helper(image_encoder, dataloader, device): 61 | all_data = collections.defaultdict(list) 62 | 63 | image_encoder = image_encoder.to(device) 64 | image_encoder = torch.nn.DataParallel(image_encoder, device_ids=[x for x in range(torch.cuda.device_count())]) 65 | image_encoder.eval() 66 | 67 | with torch.no_grad(): 68 | for batch in tqdm(dataloader): 69 | batch = maybe_dictionarize(batch) 70 | features = image_encoder(batch['images'].cuda()) 71 | 72 | all_data['features'].append(features.cpu()) 73 | 74 | for key, val in batch.items(): 75 | if key == 'images': 76 | continue 77 | if hasattr(val, 'cpu'): 78 | val = val.cpu() 79 | all_data[key].append(val) 80 | else: 81 | all_data[key].extend(val) 82 | 83 | for key, val in all_data.items(): 84 | if torch.is_tensor(val[0]): 85 | all_data[key] = torch.cat(val).numpy() 86 | 87 | return all_data 88 | 89 | def get_features(is_train, image_encoder, dataset, device): 90 | split = 'train' if is_train else 'val' 91 | dname = type(dataset).__name__ 92 | if image_encoder.cache_dir is not None: 93 | cache_dir = f'{image_encoder.cache_dir}/{dname}/{split}' 94 | cached_files = glob.glob(f'{cache_dir}/*') 95 | if image_encoder.cache_dir is not None and len(cached_files) > 0: 96 | print(f'Getting features from {cache_dir}') 97 | data = {} 98 | for cached_file in cached_files: 99 | name = os.path.splitext(os.path.basename(cached_file))[0] 100 | data[name] = torch.load(cached_file) 101 | else: 102 | print(f'Did not find cached features at {cache_dir}. Building from scratch.') 103 | loader = dataset.train_loader if is_train else dataset.test_loader 104 | data = get_features_helper(image_encoder, loader, device) 105 | if image_encoder.cache_dir is None: 106 | print('Not caching because no cache directory was passed.') 107 | else: 108 | os.makedirs(cache_dir, exist_ok=True) 109 | print(f'Caching data at {cache_dir}') 110 | for name, val in data.items(): 111 | torch.save(val, f'{cache_dir}/{name}.pt') 112 | return data 113 | 114 | class FeatureDataset(Dataset): 115 | def __init__(self, is_train, image_encoder, dataset, device): 116 | self.data = get_features(is_train, image_encoder, dataset, device) 117 | 118 | def __len__(self): 119 | return len(self.data['features']) 120 | 121 | def __getitem__(self, idx): 122 | data = {k: v[idx] for k, v in self.data.items()} 123 | data['features'] = torch.from_numpy(data['features']).float() 124 | return data 125 | 126 | def get_dataloader(dataset, split): 127 | if split=='train': 128 | dataloader = dataset.train_loader 129 | elif split=='test': 130 | dataloader = dataset.test_loader 131 | elif split=='test_shuffled': 132 | dataloader = dataset.test_loader_shuffle 133 | elif split=='dev': 134 | dataloader = dataset.test_loader_shuffle 135 | elif split=='shadowtrain': 136 | dataloader = dataset.shadowtrain_loader 137 | return dataloader -------------------------------------------------------------------------------- /src/datasets/dtd.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torchvision.datasets as datasets 4 | 5 | class ImageFolderDataset(datasets.ImageFolder): 6 | def __init__(self, root, transform): 7 | super().__init__(root, transform) 8 | 9 | def __getitem__(self, index: int): 10 | path, target = self.samples[index] 11 | sample = self.loader(path) 12 | if self.transform is not None: 13 | sample = self.transform(sample) 14 | if self.target_transform is not None: 15 | target = self.target_transform(target) 16 | return sample, target, index 17 | 18 | class DTD: 19 | def __init__(self, 20 | preprocess, 21 | location=os.path.expanduser('./data'), 22 | batch_size=32, 23 | num_workers=16): 24 | # Data loading code 25 | location = './data' 26 | traindir = os.path.join(location, 'dtd', 'train') 27 | valdir = os.path.join(location, 'dtd', 'test') 28 | 29 | self.train_dataset = ImageFolderDataset( 30 | traindir, transform=preprocess) 31 | self.train_loader = torch.utils.data.DataLoader( 32 | self.train_dataset, 33 | shuffle=True, 34 | batch_size=batch_size, 35 | num_workers=num_workers, 36 | ) 37 | 38 | self.test_dataset = ImageFolderDataset(valdir, transform=preprocess) 39 | self.test_loader = torch.utils.data.DataLoader( 40 | self.test_dataset, 41 | batch_size=batch_size, 42 | num_workers=num_workers 43 | ) 44 | self.test_loader_shuffle = torch.utils.data.DataLoader( 45 | self.test_dataset, 46 | shuffle=True, 47 | batch_size=batch_size, 48 | num_workers=num_workers 49 | ) 50 | idx_to_class = dict((v, k) 51 | for k, v in self.train_dataset.class_to_idx.items()) 52 | self.classnames = [idx_to_class[i].replace( 53 | '_', ' ') for i in range(len(idx_to_class))] -------------------------------------------------------------------------------- /src/datasets/eurosat.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torchvision.datasets as datasets 4 | import re 5 | import numpy as np 6 | 7 | def pretify_classname(classname): 8 | l = re.findall(r'[A-Z](?:[a-z]+|[A-Z]*(?=[A-Z]|$))', classname) 9 | l = [i.lower() for i in l] 10 | out = ' '.join(l) 11 | if out.endswith('al'): 12 | return out + ' area' 13 | return out 14 | 15 | class ImageFolderDataset(datasets.ImageFolder): 16 | def __init__(self, root, transform): 17 | super().__init__(root, transform) 18 | self.indices = np.arange(len(self.samples)) 19 | 20 | def __getitem__(self, index: int): 21 | path, target = self.samples[index] 22 | sample = self.loader(path) 23 | if self.transform is not None: 24 | sample = self.transform(sample) 25 | if self.target_transform is not None: 26 | target = self.target_transform(target) 27 | return sample, target, index 28 | 29 | class EuroSATBase: 30 | def __init__(self, 31 | preprocess, 32 | test_split, 33 | location='./data', 34 | batch_size=32, 35 | num_workers=16): 36 | # Data loading code 37 | location = './data' 38 | traindir = os.path.join(location, 'EuroSAT_splits', 'train') 39 | testdir = os.path.join(location, 'EuroSAT_splits', test_split) 40 | 41 | 42 | self.train_dataset = ImageFolderDataset(traindir, transform=preprocess) 43 | self.train_loader = torch.utils.data.DataLoader( 44 | self.train_dataset, 45 | shuffle=True, 46 | batch_size=batch_size, 47 | num_workers=num_workers, 48 | ) 49 | 50 | self.test_dataset = ImageFolderDataset(testdir, transform=preprocess) 51 | self.test_loader = torch.utils.data.DataLoader( 52 | self.test_dataset, 53 | batch_size=batch_size, 54 | num_workers=num_workers 55 | ) 56 | self.test_loader_shuffle = torch.utils.data.DataLoader( 57 | self.test_dataset, 58 | shuffle=True, 59 | batch_size=batch_size, 60 | num_workers=num_workers 61 | ) 62 | idx_to_class = dict((v, k) 63 | for k, v in self.train_dataset.class_to_idx.items()) 64 | self.classnames = [idx_to_class[i].replace('_', ' ') for i in range(len(idx_to_class))] 65 | self.classnames = [pretify_classname(c) for c in self.classnames] 66 | ours_to_open_ai = { 67 | 'annual crop': 'annual crop land', 68 | 'forest': 'forest', 69 | 'herbaceous vegetation': 'brushland or shrubland', 70 | 'highway': 'highway or road', 71 | 'industrial area': 'industrial buildings or commercial buildings', 72 | 'pasture': 'pasture land', 73 | 'permanent crop': 'permanent crop land', 74 | 'residential area': 'residential buildings or homes or apartments', 75 | 'river': 'river', 76 | 'sea lake': 'lake or sea', 77 | } 78 | for i in range(len(self.classnames)): 79 | self.classnames[i] = ours_to_open_ai[self.classnames[i]] 80 | 81 | 82 | class EuroSAT(EuroSATBase): 83 | def __init__(self, 84 | preprocess, 85 | location='~/datasets', 86 | batch_size=32, 87 | num_workers=16): 88 | super().__init__(preprocess, 'test', location, batch_size, num_workers) 89 | 90 | 91 | class EuroSATVal(EuroSATBase): 92 | def __init__(self, 93 | preprocess, 94 | location='~/datasets', 95 | batch_size=32, 96 | num_workers=16): 97 | super().__init__(preprocess, 'val', location, batch_size, num_workers) -------------------------------------------------------------------------------- /src/datasets/flowers.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torchvision.datasets as datasets 4 | import re 5 | import numpy as np 6 | 7 | class ImageFolderDataset(datasets.ImageFolder): 8 | def __init__(self, root, transform): 9 | super().__init__(root, transform) 10 | self.indices = np.arange(len(self.samples)) 11 | 12 | def __getitem__(self, index: int): 13 | path, target = self.samples[index] 14 | sample = self.loader(path) 15 | if self.transform is not None: 16 | sample = self.transform(sample) 17 | if self.target_transform is not None: 18 | target = self.target_transform(target) 19 | return sample, target, index 20 | 21 | class FlowersBase: 22 | def __init__(self, 23 | preprocess, 24 | test_split, 25 | location='./data', 26 | batch_size=32, 27 | num_workers=16): 28 | # Data loading code 29 | location = './data' 30 | traindir = os.path.join(location, 'flowers', 'train') 31 | testdir = os.path.join(location, 'flowers', test_split) 32 | 33 | 34 | self.train_dataset = ImageFolderDataset(traindir, transform=preprocess) 35 | self.train_loader = torch.utils.data.DataLoader( 36 | self.train_dataset, 37 | shuffle=True, 38 | batch_size=batch_size, 39 | num_workers=num_workers, 40 | ) 41 | 42 | self.test_dataset = ImageFolderDataset(testdir, transform=preprocess) 43 | self.test_loader = torch.utils.data.DataLoader( 44 | self.test_dataset, 45 | batch_size=batch_size, 46 | num_workers=num_workers 47 | ) 48 | self.test_loader_shuffle = torch.utils.data.DataLoader( 49 | self.test_dataset, 50 | shuffle=True, 51 | batch_size=batch_size, 52 | num_workers=num_workers 53 | ) 54 | 55 | idx_to_class = dict((v, k) 56 | for k, v in self.train_dataset.class_to_idx.items()) 57 | self.classnames = [idx_to_class[i].replace('_', ' ') for i in range(len(idx_to_class))] 58 | 59 | 60 | class Flowers(FlowersBase): 61 | def __init__(self, 62 | preprocess, 63 | location='~/datasets', 64 | batch_size=32, 65 | num_workers=16): 66 | super().__init__(preprocess, 'test', location, batch_size, num_workers) 67 | 68 | 69 | class FlowersVal(FlowersBase): 70 | def __init__(self, 71 | preprocess, 72 | location='~/datasets', 73 | batch_size=32, 74 | num_workers=16): 75 | super().__init__(preprocess, 'val', location, batch_size, num_workers) 76 | -------------------------------------------------------------------------------- /src/datasets/gtsrb.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import os 3 | import pathlib 4 | from typing import Any, Callable, Dict, List, Optional, Tuple 5 | 6 | import numpy as np 7 | import PIL 8 | import torch 9 | from torchvision.datasets.folder import make_dataset 10 | from torchvision.datasets.utils import (download_and_extract_archive, 11 | verify_str_arg) 12 | from torchvision.datasets.vision import VisionDataset 13 | 14 | def find_classes(directory: str) -> Tuple[List[str], Dict[str, int]]: 15 | """Finds the class folders in a dataset. 16 | 17 | See :class:`DatasetFolder` for details. 18 | """ 19 | classes = sorted(entry.name for entry in os.scandir(directory) if entry.is_dir()) 20 | if not classes: 21 | raise FileNotFoundError(f"Couldn't find any class folder in {directory}.") 22 | 23 | class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)} 24 | return classes, class_to_idx 25 | 26 | class PyTorchGTSRB(VisionDataset): 27 | """`German Traffic Sign Recognition Benchmark (GTSRB) `_ Dataset. 28 | 29 | Modified from https://pytorch.org/vision/main/_modules/torchvision/datasets/gtsrb.html#GTSRB. 30 | 31 | Args: 32 | root (string): Root directory of the dataset. 33 | split (string, optional): The dataset split, supports ``"train"`` (default), or ``"test"``. 34 | transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed 35 | version. E.g, ``transforms.RandomCrop``. 36 | target_transform (callable, optional): A function/transform that takes in the target and transforms it. 37 | download (bool, optional): If True, downloads the dataset from the internet and 38 | puts it in root directory. If dataset is already downloaded, it is not 39 | downloaded again. 40 | """ 41 | 42 | def __init__( 43 | self, 44 | root: str, 45 | split: str = "train", 46 | transform: Optional[Callable] = None, 47 | target_transform: Optional[Callable] = None, 48 | download: bool = False, 49 | ) -> None: 50 | 51 | super().__init__(root, transform=transform, target_transform=target_transform) 52 | 53 | self._split = verify_str_arg(split, "split", ("train", "test")) 54 | self._base_folder = pathlib.Path("./data") / "gtsrb" 55 | self._target_folder = ( 56 | self._base_folder / "GTSRB" / ("Training" if self._split == "train" else "Final_Test/Images") 57 | ) 58 | 59 | if download: 60 | self.download() 61 | 62 | if not self._check_exists(): 63 | raise RuntimeError("Dataset not found. You can use download=True to download it") 64 | 65 | if self._split == "train": 66 | _, class_to_idx = find_classes(str(self._target_folder)) 67 | samples = make_dataset(str(self._target_folder), extensions=(".ppm",), class_to_idx=class_to_idx) 68 | else: 69 | with open(self._base_folder / "GT-final_test.csv") as csv_file: 70 | samples = [ 71 | (str(self._target_folder / row["Filename"]), int(row["ClassId"])) 72 | for row in csv.DictReader(csv_file, delimiter=";", skipinitialspace=True) 73 | ] 74 | 75 | self._samples = samples 76 | self.transform = transform 77 | self.target_transform = target_transform 78 | 79 | def __len__(self) -> int: 80 | return len(self._samples) 81 | 82 | def __getitem__(self, index: int) -> Tuple[Any, Any]: 83 | 84 | path, target = self._samples[index] 85 | sample = PIL.Image.open(path).convert("RGB") 86 | 87 | if self.transform is not None: 88 | sample = self.transform(sample) 89 | 90 | if self.target_transform is not None: 91 | target = self.target_transform(target) 92 | 93 | return sample, target, index 94 | 95 | 96 | def _check_exists(self) -> bool: 97 | return self._target_folder.is_dir() 98 | 99 | def download(self) -> None: 100 | if self._check_exists(): 101 | return 102 | 103 | base_url = "https://sid.erda.dk/public/archives/daaeac0d7ce1152aea9b61d9f1e19370/" 104 | 105 | if self._split == "train": 106 | download_and_extract_archive( 107 | f"{base_url}GTSRB-Training_fixed.zip", 108 | download_root=str(self._base_folder), 109 | md5="513f3c79a4c5141765e10e952eaa2478", 110 | ) 111 | else: 112 | download_and_extract_archive( 113 | f"{base_url}GTSRB_Final_Test_Images.zip", 114 | download_root=str(self._base_folder), 115 | md5="c7e4e6327067d32654124b0fe9e82185", 116 | ) 117 | download_and_extract_archive( 118 | f"{base_url}GTSRB_Final_Test_GT.zip", 119 | download_root=str(self._base_folder), 120 | md5="fe31e9c9270bbcd7b84b7f21a9d9d9e5", 121 | ) 122 | 123 | 124 | class GTSRB: 125 | def __init__(self, 126 | preprocess, 127 | location=os.path.expanduser('./data'), 128 | batch_size=128, 129 | num_workers=16): 130 | 131 | # to fit with repo conventions for location 132 | self.train_dataset = PyTorchGTSRB( 133 | root=location, 134 | download=True, 135 | split='train', 136 | transform=preprocess 137 | ) 138 | 139 | self.train_loader = torch.utils.data.DataLoader( 140 | self.train_dataset, 141 | batch_size=batch_size, 142 | shuffle=True, 143 | num_workers=num_workers 144 | ) 145 | 146 | self.test_dataset = PyTorchGTSRB( 147 | root=location, 148 | download=True, 149 | split='test', 150 | transform=preprocess 151 | ) 152 | 153 | self.test_loader = torch.utils.data.DataLoader( 154 | self.test_dataset, 155 | batch_size=batch_size, 156 | shuffle=False, 157 | num_workers=num_workers 158 | ) 159 | 160 | self.test_loader_shuffle = torch.utils.data.DataLoader( 161 | self.test_dataset, 162 | shuffle=True, 163 | batch_size=batch_size, 164 | num_workers=num_workers 165 | ) 166 | 167 | # from https://github.com/openai/CLIP/blob/e184f608c5d5e58165682f7c332c3a8b4c1545f2/data/prompts.md 168 | self.classnames = [ 169 | 'red and white circle 20 kph speed limit', 170 | 'red and white circle 30 kph speed limit', 171 | 'red and white circle 50 kph speed limit', 172 | 'red and white circle 60 kph speed limit', 173 | 'red and white circle 70 kph speed limit', 174 | 'red and white circle 80 kph speed limit', 175 | 'end / de-restriction of 80 kph speed limit', 176 | 'red and white circle 100 kph speed limit', 177 | 'red and white circle 120 kph speed limit', 178 | 'red and white circle red car and black car no passing', 179 | 'red and white circle red truck and black car no passing', 180 | 'red and white triangle road intersection warning', 181 | 'white and yellow diamond priority road', 182 | 'red and white upside down triangle yield right-of-way', 183 | 'stop', 184 | 'empty red and white circle', 185 | 'red and white circle no truck entry', 186 | 'red circle with white horizonal stripe no entry', 187 | 'red and white triangle with exclamation mark warning', 188 | 'red and white triangle with black left curve approaching warning', 189 | 'red and white triangle with black right curve approaching warning', 190 | 'red and white triangle with black double curve approaching warning', 191 | 'red and white triangle rough / bumpy road warning', 192 | 'red and white triangle car skidding / slipping warning', 193 | 'red and white triangle with merging / narrow lanes warning', 194 | 'red and white triangle with person digging / construction / road work warning', 195 | 'red and white triangle with traffic light approaching warning', 196 | 'red and white triangle with person walking warning', 197 | 'red and white triangle with child and person walking warning', 198 | 'red and white triangle with bicyle warning', 199 | 'red and white triangle with snowflake / ice warning', 200 | 'red and white triangle with deer warning', 201 | 'white circle with gray strike bar no speed limit', 202 | 'blue circle with white right turn arrow mandatory', 203 | 'blue circle with white left turn arrow mandatory', 204 | 'blue circle with white forward arrow mandatory', 205 | 'blue circle with white forward or right turn arrow mandatory', 206 | 'blue circle with white forward or left turn arrow mandatory', 207 | 'blue circle with white keep right arrow mandatory', 208 | 'blue circle with white keep left arrow mandatory', 209 | 'blue circle with white arrows indicating a traffic circle', 210 | 'white circle with gray strike bar indicating no passing for cars has ended', 211 | 'white circle with gray strike bar indicating no passing for trucks has ended', 212 | ] 213 | -------------------------------------------------------------------------------- /src/datasets/imagenet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | from .common import ImageFolderWithPaths, SubsetSampler 5 | import numpy as np 6 | 7 | def get_imagenet_classnames(): 8 | imagenet_classnames = [ 9 | "tench", "goldfish", "great white shark", "tiger shark", "hammerhead shark", "electric ray", 10 | "stingray", "rooster", "hen", "ostrich", "brambling", "goldfinch", "house finch", "junco", 11 | "indigo bunting", "American robin", "bulbul", "jay", "magpie", "chickadee", "American dipper", 12 | "kite (bird of prey)", "bald eagle", "vulture", "great grey owl", "fire salamander", 13 | "smooth newt", "newt", "spotted salamander", "axolotl", "American bullfrog", "tree frog", 14 | "tailed frog", "loggerhead sea turtle", "leatherback sea turtle", "mud turtle", "terrapin", 15 | "box turtle", "banded gecko", "green iguana", "Carolina anole", 16 | "desert grassland whiptail lizard", "agama", "frilled-necked lizard", "alligator lizard", 17 | "Gila monster", "European green lizard", "chameleon", "Komodo dragon", "Nile crocodile", 18 | "American alligator", "triceratops", "worm snake", "ring-necked snake", 19 | "eastern hog-nosed snake", "smooth green snake", "kingsnake", "garter snake", "water snake", 20 | "vine snake", "night snake", "boa constrictor", "African rock python", "Indian cobra", 21 | "green mamba", "sea snake", "Saharan horned viper", "eastern diamondback rattlesnake", 22 | "sidewinder rattlesnake", "trilobite", "harvestman", "scorpion", "yellow garden spider", 23 | "barn spider", "European garden spider", "southern black widow", "tarantula", "wolf spider", 24 | "tick", "centipede", "black grouse", "ptarmigan", "ruffed grouse", "prairie grouse", "peafowl", 25 | "quail", "partridge", "african grey parrot", "macaw", "sulphur-crested cockatoo", "lorikeet", 26 | "coucal", "bee eater", "hornbill", "hummingbird", "jacamar", "toucan", "duck", 27 | "red-breasted merganser", "goose", "black swan", "tusker", "echidna", "platypus", "wallaby", 28 | "koala", "wombat", "jellyfish", "sea anemone", "brain coral", "flatworm", "nematode", "conch", 29 | "snail", "slug", "sea slug", "chiton", "chambered nautilus", "Dungeness crab", "rock crab", 30 | "fiddler crab", "red king crab", "American lobster", "spiny lobster", "crayfish", "hermit crab", 31 | "isopod", "white stork", "black stork", "spoonbill", "flamingo", "little blue heron", 32 | "great egret", "bittern bird", "crane bird", "limpkin", "common gallinule", "American coot", 33 | "bustard", "ruddy turnstone", "dunlin", "common redshank", "dowitcher", "oystercatcher", 34 | "pelican", "king penguin", "albatross", "grey whale", "killer whale", "dugong", "sea lion", 35 | "Chihuahua", "Japanese Chin", "Maltese", "Pekingese", "Shih Tzu", "King Charles Spaniel", 36 | "Papillon", "toy terrier", "Rhodesian Ridgeback", "Afghan Hound", "Basset Hound", "Beagle", 37 | "Bloodhound", "Bluetick Coonhound", "Black and Tan Coonhound", "Treeing Walker Coonhound", 38 | "English foxhound", "Redbone Coonhound", "borzoi", "Irish Wolfhound", "Italian Greyhound", 39 | "Whippet", "Ibizan Hound", "Norwegian Elkhound", "Otterhound", "Saluki", "Scottish Deerhound", 40 | "Weimaraner", "Staffordshire Bull Terrier", "American Staffordshire Terrier", 41 | "Bedlington Terrier", "Border Terrier", "Kerry Blue Terrier", "Irish Terrier", 42 | "Norfolk Terrier", "Norwich Terrier", "Yorkshire Terrier", "Wire Fox Terrier", 43 | "Lakeland Terrier", "Sealyham Terrier", "Airedale Terrier", "Cairn Terrier", 44 | "Australian Terrier", "Dandie Dinmont Terrier", "Boston Terrier", "Miniature Schnauzer", 45 | "Giant Schnauzer", "Standard Schnauzer", "Scottish Terrier", "Tibetan Terrier", 46 | "Australian Silky Terrier", "Soft-coated Wheaten Terrier", "West Highland White Terrier", 47 | "Lhasa Apso", "Flat-Coated Retriever", "Curly-coated Retriever", "Golden Retriever", 48 | "Labrador Retriever", "Chesapeake Bay Retriever", "German Shorthaired Pointer", "Vizsla", 49 | "English Setter", "Irish Setter", "Gordon Setter", "Brittany dog", "Clumber Spaniel", 50 | "English Springer Spaniel", "Welsh Springer Spaniel", "Cocker Spaniel", "Sussex Spaniel", 51 | "Irish Water Spaniel", "Kuvasz", "Schipperke", "Groenendael dog", "Malinois", "Briard", 52 | "Australian Kelpie", "Komondor", "Old English Sheepdog", "Shetland Sheepdog", "collie", 53 | "Border Collie", "Bouvier des Flandres dog", "Rottweiler", "German Shepherd Dog", "Dobermann", 54 | "Miniature Pinscher", "Greater Swiss Mountain Dog", "Bernese Mountain Dog", 55 | "Appenzeller Sennenhund", "Entlebucher Sennenhund", "Boxer", "Bullmastiff", "Tibetan Mastiff", 56 | "French Bulldog", "Great Dane", "St. Bernard", "husky", "Alaskan Malamute", "Siberian Husky", 57 | "Dalmatian", "Affenpinscher", "Basenji", "pug", "Leonberger", "Newfoundland dog", 58 | "Great Pyrenees dog", "Samoyed", "Pomeranian", "Chow Chow", "Keeshond", "brussels griffon", 59 | "Pembroke Welsh Corgi", "Cardigan Welsh Corgi", "Toy Poodle", "Miniature Poodle", 60 | "Standard Poodle", "Mexican hairless dog (xoloitzcuintli)", "grey wolf", "Alaskan tundra wolf", 61 | "red wolf or maned wolf", "coyote", "dingo", "dhole", "African wild dog", "hyena", "red fox", 62 | "kit fox", "Arctic fox", "grey fox", "tabby cat", "tiger cat", "Persian cat", "Siamese cat", 63 | "Egyptian Mau", "cougar", "lynx", "leopard", "snow leopard", "jaguar", "lion", "tiger", 64 | "cheetah", "brown bear", "American black bear", "polar bear", "sloth bear", "mongoose", 65 | "meerkat", "tiger beetle", "ladybug", "ground beetle", "longhorn beetle", "leaf beetle", 66 | "dung beetle", "rhinoceros beetle", "weevil", "fly", "bee", "ant", "grasshopper", 67 | "cricket insect", "stick insect", "cockroach", "praying mantis", "cicada", "leafhopper", 68 | "lacewing", "dragonfly", "damselfly", "red admiral butterfly", "ringlet butterfly", 69 | "monarch butterfly", "small white butterfly", "sulphur butterfly", "gossamer-winged butterfly", 70 | "starfish", "sea urchin", "sea cucumber", "cottontail rabbit", "hare", "Angora rabbit", 71 | "hamster", "porcupine", "fox squirrel", "marmot", "beaver", "guinea pig", "common sorrel horse", 72 | "zebra", "pig", "wild boar", "warthog", "hippopotamus", "ox", "water buffalo", "bison", 73 | "ram (adult male sheep)", "bighorn sheep", "Alpine ibex", "hartebeest", "impala (antelope)", 74 | "gazelle", "arabian camel", "llama", "weasel", "mink", "European polecat", 75 | "black-footed ferret", "otter", "skunk", "badger", "armadillo", "three-toed sloth", "orangutan", 76 | "gorilla", "chimpanzee", "gibbon", "siamang", "guenon", "patas monkey", "baboon", "macaque", 77 | "langur", "black-and-white colobus", "proboscis monkey", "marmoset", "white-headed capuchin", 78 | "howler monkey", "titi monkey", "Geoffroy's spider monkey", "common squirrel monkey", 79 | "ring-tailed lemur", "indri", "Asian elephant", "African bush elephant", "red panda", 80 | "giant panda", "snoek fish", "eel", "silver salmon", "rock beauty fish", "clownfish", 81 | "sturgeon", "gar fish", "lionfish", "pufferfish", "abacus", "abaya", "academic gown", 82 | "accordion", "acoustic guitar", "aircraft carrier", "airliner", "airship", "altar", "ambulance", 83 | "amphibious vehicle", "analog clock", "apiary", "apron", "trash can", "assault rifle", 84 | "backpack", "bakery", "balance beam", "balloon", "ballpoint pen", "Band-Aid", "banjo", 85 | "baluster / handrail", "barbell", "barber chair", "barbershop", "barn", "barometer", "barrel", 86 | "wheelbarrow", "baseball", "basketball", "bassinet", "bassoon", "swimming cap", "bath towel", 87 | "bathtub", "station wagon", "lighthouse", "beaker", "military hat (bearskin or shako)", 88 | "beer bottle", "beer glass", "bell tower", "baby bib", "tandem bicycle", "bikini", 89 | "ring binder", "binoculars", "birdhouse", "boathouse", "bobsleigh", "bolo tie", "poke bonnet", 90 | "bookcase", "bookstore", "bottle cap", "hunting bow", "bow tie", "brass memorial plaque", "bra", 91 | "breakwater", "breastplate", "broom", "bucket", "buckle", "bulletproof vest", 92 | "high-speed train", "butcher shop", "taxicab", "cauldron", "candle", "cannon", "canoe", 93 | "can opener", "cardigan", "car mirror", "carousel", "tool kit", "cardboard box / carton", 94 | "car wheel", "automated teller machine", "cassette", "cassette player", "castle", "catamaran", 95 | "CD player", "cello", "mobile phone", "chain", "chain-link fence", "chain mail", "chainsaw", 96 | "storage chest", "chiffonier", "bell or wind chime", "china cabinet", "Christmas stocking", 97 | "church", "movie theater", "cleaver", "cliff dwelling", "cloak", "clogs", "cocktail shaker", 98 | "coffee mug", "coffeemaker", "spiral or coil", "combination lock", "computer keyboard", 99 | "candy store", "container ship", "convertible", "corkscrew", "cornet", "cowboy boot", 100 | "cowboy hat", "cradle", "construction crane", "crash helmet", "crate", "infant bed", 101 | "Crock Pot", "croquet ball", "crutch", "cuirass", "dam", "desk", "desktop computer", 102 | "rotary dial telephone", "diaper", "digital clock", "digital watch", "dining table", 103 | "dishcloth", "dishwasher", "disc brake", "dock", "dog sled", "dome", "doormat", "drilling rig", 104 | "drum", "drumstick", "dumbbell", "Dutch oven", "electric fan", "electric guitar", 105 | "electric locomotive", "entertainment center", "envelope", "espresso machine", "face powder", 106 | "feather boa", "filing cabinet", "fireboat", "fire truck", "fire screen", "flagpole", "flute", 107 | "folding chair", "football helmet", "forklift", "fountain", "fountain pen", "four-poster bed", 108 | "freight car", "French horn", "frying pan", "fur coat", "garbage truck", 109 | "gas mask or respirator", "gas pump", "goblet", "go-kart", "golf ball", "golf cart", "gondola", 110 | "gong", "gown", "grand piano", "greenhouse", "radiator grille", "grocery store", "guillotine", 111 | "hair clip", "hair spray", "half-track", "hammer", "hamper", "hair dryer", "hand-held computer", 112 | "handkerchief", "hard disk drive", "harmonica", "harp", "combine harvester", "hatchet", 113 | "holster", "home theater", "honeycomb", "hook", "hoop skirt", "gymnastic horizontal bar", 114 | "horse-drawn vehicle", "hourglass", "iPod", "clothes iron", "carved pumpkin", "jeans", "jeep", 115 | "T-shirt", "jigsaw puzzle", "rickshaw", "joystick", "kimono", "knee pad", "knot", "lab coat", 116 | "ladle", "lampshade", "laptop computer", "lawn mower", "lens cap", "letter opener", "library", 117 | "lifeboat", "lighter", "limousine", "ocean liner", "lipstick", "slip-on shoe", "lotion", 118 | "music speaker", "loupe magnifying glass", "sawmill", "magnetic compass", "messenger bag", 119 | "mailbox", "tights", "one-piece bathing suit", "manhole cover", "maraca", "marimba", "mask", 120 | "matchstick", "maypole", "maze", "measuring cup", "medicine cabinet", "megalith", "microphone", 121 | "microwave oven", "military uniform", "milk can", "minibus", "miniskirt", "minivan", "missile", 122 | "mitten", "mixing bowl", "mobile home", "ford model t", "modem", "monastery", "monitor", 123 | "moped", "mortar and pestle", "graduation cap", "mosque", "mosquito net", "vespa", 124 | "mountain bike", "tent", "computer mouse", "mousetrap", "moving van", "muzzle", "metal nail", 125 | "neck brace", "necklace", "baby pacifier", "notebook computer", "obelisk", "oboe", "ocarina", 126 | "odometer", "oil filter", "pipe organ", "oscilloscope", "overskirt", "bullock cart", 127 | "oxygen mask", "product packet / packaging", "paddle", "paddle wheel", "padlock", "paintbrush", 128 | "pajamas", "palace", "pan flute", "paper towel", "parachute", "parallel bars", "park bench", 129 | "parking meter", "railroad car", "patio", "payphone", "pedestal", "pencil case", 130 | "pencil sharpener", "perfume", "Petri dish", "photocopier", "plectrum", "Pickelhaube", 131 | "picket fence", "pickup truck", "pier", "piggy bank", "pill bottle", "pillow", "ping-pong ball", 132 | "pinwheel", "pirate ship", "drink pitcher", "block plane", "planetarium", "plastic bag", 133 | "plate rack", "farm plow", "plunger", "Polaroid camera", "pole", "police van", "poncho", 134 | "pool table", "soda bottle", "plant pot", "potter's wheel", "power drill", "prayer rug", 135 | "printer", "prison", "missile", "projector", "hockey puck", "punching bag", "purse", "quill", 136 | "quilt", "race car", "racket", "radiator", "radio", "radio telescope", "rain barrel", 137 | "recreational vehicle", "fishing casting reel", "reflex camera", "refrigerator", 138 | "remote control", "restaurant", "revolver", "rifle", "rocking chair", "rotisserie", "eraser", 139 | "rugby ball", "ruler measuring stick", "sneaker", "safe", "safety pin", "salt shaker", "sandal", 140 | "sarong", "saxophone", "scabbard", "weighing scale", "school bus", "schooner", "scoreboard", 141 | "CRT monitor", "screw", "screwdriver", "seat belt", "sewing machine", "shield", "shoe store", 142 | "shoji screen / room divider", "shopping basket", "shopping cart", "shovel", "shower cap", 143 | "shower curtain", "ski", "balaclava ski mask", "sleeping bag", "slide rule", "sliding door", 144 | "slot machine", "snorkel", "snowmobile", "snowplow", "soap dispenser", "soccer ball", "sock", 145 | "solar thermal collector", "sombrero", "soup bowl", "keyboard space bar", "space heater", 146 | "space shuttle", "spatula", "motorboat", "spider web", "spindle", "sports car", "spotlight", 147 | "stage", "steam locomotive", "through arch bridge", "steel drum", "stethoscope", "scarf", 148 | "stone wall", "stopwatch", "stove", "strainer", "tram", "stretcher", "couch", "stupa", 149 | "submarine", "suit", "sundial", "sunglasses", "sunglasses", "sunscreen", "suspension bridge", 150 | "mop", "sweatshirt", "swim trunks / shorts", "swing", "electrical switch", "syringe", 151 | "table lamp", "tank", "tape player", "teapot", "teddy bear", "television", "tennis ball", 152 | "thatched roof", "front curtain", "thimble", "threshing machine", "throne", "tile roof", 153 | "toaster", "tobacco shop", "toilet seat", "torch", "totem pole", "tow truck", "toy store", 154 | "tractor", "semi-trailer truck", "tray", "trench coat", "tricycle", "trimaran", "tripod", 155 | "triumphal arch", "trolleybus", "trombone", "hot tub", "turnstile", "typewriter keyboard", 156 | "umbrella", "unicycle", "upright piano", "vacuum cleaner", "vase", "vaulted or arched ceiling", 157 | "velvet fabric", "vending machine", "vestment", "viaduct", "violin", "volleyball", 158 | "waffle iron", "wall clock", "wallet", "wardrobe", "military aircraft", "sink", 159 | "washing machine", "water bottle", "water jug", "water tower", "whiskey jug", "whistle", 160 | "hair wig", "window screen", "window shade", "Windsor tie", "wine bottle", "airplane wing", 161 | "wok", "wooden spoon", "wool", "split-rail fence", "shipwreck", "sailboat", "yurt", "website", 162 | "comic book", "crossword", "traffic or street sign", "traffic light", "dust jacket", "menu", 163 | "plate", "guacamole", "consomme", "hot pot", "trifle", "ice cream", "popsicle", "baguette", 164 | "bagel", "pretzel", "cheeseburger", "hot dog", "mashed potatoes", "cabbage", "broccoli", 165 | "cauliflower", "zucchini", "spaghetti squash", "acorn squash", "butternut squash", "cucumber", 166 | "artichoke", "bell pepper", "cardoon", "mushroom", "Granny Smith apple", "strawberry", "orange", 167 | "lemon", "fig", "pineapple", "banana", "jackfruit", "cherimoya (custard apple)", "pomegranate", 168 | "hay", "carbonara", "chocolate syrup", "dough", "meatloaf", "pizza", "pot pie", "burrito", 169 | "red wine", "espresso", "tea cup", "eggnog", "mountain", "bubble", "cliff", "coral reef", 170 | "geyser", "lakeshore", "promontory", "sandbar", "beach", "valley", "volcano", "baseball player", 171 | "bridegroom", "scuba diver", "rapeseed", "daisy", "yellow lady's slipper", "corn", "acorn", 172 | "rose hip", "horse chestnut seed", "coral fungus", "agaric", "gyromitra", "stinkhorn mushroom", 173 | "earth star fungus", "hen of the woods mushroom", "bolete", "corn cob", "toilet paper" 174 | ] 175 | return imagenet_classnames 176 | 177 | class ImageNet: 178 | def __init__(self, 179 | preprocess, 180 | location=os.path.expanduser('./data'), 181 | batch_size=32, 182 | num_workers=32): 183 | self.preprocess = preprocess 184 | self.location = location 185 | self.batch_size = batch_size 186 | self.num_workers = num_workers 187 | self.classnames = get_imagenet_classnames() 188 | 189 | self.populate_train() 190 | self.populate_test() 191 | 192 | def populate_train(self): 193 | traindir = os.path.join(self.location, self.name(), 'train') 194 | self.train_dataset = ImageFolderWithPaths( 195 | traindir, 196 | transform=self.preprocess) 197 | sampler = self.get_train_sampler() 198 | kwargs = {'shuffle' : True} if sampler is None else {} 199 | self.train_loader = torch.utils.data.DataLoader( 200 | self.train_dataset, 201 | sampler=sampler, 202 | batch_size=self.batch_size, 203 | num_workers=self.num_workers, 204 | **kwargs, 205 | ) 206 | 207 | def populate_test(self): 208 | self.test_dataset = self.get_test_dataset() 209 | self.test_loader = torch.utils.data.DataLoader( 210 | self.test_dataset, 211 | batch_size=self.batch_size, 212 | num_workers=self.num_workers, 213 | sampler=self.get_test_sampler() 214 | ) 215 | self.test_loader_shuffle = torch.utils.data.DataLoader( 216 | self.test_dataset, 217 | shuffle=True, 218 | batch_size=self.batch_size, 219 | num_workers=self.num_workers, 220 | sampler=self.get_test_sampler() 221 | ) 222 | 223 | def get_test_path(self): 224 | test_path = os.path.join(self.location, self.name(), 'val_in_folder') 225 | if not os.path.exists(test_path): 226 | test_path = os.path.join(self.location, self.name(), 'val') 227 | return test_path 228 | 229 | def get_train_sampler(self): 230 | return None 231 | 232 | def get_test_sampler(self): 233 | return None 234 | 235 | def get_test_dataset(self): 236 | return ImageFolderWithPaths(self.get_test_path(), transform=self.preprocess) 237 | 238 | def name(self): 239 | return 'imagenet' 240 | 241 | class ImageNetTrain(ImageNet): 242 | 243 | def get_test_dataset(self): 244 | pass 245 | 246 | class ImageNetK(ImageNet): 247 | 248 | def get_train_sampler(self): 249 | idxs = np.zeros(len(self.train_dataset.targets)) 250 | target_array = np.array(self.train_dataset.targets) 251 | for c in range(1000): 252 | m = target_array == c 253 | n = len(idxs[m]) 254 | arr = np.zeros(n) 255 | arr[:self.k()] = 1 256 | np.random.shuffle(arr) 257 | idxs[m] = arr 258 | 259 | idxs = idxs.astype('int') 260 | sampler = SubsetSampler(np.where(idxs)[0]) 261 | return sampler -------------------------------------------------------------------------------- /src/datasets/imagenet100.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torchvision.datasets as datasets 4 | 5 | class ImageFolderDataset(datasets.ImageFolder): 6 | def __init__(self, root, transform): 7 | super().__init__(root, transform) 8 | 9 | def __getitem__(self, index: int): 10 | path, target = self.samples[index] 11 | sample = self.loader(path) 12 | if self.transform is not None: 13 | sample = self.transform(sample) 14 | if self.target_transform is not None: 15 | target = self.target_transform(target) 16 | return sample, target, index 17 | 18 | class ImageNet100: 19 | def __init__(self, 20 | preprocess, 21 | location=os.path.expanduser('./data'), 22 | batch_size=32, 23 | num_workers=16): 24 | # Data loading code 25 | location = './data' 26 | traindir = os.path.join(location, 'ImageNet100', 'train') 27 | valdir = os.path.join(location, 'ImageNet100', 'val') 28 | 29 | self.train_dataset = ImageFolderDataset( 30 | traindir, transform=preprocess) 31 | self.train_loader = torch.utils.data.DataLoader( 32 | self.train_dataset, 33 | shuffle=True, 34 | batch_size=batch_size, 35 | num_workers=num_workers, 36 | ) 37 | self.test_dataset = ImageFolderDataset(valdir, transform=preprocess) 38 | self.test_loader = torch.utils.data.DataLoader( 39 | self.test_dataset, 40 | batch_size=batch_size, 41 | num_workers=num_workers 42 | ) 43 | self.test_loader_shuffle = torch.utils.data.DataLoader( 44 | self.test_dataset, 45 | shuffle=True, 46 | batch_size=batch_size, 47 | num_workers=num_workers 48 | ) 49 | idx_to_class = dict((v, k) 50 | for k, v in self.train_dataset.class_to_idx.items()) 51 | self.classnames = [idx_to_class[i].replace( 52 | '_', ' ') for i in range(len(idx_to_class))] -------------------------------------------------------------------------------- /src/datasets/mnist.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torchvision.datasets as datasets 4 | from PIL import Image 5 | 6 | class MyMNIST(datasets.MNIST): 7 | def __init__(self, root, download, train, transform): 8 | super().__init__(root=root, download=download, train=train, transform=transform) 9 | 10 | def __getitem__(self, index: int): 11 | """ 12 | Args: 13 | index (int): Index 14 | 15 | Returns: 16 | tuple: (image, target) where target is index of the target class. 17 | """ 18 | img, target = self.data[index], int(self.targets[index]) 19 | 20 | # doing this so that it is consistent with all other datasets 21 | # to return a PIL Image 22 | img = Image.fromarray(img.numpy(), mode="L") 23 | 24 | if self.transform is not None: 25 | img = self.transform(img) 26 | 27 | if self.target_transform is not None: 28 | target = self.target_transform(target) 29 | 30 | return img, target, index 31 | 32 | class MNIST: 33 | def __init__(self, 34 | preprocess, 35 | location=os.path.expanduser('./data'), 36 | batch_size=128, 37 | num_workers=16): 38 | 39 | 40 | self.train_dataset = MyMNIST( 41 | root=location, 42 | download=True, 43 | train=True, 44 | transform=preprocess 45 | ) 46 | 47 | self.train_loader = torch.utils.data.DataLoader( 48 | self.train_dataset, 49 | batch_size=batch_size, 50 | shuffle=True, 51 | num_workers=num_workers 52 | ) 53 | 54 | self.test_dataset = MyMNIST( 55 | root=location, 56 | download=True, 57 | train=False, 58 | transform=preprocess 59 | ) 60 | 61 | self.test_loader = torch.utils.data.DataLoader( 62 | self.test_dataset, 63 | batch_size=batch_size, 64 | shuffle=False, 65 | num_workers=num_workers 66 | ) 67 | 68 | self.test_loader_shuffle = torch.utils.data.DataLoader( 69 | self.test_dataset, 70 | shuffle=True, 71 | batch_size=batch_size, 72 | num_workers=num_workers 73 | ) 74 | 75 | self.classnames = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'] -------------------------------------------------------------------------------- /src/datasets/pets.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torchvision.datasets as datasets 4 | 5 | class ImageFolderDataset(datasets.ImageFolder): 6 | def __init__(self, root, transform): 7 | super().__init__(root, transform) 8 | 9 | def __getitem__(self, index: int): 10 | path, target = self.samples[index] 11 | sample = self.loader(path) 12 | if self.transform is not None: 13 | sample = self.transform(sample) 14 | if self.target_transform is not None: 15 | target = self.target_transform(target) 16 | return sample, target, index 17 | 18 | class PETS: 19 | def __init__(self, 20 | preprocess, 21 | location=os.path.expanduser('./data'), 22 | batch_size=32, 23 | num_workers=16): 24 | # Data loading code 25 | location = './data' 26 | traindir = os.path.join(location, 'pets', 'train') 27 | valdir = os.path.join(location, 'pets', 'test') 28 | 29 | self.train_dataset = ImageFolderDataset( 30 | traindir, transform=preprocess) 31 | self.train_loader = torch.utils.data.DataLoader( 32 | self.train_dataset, 33 | shuffle=True, 34 | batch_size=batch_size, 35 | num_workers=num_workers, 36 | ) 37 | 38 | self.test_dataset = ImageFolderDataset(valdir, transform=preprocess) 39 | self.test_loader = torch.utils.data.DataLoader( 40 | self.test_dataset, 41 | batch_size=batch_size, 42 | num_workers=num_workers 43 | ) 44 | self.test_loader_shuffle = torch.utils.data.DataLoader( 45 | self.test_dataset, 46 | shuffle=True, 47 | batch_size=batch_size, 48 | num_workers=num_workers 49 | ) 50 | idx_to_class = dict((v, k) 51 | for k, v in self.train_dataset.class_to_idx.items()) 52 | self.classnames = [idx_to_class[i].replace( 53 | '_', ' ') for i in range(len(idx_to_class))] -------------------------------------------------------------------------------- /src/datasets/registry.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import inspect 3 | import random 4 | import torch 5 | import copy 6 | from torch.utils.data.dataset import random_split 7 | from src.datasets.cars import Cars 8 | from src.datasets.cifar10 import CIFAR10 9 | from src.datasets.cifar100 import CIFAR100 10 | from src.datasets.dtd import DTD 11 | from src.datasets.eurosat import EuroSAT, EuroSATVal 12 | from src.datasets.gtsrb import GTSRB 13 | from src.datasets.imagenet import ImageNet 14 | from src.datasets.mnist import MNIST 15 | from src.datasets.resisc45 import RESISC45 16 | from src.datasets.stl10 import STL10 17 | from src.datasets.svhn import SVHN 18 | from src.datasets.sun397 import SUN397 19 | from src.datasets.pets import PETS 20 | from src.datasets.flowers import Flowers, FlowersVal 21 | from src.datasets.imagenet100 import ImageNet100 22 | from src.datasets.common import get_dataloader, maybe_dictionarize 23 | registry = { 24 | name: obj for name, obj in inspect.getmembers(sys.modules[__name__], inspect.isclass) 25 | } 26 | 27 | class GenericDataset(object): 28 | def __init__(self): 29 | self.train_dataset = None 30 | self.train_loader = None 31 | self.test_dataset = None 32 | self.test_loader = None 33 | self.classnames = None 34 | 35 | 36 | def split_train_into_train_dev_cifar_mnist(dataset, new_dataset_class_name, batch_size, num_workers, val_fraction, test_length, max_val_samples=None, seed=0): 37 | assert val_fraction > 0. and val_fraction < 1. 38 | total_size = len(dataset.train_dataset) 39 | val_size = test_length # shadow train = shadow test 40 | if max_val_samples is not None: 41 | val_size = min(val_size, max_val_samples) 42 | train_size = total_size - val_size 43 | target_train_size = int(train_size/2) 44 | target_test_size = train_size - target_train_size 45 | assert val_size > 0 46 | assert train_size > 0 47 | lengths = [target_train_size, target_test_size, val_size] 48 | print(lengths) 49 | trainset, valset, shadowset = random_split( 50 | dataset.train_dataset, 51 | lengths, 52 | generator=torch.Generator().manual_seed(seed) # same split 53 | ) 54 | 55 | new_dataset = None 56 | new_dataset_class = type(new_dataset_class_name, (GenericDataset, ), {}) 57 | new_dataset = new_dataset_class() 58 | new_dataset.train_dataset = trainset 59 | new_dataset.train_loader = torch.utils.data.DataLoader( 60 | new_dataset.train_dataset, 61 | shuffle=True, 62 | batch_size=batch_size, 63 | num_workers=num_workers, 64 | ) 65 | new_dataset.test_dataset = valset 66 | new_dataset.test_loader = torch.utils.data.DataLoader( 67 | new_dataset.test_dataset, 68 | batch_size=batch_size, 69 | num_workers=num_workers 70 | ) 71 | new_dataset.test_loader_shuffle = torch.utils.data.DataLoader( 72 | new_dataset.test_dataset, 73 | batch_size=batch_size, 74 | num_workers=num_workers, 75 | shuffle=True 76 | ) 77 | 78 | new_dataset.shadowtrain_dataset = shadowset 79 | new_dataset.shadowtrain_loader = torch.utils.data.DataLoader( 80 | new_dataset.shadowtrain_dataset, 81 | shuffle=True, 82 | batch_size=batch_size, 83 | num_workers=num_workers, 84 | ) 85 | 86 | new_dataset.classnames = copy.copy(dataset.classnames) 87 | return new_dataset 88 | 89 | def get_dataset_classnames(dataset_name, preprocess, location, batch_size=128, num_workers=16): 90 | assert dataset_name in registry, f'Unsupported dataset: {dataset_name}. Supported datasets: {list(registry.keys())}' 91 | dataset_class = registry[dataset_name] 92 | dataset = dataset_class( 93 | preprocess, location=location, batch_size=batch_size, num_workers=num_workers 94 | ) 95 | return dataset.classnames 96 | 97 | 98 | def get_dataset_cifar_mnist(dataset_name, split, preprocess, location, batch_size=128, num_workers=16, val_fraction=0.4, max_val_samples=500000): 99 | print(location) 100 | # if dataset_name == 'MNIST': 101 | # val_fraction = 0.5 102 | if split=='train': 103 | if dataset_name=='EuroSAT': 104 | dataset_class = registry[dataset_name+"Val"] 105 | dataset = dataset_class( 106 | preprocess, location=location, batch_size=batch_size, num_workers=num_workers 107 | ) 108 | else: 109 | dataset_class = registry[dataset_name] 110 | base_dataset = dataset_class( 111 | preprocess, location=location, batch_size=batch_size, num_workers=num_workers 112 | ) 113 | # print("base_dataset: ", len(base_dataset.test_dataset)) 114 | if dataset_name == 'PETS': 115 | len_val = 1400 116 | elif dataset_name == 'STL10': 117 | len_val = 1600 118 | else: 119 | len_val = len(base_dataset.test_dataset) 120 | dataset = split_train_into_train_dev_cifar_mnist( 121 | base_dataset, dataset_name, batch_size, num_workers, val_fraction, len_val, max_val_samples) 122 | return dataset.train_dataset, get_dataloader(dataset, split=split) 123 | 124 | elif split=='test': 125 | if dataset_name=='EuroSAT': 126 | dataset_class = registry[dataset_name+"Val"] 127 | dataset = dataset_class( 128 | preprocess, location=location, batch_size=batch_size, num_workers=num_workers 129 | ) 130 | else: 131 | dataset_class = registry[dataset_name] 132 | base_dataset = dataset_class( 133 | preprocess, location=location, batch_size=batch_size, num_workers=num_workers 134 | ) 135 | if dataset_name == 'PETS': 136 | len_val = 1400 137 | elif dataset_name == 'STL10': 138 | len_val = 1600 139 | else: 140 | len_val = len(base_dataset.test_dataset) 141 | dataset = split_train_into_train_dev_cifar_mnist( 142 | base_dataset, dataset_name, batch_size, num_workers, val_fraction, len_val, max_val_samples) 143 | return dataset.test_dataset, get_dataloader(dataset, split=split) 144 | 145 | elif split=='shadowtrain': 146 | if dataset_name=='EuroSAT': 147 | dataset_class = registry[dataset_name+"Val"] 148 | dataset = dataset_class( 149 | preprocess, location=location, batch_size=batch_size, num_workers=num_workers 150 | ) 151 | else: 152 | dataset_class = registry[dataset_name] 153 | base_dataset = dataset_class( 154 | preprocess, location=location, batch_size=batch_size, num_workers=num_workers 155 | ) 156 | if dataset_name == 'PETS': 157 | len_val = 1400 158 | elif dataset_name == 'STL10': 159 | len_val = 1600 160 | else: 161 | len_val = len(base_dataset.test_dataset) 162 | dataset = split_train_into_train_dev_cifar_mnist( 163 | base_dataset, dataset_name, batch_size, num_workers, val_fraction, len_val, max_val_samples) 164 | return dataset.shadowtrain_dataset, get_dataloader(dataset, split=split) 165 | 166 | elif split=='shadowtest' or split=='shadowtest_shuffled': 167 | assert dataset_name in registry, f'Unsupported dataset: {dataset_name}. Supported datasets: {list(registry.keys())}' 168 | dataset_class = registry[dataset_name] 169 | base_dataset = dataset_class( 170 | preprocess, location=location, batch_size=batch_size, num_workers=num_workers 171 | ) 172 | test_loader = torch.utils.data.DataLoader( 173 | base_dataset.test_dataset, 174 | batch_size=batch_size, 175 | num_workers=num_workers, 176 | shuffle=True 177 | ) 178 | return base_dataset.test_dataset, test_loader 179 | 180 | 181 | 182 | else: 183 | raise "Not implemented" -------------------------------------------------------------------------------- /src/datasets/resisc45.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | import abc 5 | import os 6 | from typing import Any, Callable, Dict, Optional, Tuple 7 | 8 | import numpy as np 9 | import torch 10 | from torch import Tensor 11 | from torch.utils.data import Dataset 12 | from torchvision.datasets import ImageFolder 13 | from torchvision.datasets.folder import default_loader as pil_loader 14 | 15 | 16 | # modified from: https://github.com/microsoft/torchgeo 17 | class VisionDataset(Dataset[Dict[str, Any]], abc.ABC): 18 | """Abstract base class for datasets lacking geospatial information. 19 | This base class is designed for datasets with pre-defined image chips. 20 | """ 21 | 22 | @abc.abstractmethod 23 | def __getitem__(self, index: int) -> Dict[str, Any]: 24 | """Return an index within the dataset. 25 | Args: 26 | index: index to return 27 | Returns: 28 | data and labels at that index 29 | Raises: 30 | IndexError: if index is out of range of the dataset 31 | """ 32 | 33 | @abc.abstractmethod 34 | def __len__(self) -> int: 35 | """Return the length of the dataset. 36 | Returns: 37 | length of the dataset 38 | """ 39 | 40 | def __str__(self) -> str: 41 | """Return the informal string representation of the object. 42 | Returns: 43 | informal string representation 44 | """ 45 | return f"""\ 46 | {self.__class__.__name__} Dataset 47 | type: VisionDataset 48 | size: {len(self)}""" 49 | 50 | 51 | class VisionClassificationDataset(VisionDataset, ImageFolder): 52 | """Abstract base class for classification datasets lacking geospatial information. 53 | This base class is designed for datasets with pre-defined image chips which 54 | are separated into separate folders per class. 55 | """ 56 | 57 | def __init__( 58 | self, 59 | root: str, 60 | transforms: Optional[Callable[[Dict[str, Tensor]], Dict[str, Tensor]]] = None, 61 | loader: Optional[Callable[[str], Any]] = pil_loader, 62 | is_valid_file: Optional[Callable[[str], bool]] = None, 63 | ) -> None: 64 | """Initialize a new VisionClassificationDataset instance. 65 | Args: 66 | root: root directory where dataset can be found 67 | transforms: a function/transform that takes input sample and its target as 68 | entry and returns a transformed version 69 | loader: a callable function which takes as input a path to an image and 70 | returns a PIL Image or numpy array 71 | is_valid_file: A function that takes the path of an Image file and checks if 72 | the file is a valid file 73 | """ 74 | # When transform & target_transform are None, ImageFolder.__getitem__(index) 75 | # returns a PIL.Image and int for image and label, respectively 76 | super().__init__( 77 | root=root, 78 | transform=None, 79 | target_transform=None, 80 | loader=loader, 81 | is_valid_file=is_valid_file, 82 | ) 83 | 84 | # Must be set after calling super().__init__() 85 | self.transforms = transforms 86 | 87 | def __getitem__(self, index: int) -> Dict[str, Tensor]: 88 | """Return an index within the dataset. 89 | Args: 90 | index: index to return 91 | Returns: 92 | data and label at that index 93 | """ 94 | image, label = self._load_image(index) 95 | 96 | if self.transforms is not None: 97 | return self.transforms(image), label, index 98 | 99 | return image, label, index 100 | 101 | def __len__(self) -> int: 102 | """Return the number of data points in the dataset. 103 | Returns: 104 | length of the dataset 105 | """ 106 | return len(self.imgs) 107 | 108 | def _load_image(self, index: int) -> Tuple[Tensor, Tensor]: 109 | """Load a single image and it's class label. 110 | Args: 111 | index: index to return 112 | Returns: 113 | the image 114 | the image class label 115 | """ 116 | img, label = ImageFolder.__getitem__(self, index) 117 | label = torch.tensor(label) 118 | return img, label 119 | 120 | 121 | class RESISC45Dataset(VisionClassificationDataset): 122 | """RESISC45 dataset. 123 | The `RESISC45 `_ 124 | dataset is a dataset for remote sensing image scene classification. 125 | Dataset features: 126 | * 31,500 images with 0.2-30 m per pixel resolution (256x256 px) 127 | * three spectral bands - RGB 128 | * 45 scene classes, 700 images per class 129 | * images extracted from Google Earth from over 100 countries 130 | * images conditions with high variability (resolution, weather, illumination) 131 | Dataset format: 132 | * images are three-channel jpgs 133 | Dataset classes: 134 | 0. airplane 135 | 1. airport 136 | 2. baseball_diamond 137 | 3. basketball_court 138 | 4. beach 139 | 5. bridge 140 | 6. chaparral 141 | 7. church 142 | 8. circular_farmland 143 | 9. cloud 144 | 10. commercial_area 145 | 11. dense_residential 146 | 12. desert 147 | 13. forest 148 | 14. freeway 149 | 15. golf_course 150 | 16. ground_track_field 151 | 17. harbor 152 | 18. industrial_area 153 | 19. intersection 154 | 20. island 155 | 21. lake 156 | 22. meadow 157 | 23. medium_residential 158 | 24. mobile_home_park 159 | 25. mountain 160 | 26. overpass 161 | 27. palace 162 | 28. parking_lot 163 | 29. railway 164 | 30. railway_station 165 | 31. rectangular_farmland 166 | 32. river 167 | 33. roundabout 168 | 34. runway 169 | 35. sea_ice 170 | 36. ship 171 | 37. snowberg 172 | 38. sparse_residential 173 | 39. stadium 174 | 40. storage_tank 175 | 41. tennis_court 176 | 42. terrace 177 | 43. thermal_power_station 178 | 44. wetland 179 | This dataset uses the train/val/test splits defined in the "In-domain representation 180 | learning for remote sensing" paper: 181 | * https://arxiv.org/abs/1911.06721 182 | If you use this dataset in your research, please cite the following paper: 183 | * https://doi.org/10.1109/jproc.2017.2675998 184 | """ 185 | 186 | # url = "https://drive.google.com/file/d/1DnPSU5nVSN7xv95bpZ3XQ0JhKXZOKgIv" 187 | # md5 = "d824acb73957502b00efd559fc6cfbbb" 188 | # filename = "NWPU-RESISC45.rar" 189 | directory = "resisc45/NWPU-RESISC45" 190 | 191 | splits = ["train", "val", "test"] 192 | split_urls = { 193 | "train": "https://storage.googleapis.com/remote_sensing_representations/resisc45-train.txt", # noqa: E501 194 | "val": "https://storage.googleapis.com/remote_sensing_representations/resisc45-val.txt", # noqa: E501 195 | "test": "https://storage.googleapis.com/remote_sensing_representations/resisc45-test.txt", # noqa: E501 196 | } 197 | split_md5s = { 198 | "train": "b5a4c05a37de15e4ca886696a85c403e", 199 | "val": "a0770cee4c5ca20b8c32bbd61e114805", 200 | "test": "3dda9e4988b47eb1de9f07993653eb08", 201 | } 202 | classes = [ 203 | "airplane", 204 | "airport", 205 | "baseball_diamond", 206 | "basketball_court", 207 | "beach", 208 | "bridge", 209 | "chaparral", 210 | "church", 211 | "circular_farmland", 212 | "cloud", 213 | "commercial_area", 214 | "dense_residential", 215 | "desert", 216 | "forest", 217 | "freeway", 218 | "golf_course", 219 | "ground_track_field", 220 | "harbor", 221 | "industrial_area", 222 | "intersection", 223 | "island", 224 | "lake", 225 | "meadow", 226 | "medium_residential", 227 | "mobile_home_park", 228 | "mountain", 229 | "overpass", 230 | "palace", 231 | "parking_lot", 232 | "railway", 233 | "railway_station", 234 | "rectangular_farmland", 235 | "river", 236 | "roundabout", 237 | "runway", 238 | "sea_ice", 239 | "ship", 240 | "snowberg", 241 | "sparse_residential", 242 | "stadium", 243 | "storage_tank", 244 | "tennis_court", 245 | "terrace", 246 | "thermal_power_station", 247 | "wetland", 248 | ] 249 | 250 | def __init__( 251 | self, 252 | root: str = "data", 253 | split: str = "train", 254 | transforms: Optional[Callable[[Dict[str, Tensor]], Dict[str, Tensor]]] = None, 255 | ) -> None: 256 | """Initialize a new RESISC45 dataset instance. 257 | Args: 258 | root: root directory where dataset can be found 259 | split: one of "train", "val", or "test" 260 | transforms: a function/transform that takes input sample and its target as 261 | entry and returns a transformed version 262 | """ 263 | assert split in self.splits 264 | self.root = "./data" 265 | 266 | valid_fns = set() 267 | with open(os.path.join(self.root, "resisc45", f"resisc45-{split}.txt")) as f: 268 | for fn in f: 269 | valid_fns.add(fn.strip()) 270 | is_in_split: Callable[[str], bool] = lambda x: os.path.basename( 271 | x) in valid_fns 272 | 273 | super().__init__( 274 | root=os.path.join(root, self.directory), 275 | transforms=transforms, 276 | is_valid_file=is_in_split, 277 | ) 278 | 279 | 280 | 281 | class RESISC45: 282 | def __init__(self, 283 | preprocess, 284 | location=os.path.expanduser('./data'), 285 | batch_size=32, 286 | num_workers=16): 287 | 288 | self.train_dataset = RESISC45Dataset(root=location, split='train', transforms=preprocess) 289 | self.train_loader = torch.utils.data.DataLoader( 290 | self.train_dataset, 291 | shuffle=True, 292 | batch_size=batch_size, 293 | num_workers=num_workers, 294 | ) 295 | 296 | self.test_dataset = RESISC45Dataset(root=location, split='test', transforms=preprocess) 297 | self.test_loader = torch.utils.data.DataLoader( 298 | self.test_dataset, 299 | batch_size=batch_size, 300 | num_workers=num_workers 301 | ) 302 | self.test_loader_shuffle = torch.utils.data.DataLoader( 303 | self.test_dataset, 304 | shuffle=True, 305 | batch_size=batch_size, 306 | num_workers=num_workers 307 | ) 308 | 309 | # class names have _ so split on this for better zero-shot head 310 | self.classnames = [' '.join(c.split('_')) for c in RESISC45Dataset.classes] 311 | -------------------------------------------------------------------------------- /src/datasets/stl10.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torchvision.datasets as datasets 4 | from PIL import Image 5 | import numpy as np 6 | 7 | class MySTL10(datasets.STL10): 8 | def __init__(self, root, download, split, transform): 9 | super().__init__(root=root, download=download, split=split, transform=transform) 10 | 11 | def __getitem__(self, index: int): 12 | """ 13 | Args: 14 | index (int): Index 15 | 16 | Returns: 17 | tuple: (image, target) where target is index of the target class. 18 | """ 19 | target: Optional[int] 20 | if self.labels is not None: 21 | img, target = self.data[index], int(self.labels[index]) 22 | else: 23 | img, target = self.data[index], None 24 | 25 | # doing this so that it is consistent with all other datasets 26 | # to return a PIL Image 27 | img = Image.fromarray(np.transpose(img, (1, 2, 0))) 28 | 29 | if self.transform is not None: 30 | img = self.transform(img) 31 | 32 | if self.target_transform is not None: 33 | target = self.target_transform(target) 34 | 35 | return img, target, index 36 | 37 | class STL10: 38 | def __init__(self, 39 | preprocess, 40 | location=os.path.expanduser('./data'), 41 | batch_size=128, 42 | num_workers=16): 43 | 44 | location = os.path.join(location, 'stl10') 45 | self.train_dataset = MySTL10( 46 | root=location, 47 | download=True, 48 | split='train', 49 | transform=preprocess 50 | ) 51 | 52 | self.train_loader = torch.utils.data.DataLoader( 53 | self.train_dataset, 54 | batch_size=batch_size, 55 | shuffle=True, 56 | num_workers=num_workers 57 | ) 58 | 59 | self.test_dataset = MySTL10( 60 | root=location, 61 | download=True, 62 | split='test', 63 | transform=preprocess 64 | ) 65 | 66 | self.test_loader = torch.utils.data.DataLoader( 67 | self.test_dataset, 68 | batch_size=batch_size, 69 | shuffle=False, 70 | num_workers=num_workers 71 | ) 72 | 73 | self.test_loader_shuffle = torch.utils.data.DataLoader( 74 | self.test_dataset, 75 | shuffle=True, 76 | batch_size=batch_size, 77 | num_workers=num_workers 78 | ) 79 | 80 | self.classnames = self.train_dataset.classes -------------------------------------------------------------------------------- /src/datasets/sun397.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torchvision.datasets as datasets 4 | 5 | class ImageFolderDataset(datasets.ImageFolder): 6 | def __init__(self, root, transform): 7 | super().__init__(root, transform) 8 | 9 | def __getitem__(self, index: int): 10 | path, target = self.samples[index] 11 | sample = self.loader(path) 12 | if self.transform is not None: 13 | sample = self.transform(sample) 14 | if self.target_transform is not None: 15 | target = self.target_transform(target) 16 | return sample, target, index 17 | 18 | class SUN397: 19 | def __init__(self, 20 | preprocess, 21 | location=os.path.expanduser('./data'), 22 | batch_size=32, 23 | num_workers=16): 24 | # Data loading code 25 | traindir = os.path.join(location, 'sun397', 'train') 26 | valdir = os.path.join(location, 'sun397', 'test') 27 | 28 | 29 | self.train_dataset = ImageFolderDataset(traindir, transform=preprocess) 30 | self.train_loader = torch.utils.data.DataLoader( 31 | self.train_dataset, 32 | shuffle=True, 33 | batch_size=batch_size, 34 | num_workers=num_workers, 35 | ) 36 | 37 | self.test_dataset = ImageFolderDataset(valdir, transform=preprocess) 38 | self.test_loader = torch.utils.data.DataLoader( 39 | self.test_dataset, 40 | batch_size=batch_size, 41 | num_workers=num_workers 42 | ) 43 | self.test_loader_shuffle = torch.utils.data.DataLoader( 44 | self.test_dataset, 45 | shuffle=True, 46 | batch_size=batch_size, 47 | num_workers=num_workers 48 | ) 49 | idx_to_class = dict((v, k) 50 | for k, v in self.train_dataset.class_to_idx.items()) 51 | self.classnames = [idx_to_class[i][3:].replace('_', ' ') for i in range(len(idx_to_class))] 52 | # print(self.classnames) 53 | print(len(self.classnames)) -------------------------------------------------------------------------------- /src/datasets/svhn.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torchvision.datasets import SVHN as PyTorchSVHN 4 | import numpy as np 5 | from PIL import Image 6 | 7 | class MyPyTorchSVHN(PyTorchSVHN): 8 | def __init__(self, root, download, split, transform): 9 | super().__init__(root=root, download=download, split=split, transform=transform) 10 | 11 | def __getitem__(self, index: int): 12 | """ 13 | Args: 14 | index (int): Index 15 | 16 | Returns: 17 | tuple: (image, target) where target is index of the target class. 18 | """ 19 | img, target = self.data[index], int(self.labels[index]) 20 | 21 | # doing this so that it is consistent with all other datasets 22 | # to return a PIL Image 23 | img = Image.fromarray(np.transpose(img, (1, 2, 0))) 24 | 25 | if self.transform is not None: 26 | img = self.transform(img) 27 | 28 | if self.target_transform is not None: 29 | target = self.target_transform(target) 30 | 31 | return img, target, index 32 | 33 | class SVHN: 34 | def __init__(self, 35 | preprocess, 36 | location=os.path.expanduser('./data'), 37 | batch_size=128, 38 | num_workers=16): 39 | 40 | # to fit with repo conventions for location 41 | modified_location = os.path.join(location, 'svhn') 42 | 43 | self.train_dataset = MyPyTorchSVHN( 44 | root=modified_location, 45 | download=True, 46 | split='train', 47 | transform=preprocess 48 | ) 49 | 50 | self.train_loader = torch.utils.data.DataLoader( 51 | self.train_dataset, 52 | batch_size=batch_size, 53 | shuffle=True, 54 | num_workers=num_workers 55 | ) 56 | 57 | self.test_dataset = MyPyTorchSVHN( 58 | root=modified_location, 59 | download=True, 60 | split='test', 61 | transform=preprocess 62 | ) 63 | 64 | self.test_loader = torch.utils.data.DataLoader( 65 | self.test_dataset, 66 | batch_size=batch_size, 67 | shuffle=False, 68 | num_workers=num_workers 69 | ) 70 | 71 | self.test_loader_shuffle = torch.utils.data.DataLoader( 72 | self.test_dataset, 73 | shuffle=True, 74 | batch_size=batch_size, 75 | num_workers=num_workers 76 | ) 77 | 78 | self.classnames = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'] 79 | -------------------------------------------------------------------------------- /src/datasets/templates.py: -------------------------------------------------------------------------------- 1 | cars_template = [ 2 | lambda c: f'a photo of a {c}.', 3 | lambda c: f'a photo of the {c}.', 4 | lambda c: f'a photo of my {c}.', 5 | lambda c: f'i love my {c}!', 6 | lambda c: f'a photo of my dirty {c}.', 7 | lambda c: f'a photo of my clean {c}.', 8 | lambda c: f'a photo of my new {c}.', 9 | lambda c: f'a photo of my old {c}.', 10 | ] 11 | 12 | cifar10_template = [ 13 | lambda c: f'a photo of a {c}.', 14 | lambda c: f'a blurry photo of a {c}.', 15 | lambda c: f'a black and white photo of a {c}.', 16 | lambda c: f'a low contrast photo of a {c}.', 17 | lambda c: f'a high contrast photo of a {c}.', 18 | lambda c: f'a bad photo of a {c}.', 19 | lambda c: f'a good photo of a {c}.', 20 | lambda c: f'a photo of a small {c}.', 21 | lambda c: f'a photo of a big {c}.', 22 | lambda c: f'a photo of the {c}.', 23 | lambda c: f'a blurry photo of the {c}.', 24 | lambda c: f'a black and white photo of the {c}.', 25 | lambda c: f'a low contrast photo of the {c}.', 26 | lambda c: f'a high contrast photo of the {c}.', 27 | lambda c: f'a bad photo of the {c}.', 28 | lambda c: f'a good photo of the {c}.', 29 | lambda c: f'a photo of the small {c}.', 30 | lambda c: f'a photo of the big {c}.', 31 | ] 32 | 33 | cifar100_template = [ 34 | lambda c: f'a photo of a {c}.', 35 | lambda c: f'a blurry photo of a {c}.', 36 | lambda c: f'a black and white photo of a {c}.', 37 | lambda c: f'a low contrast photo of a {c}.', 38 | lambda c: f'a high contrast photo of a {c}.', 39 | lambda c: f'a bad photo of a {c}.', 40 | lambda c: f'a good photo of a {c}.', 41 | lambda c: f'a photo of a small {c}.', 42 | lambda c: f'a photo of a big {c}.', 43 | lambda c: f'a photo of the {c}.', 44 | lambda c: f'a blurry photo of the {c}.', 45 | lambda c: f'a black and white photo of the {c}.', 46 | lambda c: f'a low contrast photo of the {c}.', 47 | lambda c: f'a high contrast photo of the {c}.', 48 | lambda c: f'a bad photo of the {c}.', 49 | lambda c: f'a good photo of the {c}.', 50 | lambda c: f'a photo of the small {c}.', 51 | lambda c: f'a photo of the big {c}.', 52 | ] 53 | 54 | dtd_template = [ 55 | lambda c: f'a photo of a {c} texture.', 56 | lambda c: f'a photo of a {c} pattern.', 57 | lambda c: f'a photo of a {c} thing.', 58 | lambda c: f'a photo of a {c} object.', 59 | lambda c: f'a photo of the {c} texture.', 60 | lambda c: f'a photo of the {c} pattern.', 61 | lambda c: f'a photo of the {c} thing.', 62 | lambda c: f'a photo of the {c} object.', 63 | ] 64 | 65 | eurosat_template = [ 66 | lambda c: f'a centered satellite photo of {c}.', 67 | lambda c: f'a centered satellite photo of a {c}.', 68 | lambda c: f'a centered satellite photo of the {c}.', 69 | ] 70 | 71 | food101_template = [ 72 | lambda c: f'a photo of {c}, a type of food.', 73 | ] 74 | 75 | gtsrb_template = [ 76 | lambda c: f'a zoomed in photo of a "{c}" traffic sign.', 77 | lambda c: f'a centered photo of a "{c}" traffic sign.', 78 | lambda c: f'a close up photo of a "{c}" traffic sign.', 79 | ] 80 | 81 | mnist_template = [ 82 | lambda c: f'a photo of the number: "{c}".', 83 | ] 84 | 85 | imagenet_template = [ 86 | lambda c: f'a bad photo of a {c}.', 87 | lambda c: f'a photo of many {c}.', 88 | lambda c: f'a sculpture of a {c}.', 89 | lambda c: f'a photo of the hard to see {c}.', 90 | lambda c: f'a low resolution photo of the {c}.', 91 | lambda c: f'a rendering of a {c}.', 92 | lambda c: f'graffiti of a {c}.', 93 | lambda c: f'a bad photo of the {c}.', 94 | lambda c: f'a cropped photo of the {c}.', 95 | lambda c: f'a tattoo of a {c}.', 96 | lambda c: f'the embroidered {c}.', 97 | lambda c: f'a photo of a hard to see {c}.', 98 | lambda c: f'a bright photo of a {c}.', 99 | lambda c: f'a photo of a clean {c}.', 100 | lambda c: f'a photo of a dirty {c}.', 101 | lambda c: f'a dark photo of the {c}.', 102 | lambda c: f'a drawing of a {c}.', 103 | lambda c: f'a photo of my {c}.', 104 | lambda c: f'the plastic {c}.', 105 | lambda c: f'a photo of the cool {c}.', 106 | lambda c: f'a close-up photo of a {c}.', 107 | lambda c: f'a black and white photo of the {c}.', 108 | lambda c: f'a painting of the {c}.', 109 | lambda c: f'a painting of a {c}.', 110 | lambda c: f'a pixelated photo of the {c}.', 111 | lambda c: f'a sculpture of the {c}.', 112 | lambda c: f'a bright photo of the {c}.', 113 | lambda c: f'a cropped photo of a {c}.', 114 | lambda c: f'a plastic {c}.', 115 | lambda c: f'a photo of the dirty {c}.', 116 | lambda c: f'a jpeg corrupted photo of a {c}.', 117 | lambda c: f'a blurry photo of the {c}.', 118 | lambda c: f'a photo of the {c}.', 119 | lambda c: f'a good photo of the {c}.', 120 | lambda c: f'a rendering of the {c}.', 121 | lambda c: f'a {c} in a video game.', 122 | lambda c: f'a photo of one {c}.', 123 | lambda c: f'a doodle of a {c}.', 124 | lambda c: f'a close-up photo of the {c}.', 125 | lambda c: f'a photo of a {c}.', 126 | lambda c: f'the origami {c}.', 127 | lambda c: f'the {c} in a video game.', 128 | lambda c: f'a sketch of a {c}.', 129 | lambda c: f'a doodle of the {c}.', 130 | lambda c: f'a origami {c}.', 131 | lambda c: f'a low resolution photo of a {c}.', 132 | lambda c: f'the toy {c}.', 133 | lambda c: f'a rendition of the {c}.', 134 | lambda c: f'a photo of the clean {c}.', 135 | lambda c: f'a photo of a large {c}.', 136 | lambda c: f'a rendition of a {c}.', 137 | lambda c: f'a photo of a nice {c}.', 138 | lambda c: f'a photo of a weird {c}.', 139 | lambda c: f'a blurry photo of a {c}.', 140 | lambda c: f'a cartoon {c}.', 141 | lambda c: f'art of a {c}.', 142 | lambda c: f'a sketch of the {c}.', 143 | lambda c: f'a embroidered {c}.', 144 | lambda c: f'a pixelated photo of a {c}.', 145 | lambda c: f'itap of the {c}.', 146 | lambda c: f'a jpeg corrupted photo of the {c}.', 147 | lambda c: f'a good photo of a {c}.', 148 | lambda c: f'a plushie {c}.', 149 | lambda c: f'a photo of the nice {c}.', 150 | lambda c: f'a photo of the small {c}.', 151 | lambda c: f'a photo of the weird {c}.', 152 | lambda c: f'the cartoon {c}.', 153 | lambda c: f'art of the {c}.', 154 | lambda c: f'a drawing of the {c}.', 155 | lambda c: f'a photo of the large {c}.', 156 | lambda c: f'a black and white photo of a {c}.', 157 | lambda c: f'the plushie {c}.', 158 | lambda c: f'a dark photo of a {c}.', 159 | lambda c: f'itap of a {c}.', 160 | lambda c: f'graffiti of the {c}.', 161 | lambda c: f'a toy {c}.', 162 | lambda c: f'itap of my {c}.', 163 | lambda c: f'a photo of a cool {c}.', 164 | lambda c: f'a photo of a small {c}.', 165 | lambda c: f'a tattoo of the {c}.', 166 | ] 167 | 168 | resisc45_template = [ 169 | lambda c: f'satellite imagery of {c}.', 170 | lambda c: f'aerial imagery of {c}.', 171 | lambda c: f'satellite photo of {c}.', 172 | lambda c: f'aerial photo of {c}.', 173 | lambda c: f'satellite view of {c}.', 174 | lambda c: f'aerial view of {c}.', 175 | lambda c: f'satellite imagery of a {c}.', 176 | lambda c: f'aerial imagery of a {c}.', 177 | lambda c: f'satellite photo of a {c}.', 178 | lambda c: f'aerial photo of a {c}.', 179 | lambda c: f'satellite view of a {c}.', 180 | lambda c: f'aerial view of a {c}.', 181 | lambda c: f'satellite imagery of the {c}.', 182 | lambda c: f'aerial imagery of the {c}.', 183 | lambda c: f'satellite photo of the {c}.', 184 | lambda c: f'aerial photo of the {c}.', 185 | lambda c: f'satellite view of the {c}.', 186 | lambda c: f'aerial view of the {c}.', 187 | ] 188 | 189 | stl10_template = [ 190 | lambda c: f'a photo of a {c}.', 191 | lambda c: f'a photo of the {c}.', 192 | ] 193 | 194 | sun397_template = [ 195 | lambda c: f'a photo of a {c}.', 196 | lambda c: f'a photo of the {c}.', 197 | ] 198 | 199 | svhn_template = [ 200 | lambda c: f'a photo of the number: "{c}".', 201 | ] 202 | 203 | pets_template = [ 204 | lambda c: f'a photo of a {c}, a type of pet.' 205 | ] 206 | 207 | caltech101_template = [ 208 | lambda c: f'a photo of a {c}.', 209 | lambda c: f'a painting of a {c}.', 210 | lambda c: f'a plastic {c}.', 211 | lambda c: f'a sculpture of a {c}.', 212 | lambda c: f'a sketch of a {c}.', 213 | lambda c: f'a tattoo of a {c}.', 214 | lambda c: f'a toy {c}.', 215 | lambda c: f'a rendition of a {c}.', 216 | lambda c: f'a embroidered {c}.', 217 | lambda c: f'a cartoon {c}.', 218 | lambda c: f'a {c} in a video game.', 219 | lambda c: f'a plushie {c}.', 220 | lambda c: f'a origami {c}.', 221 | lambda c: f'art of a {c}.', 222 | lambda c: f'graffiti of a {c}.', 223 | lambda c: f'a drawing of a {c}.', 224 | lambda c: f'a doodle of a {c}.', 225 | lambda c: f'a photo of the {c}.', 226 | lambda c: f'a painting of the {c}.', 227 | lambda c: f'the plastic {c}.', 228 | lambda c: f'a sculpture of the {c}.', 229 | lambda c: f'a sketch of the {c}.', 230 | lambda c: f'a tattoo of the {c}.', 231 | lambda c: f'the toy {c}.', 232 | lambda c: f'a rendition of the {c}.', 233 | lambda c: f'the embroidered {c}.', 234 | lambda c: f'the cartoon {c}.', 235 | lambda c: f'the {c} in a video game.', 236 | lambda c: f'the plushie {c}.', 237 | lambda c: f'the origami {c}.', 238 | lambda c: f'art of the {c}.', 239 | lambda c: f'graffiti of the {c}.', 240 | lambda c: f'a drawing of the {c}.', 241 | lambda c: f'a doodle of the {c}.', 242 | ] 243 | 244 | flower_templates = [ 245 | lambda c: f'a photo of a {c}, a type of flower.', 246 | ] 247 | 248 | cub_templates = [ 249 | lambda c: f'a photo of a {c}, a type of bird.', 250 | ] 251 | 252 | fashion_mnist_template = [ 253 | lambda c: f'a photo of a {c}.', 254 | lambda c: f'a blurry photo of a {c}.', 255 | lambda c: f'a black and white photo of a {c}.', 256 | lambda c: f'a thumbnail of a {c}.', 257 | lambda c: f'a photo of the {c}.', 258 | lambda c: f'a blurry photo of the {c}.', 259 | lambda c: f'a black and white photo of the {c}.', 260 | lambda c: f'a thumbnail of the {c}.', 261 | ] 262 | 263 | dataset_to_template = { 264 | 'Cars': cars_template, 265 | 'CIFAR10': cifar10_template, 266 | 'CIFAR100': cifar100_template, 267 | 'DTD': dtd_template, 268 | 'EuroSAT': eurosat_template, 269 | 'Food101': food101_template, 270 | 'GTSRB': gtsrb_template, 271 | 'MNIST': mnist_template, 272 | 'ImageNet': imagenet_template, 273 | 'RESISC45': resisc45_template, 274 | 'STL10': stl10_template, 275 | 'SUN397': sun397_template, 276 | 'SVHN': svhn_template, 277 | 'PETS': pets_template, 278 | 'Caltech101': caltech101_template, 279 | 'Flowers': flower_templates, 280 | 'TIN': imagenet_template, 281 | 'ImageNet100': imagenet_template, 282 | 'ImageNet': imagenet_template, 283 | 'CUB200': cub_templates, 284 | 'FashionMNIST': fashion_mnist_template 285 | } 286 | 287 | 288 | def get_templates(dataset_name): 289 | if dataset_name.endswith('Val'): 290 | return get_templates(dataset_name.replace('Val', '')) 291 | assert dataset_name in dataset_to_template, f'Unsupported dataset: {dataset_name}' 292 | return dataset_to_template[dataset_name] -------------------------------------------------------------------------------- /src/eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import tqdm 4 | import torch 5 | import numpy as np 6 | import utils 7 | from src.datasets.common import get_dataloader, maybe_dictionarize 8 | from src.datasets.templates import get_templates 9 | from heads import get_classification_head, build_classification_head 10 | from modeling import ImageClassifier, ImageEncoder, ClassificationHead 11 | from src.datasets.registry import get_dataset_cifar_mnist 12 | import torchvision.utils as vutils 13 | from src.utils import * 14 | 15 | def eval_single_dataset(image_encoder, dataset_name, args, backdoor_info=None): 16 | print("") 17 | # 18 | classification_head = get_classification_head(args, dataset_name) 19 | model = ImageClassifier(image_encoder, classification_head) 20 | model.eval() 21 | 22 | # 23 | test_dataset, test_loader = get_dataset_cifar_mnist( 24 | dataset_name, 25 | 'shadowtest', 26 | model.val_preprocess, 27 | location=args.data_location, 28 | batch_size=args.batch_size 29 | ) 30 | normalizer = model.val_preprocess.transforms[-1] 31 | inv_normalizer = NormalizeInverse(normalizer.mean, normalizer.std) 32 | print("Evaluation Size:", len(test_dataset)) 33 | 34 | device = args.device 35 | 36 | with torch.no_grad(): 37 | top1, correct, n = 0., 0., 0. 38 | for i, data in enumerate(tqdm.tqdm(test_loader)): 39 | data = maybe_dictionarize(data) 40 | x = data['images'] 41 | y = data['labels'] 42 | 43 | x = x.cuda() 44 | y = y.cuda() 45 | logits = utils.get_logits(x, model) 46 | pred = logits.argmax(dim=1, keepdim=True).to(device) 47 | correct += pred.eq(y.view_as(pred)).sum().item() 48 | n += y.size(0) 49 | 50 | top1 = correct / n 51 | 52 | metrics = {'top1': top1} 53 | print(f'Accuracy: {100*top1:.2f}%') 54 | 55 | return metrics 56 | 57 | def eval_single_dataset_head(image_encoder, head, dataset_name, args): 58 | model = ImageClassifier(image_encoder, head) 59 | model.eval() 60 | test_dataset, test_loader = get_dataset_cifar_mnist(dataset_name, 'test', model.val_preprocess, location=args.data_location, batch_size=args.batch_size) 61 | device = args.device 62 | 63 | with torch.no_grad(): 64 | top1, correct, n = 0., 0., 0. 65 | for i, data in enumerate(tqdm.tqdm(test_loader)): 66 | data = maybe_dictionarize(data) 67 | x = data['images'].to(device) 68 | y = data['labels'].to(device) 69 | logits = utils.get_logits(x, model) 70 | pred = logits.argmax(dim=1, keepdim=True).to(device) 71 | correct += pred.eq(y.view_as(pred)).sum().item() 72 | n += y.size(0) 73 | top1 = correct / n 74 | 75 | metrics = {'top1': top1} 76 | print(f'Done evaluating on {dataset_name}. Accuracy: {100 * top1:.2f}%') 77 | return metrics 78 | 79 | def eval_single_dataset_preprocess_head(image_encoder, head, dataset_name, args): 80 | model = ImageClassifier(image_encoder, head) 81 | model.eval() 82 | test_dataset, test_loader = get_dataset_cifar_mnist(dataset_name, model.val_preprocess, 'test', location=args.data_location, batch_size=args.batch_size) 83 | device = args.device 84 | 85 | with torch.no_grad(): 86 | top1, correct, n = 0., 0., 0. 87 | for i, data in enumerate(tqdm.tqdm(test_loader)): 88 | data = maybe_dictionarize(data) 89 | x = data['images'].to(device) 90 | y = data['labels'].to(device) 91 | logits = utils.get_logits(x, model) 92 | pred = logits.argmax(dim=1, keepdim=True).to(device) 93 | correct += pred.eq(y.view_as(pred)).sum().item() 94 | n += y.size(0) 95 | top1 = correct / n 96 | metrics = {'top1': top1} 97 | print(f'Done evaluating on {dataset_name}. Accuracy: {100 * top1:.2f}%') 98 | return metrics 99 | 100 | def evaluate(image_encoder, args, backdoor_info=None): 101 | if args.eval_datasets is None: 102 | return 103 | info = vars(args) 104 | for i, dataset_name in enumerate(args.eval_datasets): 105 | print('Evaluating on', dataset_name) 106 | 107 | results = eval_single_dataset(image_encoder, dataset_name, args, backdoor_info) 108 | 109 | for key, val in results.items(): 110 | if 'worst' in key or 'f1' in key.lower() or 'pm0' in key: 111 | print(f"{dataset_name} {key}: {val:.4f}") 112 | if backdoor_info is not None: 113 | info[dataset_name + '-B:' + key] = val # trigger 114 | else: 115 | info[dataset_name + ':' + key] = val # clean 116 | return info -------------------------------------------------------------------------------- /src/figures/comparison.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LiQiiiii/Neural-Ligand/a51a30e9600e366511ceca6d650b09b08267033a/src/figures/comparison.png -------------------------------------------------------------------------------- /src/figures/exp.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LiQiiiii/Neural-Ligand/a51a30e9600e366511ceca6d650b09b08267033a/src/figures/exp.png -------------------------------------------------------------------------------- /src/figures/figures.txt: -------------------------------------------------------------------------------- 1 | Some figures in the paper are saved here. 2 | -------------------------------------------------------------------------------- /src/figures/main_table.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LiQiiiii/Neural-Ligand/a51a30e9600e366511ceca6d650b09b08267033a/src/figures/main_table.png -------------------------------------------------------------------------------- /src/figures/neulig_overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LiQiiiii/Neural-Ligand/a51a30e9600e366511ceca6d650b09b08267033a/src/figures/neulig_overview.png -------------------------------------------------------------------------------- /src/figures/neulig_train_pip.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LiQiiiii/Neural-Ligand/a51a30e9600e366511ceca6d650b09b08267033a/src/figures/neulig_train_pip.png -------------------------------------------------------------------------------- /src/finetune_clean.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import sys 4 | sys.path.append(os.path.abspath('.')) 5 | import torch 6 | from src.args import parse_arguments 7 | from src.datasets.common import get_dataloader, maybe_dictionarize 8 | from src.datasets.registry import get_dataset_cifar_mnist 9 | from src.eval import evaluate 10 | from src.modeling import ImageEncoder, ImageClassifier, MultiHeadImageClassifier 11 | from src.utils import cosine_lr, LabelSmoothing 12 | from src.heads import get_classification_head 13 | import src.datasets as datasets 14 | from torch.utils.data import Subset 15 | import pickle 16 | import numpy as np 17 | import copy 18 | import tqdm 19 | 20 | def save_dataset_splits(train_dataset, test_dataset, shadowtrain_dataset, shadowtest_dataset, save_dir): 21 | os.makedirs(save_dir, exist_ok=True) 22 | with open(os.path.join(save_dir, 'train_dataset.pkl'), 'wb') as f: 23 | pickle.dump(train_dataset, f) 24 | 25 | with open(os.path.join(save_dir, 'test_dataset.pkl'), 'wb') as f: 26 | pickle.dump(test_dataset, f) 27 | 28 | with open(os.path.join(save_dir, 'shadowtrain_dataset.pkl'), 'wb') as f: 29 | pickle.dump(shadowtrain_dataset, f) 30 | 31 | with open(os.path.join(save_dir, 'shadowtest_dataset.pkl'), 'wb') as f: 32 | pickle.dump(shadowtest_dataset, f) 33 | 34 | def load_datasets(save_dir): 35 | with open(os.path.join(save_dir, 'train_dataset.pkl'), 'rb') as f: 36 | train_dataset = pickle.load(f) 37 | 38 | with open(os.path.join(save_dir, 'test_dataset.pkl'), 'rb') as f: 39 | test_dataset = pickle.load(f) 40 | 41 | with open(os.path.join(save_dir, 'shadowtrain_dataset.pkl'), 'rb') as f: 42 | shadowtrain_dataset = pickle.load(f) 43 | 44 | with open(os.path.join(save_dir, 'shadowtest_dataset.pkl'), 'rb') as f: 45 | shadowtest_dataset = pickle.load(f) 46 | 47 | 48 | return train_dataset, test_dataset, shadowtrain_dataset, shadowtest_dataset 49 | def check_datasets_exist(save_dir): 50 | return (os.path.exists(os.path.join(save_dir, 'train_dataset.pkl')) and 51 | os.path.exists(os.path.join(save_dir, 'test_dataset.pkl')) and 52 | os.path.exists(os.path.join(save_dir, 'shadowtrain_dataset.pkl')) and 53 | os.path.exists(os.path.join(save_dir, 'shadowtest_dataset.pkl'))) 54 | 55 | def load_dataset_splits(save_dir): 56 | with open(os.path.join(save_dir, 'train_indices.pkl'), 'rb') as f: 57 | train_indices = pickle.load(f) 58 | with open(os.path.join(save_dir, 'test_indices.pkl'), 'rb') as f: 59 | test_indices = pickle.load(f) 60 | with open(os.path.join(save_dir, 'shadowtrain_indices.pkl'), 'rb') as f: 61 | shadowtrain_indices = pickle.load(f) 62 | with open(os.path.join(save_dir, 'shadowtest_indices.pkl'), 'rb') as f: 63 | shadowtest_indices = pickle.load(f) 64 | 65 | return train_indices, test_indices, shadowtrain_indices, shadowtest_indices 66 | 67 | def finetune(model, args): 68 | dataset = args.dataset 69 | preprocess_fn = model.train_preprocess 70 | 71 | print_every = 100 72 | dataset_save_dir = os.path.join("{}/{}/dataset_splits".format(args.save, args.dataset)) 73 | 74 | if check_datasets_exist(dataset_save_dir): 75 | print("Subsets already exits...") 76 | from torch.utils.data import DataLoader 77 | train_dataset, test_dataset, shadowtrain_dataset, shadowtest_dataset = load_datasets(dataset_save_dir) 78 | train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True) 79 | shadowtrain_loader = DataLoader(shadowtrain_dataset, batch_size=args.batch_size, shuffle=True) 80 | else: 81 | print("Subsets do not exist...") 82 | train_dataset, train_loader = get_dataset_cifar_mnist( 83 | dataset, 84 | 'train', 85 | preprocess_fn, 86 | location=args.data_location, 87 | batch_size=args.batch_size 88 | ) 89 | test_dataset, test_loader = get_dataset_cifar_mnist( 90 | dataset, 91 | 'test', 92 | preprocess_fn, 93 | location=args.data_location, 94 | batch_size=args.batch_size 95 | ) 96 | shadowtrain_dataset, shadowtrain_loader = get_dataset_cifar_mnist( 97 | dataset, 98 | 'shadowtrain', 99 | preprocess_fn, 100 | location=args.data_location, 101 | batch_size=args.batch_size 102 | ) 103 | shadowtest_dataset, shadowtest_loader = get_dataset_cifar_mnist( 104 | dataset, 105 | 'shadowtest', 106 | preprocess_fn, 107 | location=args.data_location, 108 | batch_size=args.batch_size 109 | ) 110 | save_dataset_splits(train_dataset, test_dataset, shadowtrain_dataset, shadowtest_dataset, dataset_save_dir) 111 | 112 | num_batches = len(train_loader) 113 | print("train_length: {}, val_length: {}, shadowtrain_length: {}, shadowtest_length: {}".format(len(train_dataset), len(test_dataset), len(shadowtrain_dataset), len(shadowtest_dataset))) 114 | # save pre-trained model 115 | 116 | # dataset_dir = dataset + '_1epoch' 117 | # ckpdir = os.path.join(args.save, dataset_dir) 118 | 119 | ckpdir = os.path.join(args.save, dataset) 120 | 121 | if args.save is not None: 122 | os.makedirs(ckpdir, exist_ok=True) 123 | model_path = os.path.join(args.save, f'zeroshot.pt') 124 | if not os.path.exists(model_path): 125 | model.image_encoder.save(model_path) 126 | # evaluate pre-trained model 127 | print("Initial evaluation:") 128 | image_encoder = model.image_encoder 129 | args.eval_datasets = [dataset] 130 | evaluate(image_encoder, args) 131 | 132 | # test_loaders = [test_loader] 133 | # evaluate_single(model, test_loaders, args.device) 134 | 135 | # train model for target train set 136 | loss_fn = torch.nn.CrossEntropyLoss() 137 | params = [p for p in model.parameters() if p.requires_grad] 138 | optimizer = torch.optim.AdamW(params, lr=args.lr, weight_decay=args.wd) 139 | scheduler = cosine_lr(optimizer, args.lr, args.warmup_length, args.epochs * num_batches) 140 | for epoch in range(args.epochs): 141 | model = model.cuda() 142 | model.train() 143 | for i, batch in enumerate(train_loader): 144 | start_time = time.time() 145 | step = i + epoch * num_batches 146 | scheduler(step) 147 | optimizer.zero_grad() 148 | 149 | batch = maybe_dictionarize(batch) 150 | inputs = batch['images'].to('cuda:0') 151 | labels = batch['labels'].to('cuda:0') 152 | data_time = time.time() - start_time 153 | 154 | logits = model(inputs) 155 | loss = loss_fn(logits, labels) 156 | loss.backward() 157 | torch.nn.utils.clip_grad_norm_(params, 1.0) 158 | optimizer.step() 159 | batch_time = time.time() - start_time 160 | 161 | if step % print_every == 0: 162 | percent_complete = 100 * i / len(train_loader) 163 | print( 164 | f"Target Train Epoch: {epoch} [{percent_complete:.0f}% {i}/{len(train_loader)}]\t" 165 | f"Loss: {loss.item():.6f}\tData (t) {data_time:.3f}\tBatch (t) {batch_time:.3f}", flush=True 166 | ) 167 | # evaluate target model 168 | image_encoder = model.image_encoder 169 | args.eval_datasets = [dataset] # eval dataset 170 | evaluate(image_encoder, args) 171 | 172 | # Save the finetuned model 173 | if args.save is not None: 174 | ft_path = os.path.join(ckpdir, 'finetuned.pt') 175 | image_encoder.save(ft_path) 176 | 177 | def finetune_dev(model_shadow, args): 178 | dataset = args.dataset 179 | preprocess_fn = model_shadow.train_preprocess 180 | 181 | print_every = 100 182 | dataset_save_dir = os.path.join("{}/{}/dataset_splits".format(args.save, args.dataset)) 183 | 184 | if check_datasets_exist(dataset_save_dir): 185 | print("Subsets already exits...") 186 | from torch.utils.data import DataLoader 187 | train_dataset, test_dataset, shadowtrain_dataset, shadowtest_dataset = load_datasets(dataset_save_dir) 188 | train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True) 189 | shadowtrain_loader = DataLoader(shadowtrain_dataset, batch_size=args.batch_size, shuffle=True) 190 | else: 191 | print("Subsets do not exist...") 192 | train_dataset, train_loader = get_dataset_cifar_mnist( 193 | dataset, 194 | 'train', 195 | preprocess_fn, 196 | location=args.data_location, 197 | batch_size=args.batch_size 198 | ) 199 | test_dataset, test_loader = get_dataset_cifar_mnist( 200 | dataset, 201 | 'test', 202 | preprocess_fn, 203 | location=args.data_location, 204 | batch_size=args.batch_size 205 | ) 206 | shadowtrain_dataset, shadowtrain_loader = get_dataset_cifar_mnist( 207 | dataset, 208 | 'shadowtrain', 209 | preprocess_fn, 210 | location=args.data_location, 211 | batch_size=args.batch_size 212 | ) 213 | shadowtest_dataset, shadowtest_loader = get_dataset_cifar_mnist( 214 | dataset, 215 | 'shadowtest', 216 | preprocess_fn, 217 | location=args.data_location, 218 | batch_size=args.batch_size 219 | ) 220 | save_dataset_splits(train_dataset, test_dataset, shadowtrain_dataset, shadowtest_dataset, dataset_save_dir) 221 | 222 | num_batches = len(train_loader) 223 | print("train_length: {}, val_length: {}, shadowtrain_length: {}, shadowtest_length: {}".format(len(train_dataset), len(test_dataset), len(shadowtrain_dataset), len(shadowtest_dataset))) 224 | 225 | # train model for shadow train set 226 | model_shadow = model_shadow.to(args.device) 227 | loss_fn_shadow = torch.nn.CrossEntropyLoss() 228 | params_shadow = [p for p in model_shadow.parameters() if p.requires_grad] 229 | optimizer_shadow = torch.optim.AdamW(params_shadow, lr=args.lr, weight_decay=args.wd) 230 | scheduler_shadow = cosine_lr(optimizer_shadow, args.lr, args.warmup_length, args.epochs * num_batches) 231 | for epoch in range(args.epochs): 232 | model_shadow = model_shadow.cuda() 233 | model_shadow.train() 234 | for i, batch in enumerate(shadowtrain_loader): 235 | start_time = time.time() 236 | step = i + epoch * num_batches 237 | scheduler_shadow(step) 238 | optimizer_shadow.zero_grad() 239 | 240 | batch = maybe_dictionarize(batch) 241 | inputs = batch['images'].to('cuda:0') 242 | labels = batch['labels'].to('cuda:0') 243 | data_time = time.time() - start_time 244 | 245 | logits = model_shadow(inputs) 246 | loss = loss_fn_shadow(logits, labels) 247 | loss.backward() 248 | torch.nn.utils.clip_grad_norm_(params_shadow, 1.0) 249 | optimizer_shadow.step() 250 | batch_time = time.time() - start_time 251 | 252 | if step % print_every == 0: 253 | percent_complete = 100 * i / len(shadowtrain_loader) 254 | print( 255 | f"Shadow Train Epoch: {epoch} [{percent_complete:.0f}% {i}/{len(shadowtrain_loader)}]\t" 256 | f"Loss: {loss.item():.6f}\tData (t) {data_time:.3f}\tBatch (t) {batch_time:.3f}", flush=True 257 | ) 258 | 259 | # evaluate shadow model 260 | image_encoder_shadow = model_shadow.image_encoder 261 | args.eval_datasets = [dataset] # eval dataset 262 | evaluate(image_encoder_shadow, args) 263 | 264 | ckpdir = os.path.join(args.save, dataset) 265 | if args.save is not None: 266 | dev_ft_path = os.path.join(ckpdir, 'finetuned_dev.pt') 267 | image_encoder_shadow.save(dev_ft_path) 268 | 269 | if __name__ == '__main__': 270 | data_location = "./data" 271 | models = ['RN50', 'ViT-B-32', 'ViT-L-14'] 272 | datasets = ['CIFAR10', 'MNIST', 'GTSRB', 'RESISC45', 'CIFAR100', 'SVHN', 'STL10'] 273 | 274 | epochs = { 275 | 'GTSRB': 11, 276 | 'MNIST': 5, 277 | 'RESISC45': 15, 278 | 'SVHN': 4, 279 | 'STL10': 50, 280 | 'CIFAR100': 5, 281 | 'CIFAR10': 5, 282 | } 283 | 284 | for model_name in models: 285 | for dataset in datasets: 286 | print('='*100) 287 | print(f'Finetuning {model_name} on {dataset}') 288 | print('='*100) 289 | args = parse_arguments() 290 | 291 | args.lr = 1e-5 292 | args.epochs = epochs[dataset] 293 | args.data_location = data_location 294 | args.dataset = dataset 295 | args.batch_size = 32 296 | 297 | args.model = model_name 298 | args.save = f'./checkpoints/{args.model}' 299 | args.cache_dir = '' 300 | args.openclip_cachedir = './open_clip' 301 | image_encoder = ImageEncoder(args, keep_lang=False) 302 | classification_head = get_classification_head(args, dataset) 303 | model = ImageClassifier(image_encoder, classification_head) 304 | model.freeze_head() 305 | model_shadow = copy.deepcopy(model) 306 | 307 | finetune(model, args) 308 | finetune_dev(model_shadow, args) 309 | -------------------------------------------------------------------------------- /src/heads.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from tqdm import tqdm 4 | import open_clip 5 | from src.datasets.templates import get_templates 6 | from src.datasets.registry import get_dataset_classnames 7 | from src.modeling import ClassificationHead, ImageEncoder 8 | 9 | def build_classification_head(model, dataset_name, template, data_location, device): 10 | template = get_templates(dataset_name) 11 | 12 | logit_scale = model.logit_scale 13 | classnames = get_dataset_classnames( 14 | dataset_name, 15 | None, 16 | location=data_location 17 | ) 18 | model.eval() 19 | model.to(device) 20 | 21 | print('Building classification head.') 22 | with torch.no_grad(): 23 | zeroshot_weights = [] 24 | for classname in tqdm(classnames): 25 | texts = [] 26 | for t in template: 27 | texts.append(t(classname)) 28 | texts = open_clip.tokenize(texts).to(device) # tokenize 29 | embeddings = model.encode_text(texts) # embed with text encoder 30 | embeddings /= embeddings.norm(dim=-1, keepdim=True) 31 | 32 | embeddings = embeddings.mean(dim=0, keepdim=True) 33 | embeddings /= embeddings.norm() 34 | 35 | zeroshot_weights.append(embeddings) 36 | 37 | zeroshot_weights = torch.stack(zeroshot_weights, dim=0).to(device) 38 | zeroshot_weights = torch.transpose(zeroshot_weights, 0, 2) 39 | 40 | zeroshot_weights *= logit_scale.exp() 41 | 42 | zeroshot_weights = zeroshot_weights.squeeze().float() 43 | zeroshot_weights = torch.transpose(zeroshot_weights, 0, 1) 44 | 45 | classification_head = ClassificationHead(normalize=True, weights=zeroshot_weights) 46 | 47 | return classification_head 48 | 49 | 50 | def get_classification_head(args, dataset): 51 | filename = os.path.join(args.save, f'head_{dataset}.pt') 52 | if os.path.exists(filename): 53 | print(f'Classification head for {args.model} on {dataset} exists at {filename}') 54 | return ClassificationHead.load(filename) 55 | print(f'Did not find classification head for {args.model} on {dataset} at {filename}, building one from scratch.') 56 | model = ImageEncoder(args, keep_lang=True).model 57 | template = get_templates(dataset) 58 | classification_head = build_classification_head(model, dataset, template, args.data_location, args.device) 59 | os.makedirs(args.save, exist_ok=True) 60 | classification_head.save(filename) 61 | return classification_head 62 | 63 | def get_classification_head_dev(args, model, dataset, flag): 64 | if flag == 'shadow': 65 | filename = os.path.join(args.save, f'head_{dataset}_shadow.pt') 66 | if os.path.exists(filename): 67 | print(f'Classification head for {args.model} on {dataset} exists at {filename}') 68 | return ClassificationHead.load(filename) 69 | print(f'Did not find classification head for {args.model} on {dataset} at {filename}, building one from scratch.') 70 | model_todo = model.model 71 | template = get_templates(dataset) 72 | classification_head = build_classification_head(model_todo, dataset, template, args.data_location, args.device) 73 | os.makedirs(args.save, exist_ok=True) 74 | classification_head.save(filename) 75 | return classification_head 76 | if flag == 'target': 77 | filename = os.path.join(args.save, f'head_{dataset}_target.pt') 78 | if os.path.exists(filename): 79 | print(f'Classification head for {args.model} on {dataset} exists at {filename}') 80 | return ClassificationHead.load(filename) 81 | print(f'Did not find classification head for {args.model} on {dataset} at {filename}, building one from scratch.') 82 | model_todo = model.model 83 | template = get_templates(dataset) 84 | classification_head = build_classification_head(model_todo, dataset, template, args.data_location, args.device) 85 | os.makedirs(args.save, exist_ok=True) 86 | classification_head.save(filename) 87 | return classification_head -------------------------------------------------------------------------------- /src/modeling.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import open_clip 4 | 5 | import utils 6 | import math 7 | 8 | 9 | class ImageEncoder(torch.nn.Module): 10 | def __init__(self, args, keep_lang=False): 11 | super().__init__() 12 | 13 | print(f'Creating {args.model} with random initialization.') 14 | 15 | self.model, self.train_preprocess, self.val_preprocess = open_clip.create_model_and_transforms( 16 | args.model, pretrained='openai', cache_dir=args.openclip_cachedir) # pretrained=None 17 | 18 | self.cache_dir = args.cache_dir 19 | 20 | if not keep_lang and hasattr(self.model, 'transformer'): 21 | delattr(self.model, 'transformer') 22 | 23 | # self._initialize_weights() 24 | 25 | def _initialize_weights(self): 26 | for module in self.model.modules(): 27 | if isinstance(module, (torch.nn.Linear, torch.nn.Conv2d)): 28 | torch.nn.init.kaiming_uniform_(module.weight, a=math.sqrt(5)) 29 | if module.bias is not None: 30 | fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(module.weight) 31 | bound = 1 / math.sqrt(fan_in) 32 | torch.nn.init.uniform_(module.bias, -bound, bound) 33 | 34 | def forward(self, images): 35 | assert self.model is not None 36 | return self.model.encode_image(images) 37 | 38 | def __call__(self, inputs): 39 | return self.forward(inputs) 40 | 41 | def save(self, filename): 42 | print(f'Saving image encoder to {filename}') 43 | utils.torch_save(self, filename) 44 | 45 | @classmethod 46 | def load(cls, model_name, filename): 47 | print(f'Loading image encoder from {filename}') 48 | state_dict = torch.load(filename) 49 | return cls.load(model_name, state_dict) 50 | 51 | def load_from_state_dict(self, model_name, state_dict): 52 | print("start loading state dict from {}".format(state_dict)) 53 | self.model, self.train_preprocess, self.val_preprocess = open_clip.create_model_and_transforms( 54 | model_name, pretrained='openai', device='cpu') 55 | # model.load_from_state_dict(state_dict, strict=False) 56 | checkpoint = torch.load(state_dict, map_location=torch.device('cpu')) 57 | self.model.visual.load_state_dict(checkpoint) 58 | delattr(self.model, 'transformer') 59 | print("successfully loading state dict!") 60 | 61 | class ClassificationHead(torch.nn.Linear): 62 | def __init__(self, normalize, weights, biases=None): 63 | output_size, input_size = weights.shape 64 | super().__init__(input_size, output_size) 65 | self.normalize = normalize 66 | if weights is not None: 67 | self.weight = torch.nn.Parameter(weights.clone()) 68 | if biases is not None: 69 | self.bias = torch.nn.Parameter(biases.clone()) 70 | else: 71 | self.bias = torch.nn.Parameter(torch.zeros_like(self.bias)) 72 | 73 | def forward(self, inputs): 74 | if self.normalize: 75 | inputs = inputs / inputs.norm(dim=-1, keepdim=True) 76 | return super().forward(inputs) 77 | 78 | def __call__(self, inputs): 79 | return self.forward(inputs) 80 | 81 | def save(self, filename): 82 | print(f'Saving classification head to {filename}') 83 | utils.torch_save(self, filename) 84 | 85 | @classmethod 86 | def load(cls, filename): 87 | print(f'Loading classification head from {filename}') 88 | return utils.torch_load(filename) 89 | 90 | 91 | class ImageClassifier(torch.nn.Module): 92 | def __init__(self, image_encoder, classification_head): 93 | super().__init__() 94 | self.image_encoder = image_encoder 95 | self.classification_head = classification_head 96 | if self.image_encoder is not None: 97 | if hasattr(self.image_encoder, 'train_preprocess'): 98 | self.train_preprocess = self.image_encoder.train_preprocess 99 | self.val_preprocess = self.image_encoder.val_preprocess 100 | elif hasattr(self.image_encoder.model, 'train_preprocess'): 101 | self.train_preprocess = self.image_encoder.model.train_preprocess 102 | self.val_preprocess = self.image_encoder.model.val_preprocess 103 | 104 | def freeze_head(self): 105 | self.classification_head.weight.requires_grad_(False) 106 | self.classification_head.bias.requires_grad_(False) 107 | 108 | def forward(self, inputs): 109 | features = self.image_encoder(inputs) 110 | outputs = self.classification_head(features) 111 | return outputs 112 | 113 | def __call__(self, inputs): 114 | return self.forward(inputs) 115 | 116 | def save(self, filename): 117 | print(f'Saving image classifier to {filename}') 118 | utils.torch_save(self, filename) 119 | 120 | @classmethod 121 | def load(cls, filename): 122 | print(f'Loading image classifier from {filename}') 123 | return utils.torch_load(filename) 124 | 125 | class ImageClassifier_debug(torch.nn.Module): 126 | def __init__(self, image_encoder, image_encoder2, classification_head): 127 | super().__init__() 128 | self.image_encoder = image_encoder 129 | self.image_encoder2 = image_encoder2 130 | self.classification_head = classification_head 131 | if self.image_encoder is not None: 132 | self.train_preprocess = self.image_encoder.train_preprocess 133 | self.val_preprocess = self.image_encoder.val_preprocess 134 | 135 | def freeze_head(self): 136 | self.classification_head.weight.requires_grad_(False) 137 | self.classification_head.bias.requires_grad_(False) 138 | 139 | def forward(self, inputs): 140 | features = self.image_encoder(inputs) 141 | features2 = self.image_encoder2(inputs) 142 | outputs = self.classification_head(features + features2) 143 | return outputs 144 | 145 | def __call__(self, inputs): 146 | return self.forward(inputs) 147 | 148 | def save(self, filename): 149 | print(f'Saving image classifier to {filename}') 150 | utils.torch_save(self, filename) 151 | 152 | @classmethod 153 | def load(cls, filename): 154 | print(f'Loading image classifier from {filename}') 155 | return utils.torch_load(filename) 156 | 157 | class MultiHeadImageClassifier(torch.nn.Module): 158 | def __init__(self, image_encoder, classification_heads): 159 | super().__init__() 160 | self.image_encoder = image_encoder 161 | self.classification_heads = torch.nn.ModuleList(classification_heads) 162 | if self.image_encoder is not None: 163 | self.train_preprocess = self.image_encoder.train_preprocess 164 | self.val_preprocess = self.image_encoder.val_preprocess 165 | 166 | def freeze_head(self): 167 | for idx in range(len(self.classification_heads)): 168 | self.classification_heads[idx].weight.requires_grad_(False) 169 | self.classification_heads[idx].bias.requires_grad_(False) 170 | 171 | def forward(self, inputs, head_idx): 172 | features = self.image_encoder(inputs) 173 | outputs = self.classification_heads[head_idx](features) 174 | return outputs 175 | 176 | def __call__(self, inputs, head_idx): 177 | return self.forward(inputs, head_idx) 178 | 179 | def save(self, filename): 180 | print(f'Saving image classifier to {filename}') 181 | utils.torch_save(self, filename) 182 | 183 | @classmethod 184 | def load(cls, filename): 185 | print(f'Loading image classifier from {filename}') 186 | return utils.torch_load(filename) 187 | -------------------------------------------------------------------------------- /src/neulig_main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import time 4 | import sys 5 | import tqdm 6 | sys.path.append('.') 7 | sys.path.append('./src') 8 | from src.modeling import ImageEncoder 9 | from task_vectors import TaskVector 10 | # from eval import eval_single_dataset 11 | from args import parse_arguments 12 | from utils import * 13 | import torchvision.transforms as transforms 14 | from PIL import Image 15 | import time 16 | import torchvision.utils as vutils 17 | # from src.datasets.registry import get_dataset 18 | from src.heads import get_classification_head 19 | import torch 20 | from collections import Counter 21 | import torch.nn.functional as F 22 | import torch.nn as nn 23 | import torch.optim as optim 24 | from src.datasets.common import get_dataloader, maybe_dictionarize 25 | import timm 26 | from itertools import cycle 27 | from modeling import ImageClassifier, ImageEncoder, ClassificationHead 28 | from open_clip import create_model_and_transforms 29 | from torch.utils.data import DataLoader, TensorDataset 30 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 31 | 32 | def merge_ckps(fusion_model, sample_weights): 33 | flat_ft = torch.vstack([state_dict_to_vector(check, []).to('cpu') for check in fusion_model.ckpts]).to('cpu') 34 | tv_flat_checks = flat_ft 35 | final_ck = None 36 | for j in range(fusion_model.num_models): 37 | weighted_value = sample_weights[0, j].to('cpu') * tv_flat_checks[j] 38 | if final_ck is None: 39 | final_ck = weighted_value 40 | else: 41 | final_ck += weighted_value 42 | final_ck = final_ck.to(device) 43 | return final_ck 44 | 45 | 46 | def save_dataset_splits(train_dataset, test_dataset, shadowtrain_dataset, shadowtest_dataset, save_dir): 47 | os.makedirs(save_dir, exist_ok=True) 48 | with open(os.path.join(save_dir, 'train_dataset.pkl'), 'wb') as f: 49 | pickle.dump(train_dataset, f) 50 | 51 | with open(os.path.join(save_dir, 'test_dataset.pkl'), 'wb') as f: 52 | pickle.dump(test_dataset, f) 53 | 54 | with open(os.path.join(save_dir, 'shadowtrain_dataset.pkl'), 'wb') as f: 55 | pickle.dump(shadowtrain_dataset, f) 56 | 57 | with open(os.path.join(save_dir, 'shadowtest_dataset.pkl'), 'wb') as f: 58 | pickle.dump(shadowtest_dataset, f) 59 | 60 | def load_datasets(save_dir): 61 | with open(os.path.join(save_dir, 'train_dataset.pkl'), 'rb') as f: 62 | train_dataset = pickle.load(f) 63 | 64 | with open(os.path.join(save_dir, 'test_dataset.pkl'), 'rb') as f: 65 | test_dataset = pickle.load(f) 66 | 67 | with open(os.path.join(save_dir, 'shadowtrain_dataset.pkl'), 'rb') as f: 68 | shadowtrain_dataset = pickle.load(f) 69 | 70 | with open(os.path.join(save_dir, 'shadowtest_dataset.pkl'), 'rb') as f: 71 | shadowtest_dataset = pickle.load(f) 72 | 73 | return train_dataset, test_dataset, shadowtrain_dataset, shadowtest_dataset 74 | def check_datasets_exist(save_dir): 75 | return (os.path.exists(os.path.join(save_dir, 'train_dataset.pkl')) and 76 | os.path.exists(os.path.join(save_dir, 'test_dataset.pkl')) and 77 | os.path.exists(os.path.join(save_dir, 'shadowtrain_dataset.pkl')) and 78 | os.path.exists(os.path.join(save_dir, 'shadowtest_dataset.pkl')) 79 | ) 80 | def load_dataset_splits(save_dir): 81 | with open(os.path.join(save_dir, 'train_indices.pkl'), 'rb') as f: 82 | train_indices = pickle.load(f) 83 | with open(os.path.join(save_dir, 'test_indices.pkl'), 'rb') as f: 84 | test_indices = pickle.load(f) 85 | with open(os.path.join(save_dir, 'shadowtrain_indices.pkl'), 'rb') as f: 86 | shadowtrain_indices = pickle.load(f) 87 | with open(os.path.join(save_dir, 'shadowtest_indices.pkl'), 'rb') as f: 88 | shadowtest_indices = pickle.load(f) 89 | 90 | return train_indices, test_indices, shadowtrain_indices, shadowtest_indices 91 | 92 | 93 | def evaluate_ori(fusion_model, test_loaders, criterion, device): 94 | fusion_model.eval() 95 | total_loss = 0.0 96 | merged_total_loss = 0.0 97 | total_correct = [] 98 | total_samples = [] 99 | merged_total_correct = [] 100 | 101 | with torch.no_grad(): 102 | for loader_idx, test_loader in enumerate(test_loaders): 103 | cur_correct = 0 104 | cur_samples = 0 105 | merged_cur_correct = 0 106 | for i, data in enumerate(tqdm.tqdm(test_loader)): 107 | data = maybe_dictionarize(data) 108 | inputs = data['images'].to(device) 109 | labels = data['labels'].to(device) 110 | 111 | outputs, _ = fusion_model(inputs, dataset_index=loader_idx) 112 | 113 | model_outputs = [] 114 | for i, model in enumerate(fusion_model.models): 115 | model.eval() 116 | with torch.no_grad(): 117 | output = model(inputs) 118 | model_outputs.append(output) 119 | weighting_model = fusion_model.get_weighting_model() 120 | stacked_outputs = torch.cat(model_outputs, dim=1) 121 | merge_weights = weighting_model(stacked_outputs) 122 | 123 | merged_checks = merge_ckps(fusion_model, merge_weights) 124 | merged_state_dict = vector_to_state_dict(merged_checks, ptm_check, remove_keys=[]) 125 | image_encoder.load_state_dict(merged_state_dict, strict=False) 126 | image_encoder.to(device) 127 | merged_model = ImageClassifier(image_encoder, fusion_model.prediction_heads[loader_idx]) 128 | merged_outputs = merged_model(inputs) 129 | loss = criterion(outputs, labels) 130 | total_loss += loss.item() 131 | 132 | merged_loss = criterion(merged_outputs, labels) 133 | merged_total_loss += merged_loss.item() 134 | 135 | cur_samples += labels.size(0) 136 | 137 | _, predicted = torch.max(outputs.data, 1) 138 | cur_correct += (predicted == labels).sum().item() 139 | 140 | _, merged_predicted = torch.max(merged_outputs.data, 1) 141 | merged_cur_correct += (merged_predicted == labels).sum().item() 142 | 143 | total_samples.append(cur_samples) 144 | 145 | total_correct.append(cur_correct) 146 | merged_total_correct.append(merged_cur_correct) 147 | 148 | accuracies = [100.0 * total_correct[i] / total_samples[i] for i in range(len(total_samples))] 149 | print("accuracy per task: ", accuracies) 150 | merged_accuracies = [100.0 * merged_total_correct[i] / total_samples[i] for i in range(len(total_samples))] 151 | print("merged_accuracy per task: ", merged_accuracies) 152 | avg_accuracy = sum(accuracies) / len(accuracies) 153 | avg_loss = total_loss / sum(total_samples) 154 | 155 | merged_avg_accuracy = sum(merged_accuracies) / len(merged_accuracies) 156 | merged_avg_loss = merged_total_loss / sum(total_samples) 157 | 158 | return avg_loss, avg_accuracy, merged_avg_loss, merged_avg_accuracy 159 | 160 | class WeightingModel(nn.Module): 161 | def __init__(self, input_dim=512, num_models=6): 162 | super(WeightingModel, self).__init__() 163 | self.num_models = num_models 164 | self.fc = nn.Linear(input_dim * num_models, num_models) 165 | def forward(self, x): 166 | 167 | logits = self.fc(x) 168 | weights = F.softmax(logits, dim=1) 169 | return weights 170 | 171 | 172 | class FusionModel(nn.Module): 173 | def __init__(self, ckpts, models, prediction_heads, input_dim=1024): # ViT-L-14: 768, RN50: 1024, ViT-B-32: 512 174 | super(FusionModel, self).__init__() 175 | self.models = models 176 | self.prediction_heads = prediction_heads 177 | self.num_models = len(models) 178 | self.weighting_model = WeightingModel(input_dim=input_dim, num_models=self.num_models) 179 | 180 | self.ckpts = ckpts 181 | self.flat_ft = torch.vstack([state_dict_to_vector(check, []).to('cpu') for check in self.ckpts]).to('cpu') 182 | self.mean_ft = torch.mean(self.flat_ft, dim=0) 183 | self.diff_ft = self.flat_ft - self.mean_ft.unsqueeze(0) # ksi 184 | self.sum_ft = torch.sum(self.diff_ft, dim=1).to(device) 185 | 186 | mean = torch.mean(self.sum_ft) 187 | std = torch.std(self.sum_ft) 188 | self.sum_ft = (self.sum_ft - mean) / std 189 | 190 | def forward(self, inputs, dataset_index): 191 | model_outputs = [] 192 | self.weighting_model.train() 193 | for i, model in enumerate(self.models): 194 | model.eval() 195 | with torch.no_grad(): 196 | output = model(inputs) 197 | model_outputs.append(output) 198 | 199 | stacked_outputs = torch.cat(model_outputs, dim=1) 200 | 201 | weights = self.weighting_model(stacked_outputs) 202 | reg_loss = torch.matmul(weights, self.sum_ft) 203 | 204 | tensor_sum = torch.sum(weights) 205 | 206 | weighted_sum = 0 207 | for i in range(self.num_models): 208 | weighted_output = model_outputs[i] * weights[:, i].unsqueeze(1) 209 | weighted_sum += weighted_output 210 | final_output = self.prediction_heads[dataset_index](weighted_sum) 211 | return final_output, reg_loss 212 | 213 | def get_weighting_model(self): 214 | return self.weighting_model 215 | 216 | args = parse_arguments() 217 | args.save = './checkpoints/{}'.format(args.model) 218 | 219 | exam_datasets = ['GTSRB', 'CIFAR100', 'RESISC45', 'CIFAR10', 'MNIST', 'STL10', 'SVHN'] 220 | num_classes = [43, 100, 45, 10, 10, 10, 10] 221 | use_merged_model = True 222 | 223 | classification_heads = [get_classification_head(args, dataset_name).to(device) for dataset_name in exam_datasets] 224 | 225 | import itertools 226 | exam_datasets_list = [list(comb) for comb in itertools.combinations(exam_datasets, args.num_co_models)] 227 | num_classes_list = [list(comb) for comb in itertools.combinations(num_classes, args.num_co_models)] 228 | classification_heads_list = [list(comb) for comb in itertools.combinations(classification_heads, args.num_co_models)] 229 | 230 | for mm in range(len(exam_datasets_list)): 231 | exam_datasets = exam_datasets_list[mm] 232 | num_classes = num_classes_list[mm] 233 | classification_heads = classification_heads_list[mm] 234 | 235 | args.save = os.path.join(args.ckpt_dir,args.model) 236 | args.save = './checkpoints/{}'.format(args.model) 237 | pretrained_checkpoint = os.path.join(args.save, 'zeroshot.pt') 238 | image_encoder = torch.load(pretrained_checkpoint) 239 | image_encoder_shadow = torch.load(pretrained_checkpoint) 240 | ptm_check = torch.load(pretrained_checkpoint).state_dict() 241 | 242 | from tm_utils import * 243 | ft_checks, ft_checks_shadow = [], [] 244 | ft_archs, ft_archs_shadow = [], [] 245 | 246 | for dataset_name in exam_datasets: 247 | ckpt_name = os.path.join(args.save, dataset_name, 'finetuned.pt') 248 | ckpt_name_shadow = os.path.join(args.save, dataset_name, 'finetuned_dev.pt') 249 | ft_archs.append(torch.load(ckpt_name).to(device)) 250 | ft_archs_shadow.append(torch.load(ckpt_name_shadow).to(device)) 251 | ft_checks.append(torch.load(ckpt_name).state_dict()) 252 | ft_checks_shadow.append(torch.load(ckpt_name_shadow).state_dict()) 253 | print(ckpt_name) 254 | print(ckpt_name_shadow) 255 | 256 | if args.model == 'RN50': 257 | fusion_model = FusionModel(ft_checks, ft_archs, classification_heads, input_dim=1024) 258 | fusion_model_shadow = FusionModel(ft_checks_shadow, ft_archs_shadow, classification_heads, input_dim=1024) 259 | elif args.model == 'ViT-B-32': 260 | fusion_model = FusionModel(ft_checks, ft_archs, classification_heads, input_dim=512) 261 | fusion_model_shadow = FusionModel(ft_checks_shadow, ft_archs_shadow, classification_heads, input_dim=512) 262 | elif args.model == 'ViT-L-14': 263 | fusion_model = FusionModel(ft_checks, ft_archs, classification_heads, input_dim=768) 264 | fusion_model_shadow = FusionModel(ft_checks_shadow, ft_archs_shadow, classification_heads, input_dim=768) 265 | test_loaders, train_loaders, shadowtrain_loaders, shadowtest_loaders, adv_test_loaders = [], [], [], [], [] 266 | 267 | for num_ld in range(len(exam_datasets)): 268 | dataset_save_dir = os.path.join("{}/{}/dataset_splits".format(args.save, exam_datasets[num_ld])) 269 | print("cur_process_dataset: ", dataset_save_dir) 270 | if check_datasets_exist(dataset_save_dir): 271 | print("Subsets already exits...") 272 | from torch.utils.data import DataLoader 273 | 274 | train_dataset, test_dataset, shadowtrain_dataset, shadowtest_dataset = load_datasets(dataset_save_dir) 275 | 276 | train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True) 277 | 278 | test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=True) 279 | 280 | shadowtrain_loader = DataLoader(shadowtrain_dataset, batch_size=args.batch_size, shuffle=True) 281 | 282 | shadowtest_loader = DataLoader(shadowtest_dataset, batch_size=args.batch_size, shuffle=True) 283 | 284 | test_loaders.append(test_loader) 285 | train_loaders.append(train_loader) 286 | 287 | print("dataset: {}, train_length: {}, test_length: {}, shadowtrain_length: {}, shadowtest_length: {}".format(exam_datasets[num_ld], len(train_dataset), len(test_dataset), len(shadowtrain_dataset), len(shadowtest_dataset))) 288 | 289 | fusion_model = fusion_model.to(device) 290 | fusion_model_shadow = fusion_model_shadow.to(device) 291 | 292 | optimizer = optim.Adam(fusion_model.weighting_model.parameters(), lr=0.001) 293 | optimizer_shadow = optim.Adam(fusion_model_shadow.weighting_model.parameters(), lr=0.001) 294 | 295 | criterion = nn.CrossEntropyLoss() 296 | criterion_reg = nn.MSELoss() 297 | 298 | print("#########################################################") 299 | print("###############PortLand Training Begins##################") 300 | print("#########################################################") 301 | avg_loss, accuracy, merged_avg_loss, merged_accuracy = evaluate_ori(fusion_model, test_loaders, criterion, device) 302 | print(f"Initial Evaluation - Avg Loss: {avg_loss:.4f}, Merged Avg Loss: {merged_avg_loss:.4f}, Ensembling Accuracy: {accuracy:.2f}%, Merging Accuracy: {merged_accuracy:.2f}%") 303 | best_accuracy = 0.0 304 | for glb_ep in range(args.global_epoch): 305 | fusion_model.train() 306 | loaders_cycle = [cycle(loader) for loader in train_loaders] 307 | total_batches = min(len(loader) for loader in train_loaders) 308 | 309 | for batch_idx in range(total_batches): 310 | for loader_idx, loader in enumerate(loaders_cycle): 311 | data = next(loader) 312 | 313 | data = maybe_dictionarize(data) 314 | inputs = data['images'].to(device) 315 | labels = data['labels'].to(device) 316 | 317 | outputs, reg_loss = fusion_model(inputs, dataset_index=loader_idx) 318 | 319 | target = torch.zeros_like(reg_loss).to(device) 320 | loss_reg = criterion_reg(reg_loss, target) / args.scaling 321 | 322 | if args.alignment_type == 'sup': 323 | loss_ce = criterion(outputs, labels) 324 | elif args.alignment_type == 'semi': # semi-supervised (entropy minimization) 325 | probs = F.softmax(outputs, dim=1) 326 | loss_ce = -torch.mean(torch.sum(probs * torch.log(probs + 1e-6), dim=1)) 327 | 328 | loss = loss_ce + loss_reg 329 | 330 | print(f"Epoch: {glb_ep}, Current Dataset Index: {loader_idx}, Batch: {batch_idx + 1}/{total_batches}, Loss: {loss.item():.4f}") 331 | 332 | optimizer.zero_grad() 333 | loss.backward() 334 | optimizer.step() 335 | 336 | if (glb_ep+1)%10==0: 337 | avg_loss, accuracy, merged_avg_loss, merged_accuracy = evaluate_ori( 338 | fusion_model, test_loaders, criterion, device 339 | ) 340 | print( 341 | f"Epoch [{glb_ep + 1}/{args.global_epoch}] Evaluation - " 342 | f"Ensembling Avg Loss: {avg_loss:.4f}, Merging Avg Loss: {merged_avg_loss:.4f}, " 343 | f"Ensembling Accuracy: {accuracy:.2f}%, Merging Accuracy: {merged_accuracy:.2f}%" 344 | ) 345 | 346 | if merged_accuracy > best_accuracy: 347 | best_accuracy = merged_accuracy 348 | print(f"New best model found with accuracy: {best_accuracy:.2f}%") 349 | 350 | print("#########################################################") 351 | print("################PortLand Training Ends###################") 352 | print("#########################################################") 353 | 354 | -------------------------------------------------------------------------------- /src/pgbar.py: -------------------------------------------------------------------------------- 1 | import os, torch 2 | import sys 3 | import time 4 | 5 | _, term_width = os.popen('stty size', 'r').read().split() 6 | term_width = int(term_width) 7 | TOTAL_BAR_LENGTH = 65. 8 | last_time = time.time() 9 | begin_time = last_time 10 | 11 | 12 | def progress_bar(current, total, msg=None): 13 | global last_time, begin_time 14 | if current == 0: 15 | begin_time = time.time() # Reset for new bar. 16 | 17 | cur_len = int(TOTAL_BAR_LENGTH * current / total) 18 | rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1 19 | 20 | sys.stdout.write(' [') 21 | for i in range(cur_len): 22 | sys.stdout.write('=') 23 | sys.stdout.write('>') 24 | for i in range(rest_len): 25 | sys.stdout.write('.') 26 | sys.stdout.write(']') 27 | 28 | cur_time = time.time() 29 | step_time = cur_time - last_time 30 | last_time = cur_time 31 | tot_time = cur_time - begin_time 32 | 33 | L = [] 34 | L.append(' Step: %s' % format_time(step_time)) 35 | L.append(' | Tot: %s' % format_time(tot_time)) 36 | if msg: 37 | L.append(' | ' + msg) 38 | 39 | msg = ''.join(L) 40 | sys.stdout.write(msg) 41 | for i in range(term_width - int(TOTAL_BAR_LENGTH) - len(msg) - 3): 42 | sys.stdout.write(' ') 43 | 44 | # Go back to the center of the bar. 45 | for i in range(term_width - int(TOTAL_BAR_LENGTH / 2) + 2): 46 | sys.stdout.write('\b') 47 | sys.stdout.write(' %d/%d ' % (current + 1, total)) 48 | 49 | if current < total - 1: 50 | sys.stdout.write('\r') 51 | else: 52 | sys.stdout.write('\n') 53 | sys.stdout.flush() 54 | 55 | 56 | def format_time(seconds): 57 | days = int(seconds / 3600 / 24) 58 | seconds = seconds - days * 3600 * 24 59 | hours = int(seconds / 3600) 60 | seconds = seconds - hours * 3600 61 | minutes = int(seconds / 60) 62 | seconds = seconds - minutes * 60 63 | secondsf = int(seconds) 64 | seconds = seconds - secondsf 65 | millis = int(seconds * 1000) 66 | 67 | f = '' 68 | i = 1 69 | if days > 0: 70 | f += str(days) + 'D' 71 | i += 1 72 | if hours > 0 and i <= 2: 73 | f += str(hours) + 'h' 74 | i += 1 75 | if minutes > 0 and i <= 2: 76 | f += str(minutes) + 'm' 77 | i += 1 78 | if secondsf > 0 and i <= 2: 79 | f += str(secondsf) + 's' 80 | i += 1 81 | if millis > 0 and i <= 2: 82 | f += str(millis) + 'ms' 83 | i += 1 84 | if f == '': 85 | f = '0ms' 86 | return f -------------------------------------------------------------------------------- /src/task_vectors.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class TaskVector(): 5 | def __init__(self, pretrained_checkpoint=None, finetuned_checkpoint=None, vector=None): 6 | """Initializes the task vector from a pretrained and a finetuned checkpoints. 7 | 8 | This can either be done by passing two state dicts (one corresponding to the 9 | pretrained model, and another to the finetuned model), or by directly passying in 10 | the task vector state dict. 11 | """ 12 | if vector is not None: 13 | self.vector = vector 14 | else: 15 | print(pretrained_checkpoint, finetuned_checkpoint) 16 | assert pretrained_checkpoint is not None and finetuned_checkpoint is not None 17 | with torch.no_grad(): 18 | print('TaskVector:' + finetuned_checkpoint) 19 | pretrained_state_dict = torch.load(pretrained_checkpoint).state_dict() 20 | finetuned_state_dict = torch.load(finetuned_checkpoint).state_dict() 21 | self.vector = {} 22 | for key in pretrained_state_dict: 23 | if pretrained_state_dict[key].dtype in [torch.int64, torch.uint8]: 24 | continue 25 | self.vector[key] = finetuned_state_dict[key] - pretrained_state_dict[key] 26 | print(len(self.vector)) 27 | 28 | def __add__(self, other): 29 | """Add two task vectors together.""" 30 | with torch.no_grad(): 31 | new_vector = {} 32 | for key in self.vector: 33 | if key not in other.vector: 34 | print(f'Warning, key {key} is not present in both task vectors.') 35 | continue 36 | new_vector[key] = self.vector[key] + other.vector[key] 37 | return TaskVector(vector=new_vector) 38 | 39 | def __radd__(self, other): 40 | if other is None or isinstance(other, int): 41 | return self 42 | return self.__add__(other) 43 | 44 | def __neg__(self): 45 | """Negate a task vector.""" 46 | with torch.no_grad(): 47 | new_vector = {} 48 | for key in self.vector: 49 | new_vector[key] = - self.vector[key] 50 | return TaskVector(vector=new_vector) 51 | 52 | def weightmerging(self, taskvectors, coefficients): 53 | with torch.no_grad(): 54 | new_vector = {} 55 | for key in taskvectors[0].vector: 56 | new_vector[key] = sum(coefficients[k] * taskvectors[k][key] for k in range(len(taskvectors))) 57 | return TaskVector(vector=new_vector) 58 | 59 | def apply_to(self, pretrained_checkpoint, scaling_coef=1.0): 60 | """Apply a task vector to a pretrained model.""" 61 | with torch.no_grad(): 62 | pretrained_model = torch.load(pretrained_checkpoint) 63 | new_state_dict = {} 64 | pretrained_state_dict = pretrained_model.state_dict() 65 | for key in pretrained_state_dict: 66 | if key not in self.vector: 67 | print(f'Warning: key {key} is present in the pretrained state dict but not in the task vector') 68 | continue 69 | new_state_dict[key] = pretrained_state_dict[key] + scaling_coef * self.vector[key] 70 | pretrained_model.load_state_dict(new_state_dict, strict=False) 71 | return pretrained_model 72 | 73 | -------------------------------------------------------------------------------- /src/tm_utils.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os, copy 3 | import torch 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | import re 7 | from collections import OrderedDict 8 | import torch.nn.functional as F 9 | # from transformers import AutoTokenizer, AutoModelForSeq2SeqLM 10 | 11 | ## Model conversion utils 12 | def state_dict_to_vector(state_dict, remove_keys=[]): 13 | shared_state_dict = copy.deepcopy(state_dict) 14 | for key in remove_keys: 15 | if key in shared_state_dict: 16 | del shared_state_dict[key] 17 | sorted_shared_state_dict = OrderedDict(sorted(shared_state_dict.items())) 18 | return torch.nn.utils.parameters_to_vector( 19 | [value.reshape(-1) for key, value in sorted_shared_state_dict.items()] 20 | ) 21 | 22 | 23 | def vector_to_state_dict(vector, state_dict, remove_keys=[]): 24 | # create a reference dict to define the order of the vector 25 | reference_dict = copy.deepcopy(state_dict) 26 | for key in remove_keys: 27 | if key in reference_dict: 28 | del reference_dict[key] 29 | sorted_reference_dict = OrderedDict(sorted(reference_dict.items())) 30 | 31 | # create a shared state dict using the refence dict 32 | torch.nn.utils.vector_to_parameters(vector, sorted_reference_dict.values()) 33 | 34 | # add back the encoder and decoder embedding weights. 35 | if "transformer.shared.weight" in sorted_reference_dict: 36 | for key in remove_keys: 37 | sorted_reference_dict[key] = sorted_reference_dict[ 38 | "transformer.shared.weight" 39 | ] 40 | return sorted_reference_dict 41 | 42 | 43 | def add_ptm_to_tv(tv_dict, ptm_dict): 44 | assert set(tv_dict.keys()) == set( 45 | ptm_dict.keys() 46 | ), "Differing parameter names in models." 47 | final_dict = copy.deepcopy(tv_dict) 48 | for k, v in ptm_dict.items(): 49 | final_dict[k] = tv_dict[k] + v 50 | return final_dict 51 | 52 | 53 | def check_parameterNamesMatch(checkpoints): 54 | parameter_names = set(checkpoints[0].keys()) 55 | 56 | if len(checkpoints) >= 2: 57 | # raise ValueError("Number of models is less than 2.") 58 | for checkpoint in checkpoints[1:]: 59 | current_parameterNames = set(checkpoint.keys()) 60 | if current_parameterNames != parameter_names: 61 | raise ValueError( 62 | "Differing parameter names in models. " 63 | f"The different parameters are {parameter_names.symmetric_difference(current_parameterNames)}" 64 | ) 65 | 66 | def check_state_dicts_equal(state_dict1, state_dict2): 67 | if set(state_dict1.keys()) != set(state_dict2.keys()): 68 | return False 69 | 70 | for key in state_dict1.keys(): 71 | if not torch.equal(state_dict1[key], state_dict2[key]): 72 | return False 73 | 74 | return True 75 | 76 | 77 | 78 | ## TIES MERGING UTILS 79 | 80 | def topk_values_mask(M, K=0.7, return_mask=False): 81 | if K > 1: 82 | K /= 100 83 | 84 | original_shape = M.shape 85 | if M.dim() == 1: 86 | M = M.unsqueeze(0) 87 | 88 | n, d = M.shape 89 | k = int(d * K) 90 | k = d - k # Keep top k elements instead of bottom k elements 91 | 92 | # Find the k-th smallest element by magnitude for each row 93 | kth_values, _ = M.abs().kthvalue(k, dim=1, keepdim=True) 94 | # Create a mask tensor with True for the top k elements in each row 95 | mask = M.abs() >= kth_values 96 | final_mask = mask.squeeze() if original_shape == M.squeeze().shape else mask 97 | 98 | if return_mask: 99 | return M * final_mask, final_mask.float().mean(dim=1), final_mask 100 | return M * final_mask, final_mask.float().mean(dim=1) 101 | 102 | 103 | def resolve_zero_signs(sign_to_mult, method="majority"): 104 | majority_sign = torch.sign(sign_to_mult.sum()) 105 | 106 | if method == "majority": 107 | sign_to_mult[sign_to_mult == 0] = majority_sign 108 | elif method == "minority": 109 | sign_to_mult[sign_to_mult == 0] = -1 * majority_sign 110 | return sign_to_mult 111 | 112 | 113 | def resolve_sign(Tensor): 114 | sign_to_mult = torch.sign(Tensor.sum(dim=0)) 115 | sign_to_mult = resolve_zero_signs(sign_to_mult, "majority") 116 | return sign_to_mult 117 | 118 | 119 | def disjoint_merge(Tensor, merge_func, sign_to_mult): 120 | merge_func = merge_func.split("-")[-1] 121 | 122 | # If sign is provided then we select the corresponding entries and aggregate. 123 | if sign_to_mult is not None: 124 | rows_to_keep = torch.where( 125 | sign_to_mult.unsqueeze(0) > 0, Tensor > 0, Tensor < 0 126 | ) 127 | selected_entries = Tensor * rows_to_keep 128 | # Else we select all non-zero entries and aggregate. 129 | else: 130 | rows_to_keep = Tensor != 0 131 | selected_entries = Tensor * rows_to_keep 132 | 133 | if merge_func == "mean": 134 | non_zero_counts = (selected_entries != 0).sum(dim=0).float() 135 | disjoint_aggs = torch.sum(selected_entries, dim=0) / torch.clamp(non_zero_counts, min=1) 136 | elif merge_func == "sum": 137 | disjoint_aggs = torch.sum(selected_entries, dim=0) 138 | elif merge_func == "max": 139 | disjoint_aggs = selected_entries.abs().max(dim=0)[0] 140 | disjoint_aggs *= sign_to_mult 141 | else: 142 | raise ValueError(f"Merge method {merge_func} is not defined.") 143 | 144 | return disjoint_aggs 145 | 146 | 147 | def ties_merging( 148 | flat_task_checks, 149 | reset_thresh=None, 150 | merge_func="", 151 | ): 152 | all_checks = flat_task_checks.clone() 153 | updated_checks, *_ = topk_values_mask( 154 | all_checks, K=reset_thresh, return_mask=False 155 | ) 156 | print(f"RESOLVING SIGN") 157 | final_signs = resolve_sign(updated_checks) 158 | assert final_signs is not None 159 | 160 | print(f"Disjoint AGGREGATION: {merge_func}") 161 | merged_tv = disjoint_merge(updated_checks, merge_func, final_signs) 162 | 163 | return merged_tv 164 | 165 | def disjoint_merge_split(Tensor, merge_func, sign_to_mult): 166 | merge_func = merge_func.split("-")[-1] 167 | 168 | # If sign is provided then we select the corresponding entries and aggregate. 169 | if sign_to_mult is not None: 170 | rows_to_keep = torch.where( 171 | sign_to_mult.unsqueeze(0) > 0, Tensor > 0, Tensor < 0 172 | ) 173 | selected_entries = Tensor * rows_to_keep 174 | # Else we select all non-zero entries and aggregate. 175 | else: 176 | rows_to_keep = Tensor != 0 177 | selected_entries = Tensor * rows_to_keep 178 | 179 | if merge_func == "sum": 180 | disjoint_aggs = torch.sum(selected_entries, dim=0) 181 | else: 182 | raise ValueError(f"Merge method {merge_func} is not defined.") 183 | 184 | return selected_entries, disjoint_aggs 185 | 186 | 187 | def ties_merging_split( 188 | flat_task_checks, 189 | reset_thresh=None, 190 | merge_func="", 191 | ): 192 | all_checks = flat_task_checks.clone() 193 | updated_checks, *_ = topk_values_mask( 194 | all_checks, K=reset_thresh, return_mask=False 195 | ) 196 | print(f"RESOLVING SIGN") 197 | final_signs = resolve_sign(updated_checks) 198 | assert final_signs is not None 199 | 200 | print(f"Disjoint AGGREGATION: {merge_func}") 201 | selected_entries, merged_tv = disjoint_merge_split(updated_checks, merge_func, final_signs) 202 | 203 | return selected_entries, merged_tv 204 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import pickle 4 | import math 5 | import numpy as np 6 | import torchvision 7 | 8 | class NormalizeInverse(torchvision.transforms.Normalize): 9 | def __init__(self, mean, std): 10 | mean = torch.as_tensor(mean) 11 | std = torch.as_tensor(std) 12 | std_inv = 1 / (std + 1e-7) 13 | mean_inv = -mean * std_inv 14 | super().__init__(mean=mean_inv, std=std_inv) 15 | 16 | def __call__(self, tensor): 17 | return super().__call__(tensor.clone()) 18 | 19 | def corner_mask_generation(patch=None, image_size=(3, 224, 224)): 20 | applied_patch = np.zeros(image_size) 21 | x_location = image_size[1]-patch.shape[1] 22 | y_location = image_size[2]-patch.shape[2] 23 | applied_patch[:, x_location:x_location + patch.shape[1], y_location:y_location + patch.shape[2]] = patch 24 | mask = applied_patch.copy() 25 | mask[mask != 0] = 1.0 26 | return applied_patch, mask, x_location, y_location 27 | 28 | def assign_learning_rate(param_group, new_lr): 29 | param_group["lr"] = new_lr 30 | 31 | 32 | def _warmup_lr(base_lr, warmup_length, step): 33 | return base_lr * (step + 1) / warmup_length 34 | 35 | 36 | def cosine_lr(optimizer, base_lrs, warmup_length, steps): 37 | if not isinstance(base_lrs, list): 38 | base_lrs = [base_lrs for _ in optimizer.param_groups] 39 | assert len(base_lrs) == len(optimizer.param_groups) 40 | def _lr_adjuster(step): 41 | for param_group, base_lr in zip(optimizer.param_groups, base_lrs): 42 | if step < warmup_length: 43 | lr = _warmup_lr(base_lr, warmup_length, step) 44 | else: 45 | e = step - warmup_length 46 | es = steps - warmup_length 47 | lr = 0.5 * (1 + np.cos(np.pi * e / es)) * base_lr 48 | assign_learning_rate(param_group, lr) 49 | return _lr_adjuster 50 | 51 | 52 | def accuracy(output, target, topk=(1,)): 53 | pred = output.topk(max(topk), 1, True, True)[1].t() 54 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 55 | return [float(correct[:k].reshape(-1).float().sum(0, keepdim=True).cpu().numpy()) for k in topk] 56 | 57 | 58 | def torch_load_old(save_path, device=None): 59 | with open(save_path, 'rb') as f: 60 | classifier = pickle.load(f) 61 | if device is not None: 62 | classifier = classifier.to(device) 63 | return classifier 64 | 65 | 66 | def torch_save(model, save_path): 67 | if os.path.dirname(save_path) != '': 68 | os.makedirs(os.path.dirname(save_path), exist_ok=True) 69 | torch.save(model.cpu(), save_path) 70 | 71 | 72 | def torch_load(save_path, device=None): 73 | model = torch.load(save_path) 74 | if device is not None: 75 | model = model.to(device) 76 | return model 77 | 78 | 79 | 80 | def get_logits(inputs, classifier): 81 | assert callable(classifier) 82 | if hasattr(classifier, 'to'): 83 | classifier = classifier.to(inputs.device) 84 | return classifier(inputs) 85 | 86 | 87 | def get_probs(inputs, classifier): 88 | if hasattr(classifier, 'predict_proba'): 89 | probs = classifier.predict_proba(inputs.detach().cpu().numpy()) 90 | return torch.from_numpy(probs) 91 | logits = get_logits(inputs, classifier) 92 | return logits.softmax(dim=1) 93 | 94 | 95 | class LabelSmoothing(torch.nn.Module): 96 | def __init__(self, smoothing=0.0): 97 | super(LabelSmoothing, self).__init__() 98 | self.confidence = 1.0 - smoothing 99 | self.smoothing = smoothing 100 | 101 | def forward(self, x, target): 102 | logprobs = torch.nn.functional.log_softmax(x, dim=-1) 103 | 104 | nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1)) 105 | nll_loss = nll_loss.squeeze(1) 106 | smooth_loss = -logprobs.mean(dim=-1) 107 | loss = self.confidence * nll_loss + self.smoothing * smooth_loss 108 | return loss.mean() 109 | --------------------------------------------------------------------------------