├── README.md
├── data
└── data_links.txt
└── src
├── args.py
├── datasets
├── cars.py
├── cifar10.py
├── cifar100.py
├── common.py
├── dtd.py
├── eurosat.py
├── flowers.py
├── gtsrb.py
├── imagenet.py
├── imagenet100.py
├── mnist.py
├── pets.py
├── registry.py
├── resisc45.py
├── stl10.py
├── sun397.py
├── svhn.py
└── templates.py
├── eval.py
├── figures
├── comparison.png
├── exp.png
├── figures.txt
├── main_table.png
├── neulig_overview.png
└── neulig_train_pip.png
├── finetune_clean.py
├── heads.py
├── modeling.py
├── neulig_main.py
├── pgbar.py
├── task_vectors.py
├── tm_utils.py
└── utils.py
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
10 |
11 |
12 | Qi Li Runpeng Yu Xinchao Wang†
13 |
14 |
15 |
xML-Lab, National University of Singapore
16 |
†corresponding author
17 |
18 |
19 |
20 |
21 | ------------------
22 | TL;DR (1) - Achieve performance consistency between merging and ensembling in a unified framework.
23 |
24 | TL;DR (2) - Provide theoretical support for the realization of the performance consistency.
25 |
26 |
27 | ## Graphical Abstract
28 |
29 |
30 |
31 |
32 | Figure 1. An illustration of Portland, which consists of a linear layer followed by a softmax function.
33 | |
34 |
35 |
36 | Figure 2. The training process of Portland.
37 | |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |

46 |
47 |
48 | Figure 3. A toy experiment to verify theoretical feasibility. In this experiment, we merged two models that were fine-tuned on different datasets. Marker shapes represent different methods, while colors indicate different experimental groups, with each group using a distinct combination of datasets. In total, 10 groups are conducted (represented by 10 different colors). Hollow markers for each method indicate the average results across these 10 groups.
49 |
50 |
51 |
52 |

53 |
54 |
55 | Table 1. The asterisk indicates that the condition is partially satisfied. For Simple-Averaging, the theoretical discussion is limited to the relationship between the performance of merging two models and that of ensembling. Furthermore, although both Simple-Averaging and Task-Arithmetic can be applied to CNN-based models, their performance is suboptimal. In the case of Diverse-Origin Models, all previous methods yield performance close to random guessing, but our conclusions remain applicable.
56 |
57 |
58 |
59 |

60 |
61 |
62 | Table 2. Results of various methods across multiple datasets, including the merging performance, the ensembling performance, and the performance gap for both CLIP-RN50 and CLIP-ViT-B/32.
63 |
64 | ## Installation & Preparation
65 |
66 | 1. Clone the repo and prepare the virtual environment.
67 |
68 | ```
69 | git clone https://github.com/LiQiiiii/Neural-Ligand.git
70 | ```
71 |
72 | ```
73 | cd Neural-Ligand
74 | ```
75 |
76 | ```
77 | conda create -n neulig python=3.8.10
78 | ```
79 |
80 | ```
81 | conda activate neulig
82 | ```
83 |
84 | The codes are tested on torch 2.0.0 and torchvision 0.15.1.
85 |
86 | 2. Prepare the dataset and models. The download link of the datasets used in the paper can be found in `./data/data_links.txt`. Save them in the `./data` folder. Run:
87 |
88 | ```
89 | python ./src/finetune_clean.py
90 | ```
91 |
92 | to get the corresponding models for the training and evaluation.
93 |
94 |
95 |
96 | ## Training & Evaluation
97 |
98 | ```
99 | python ./src/neulig_main.py --num_co_models 2 --global_epoch 1000 --alignment_type sup --model RN50
100 | ```
101 |
102 | where `--num_co_models` is the number of collaborating models, `--alignment_type` controls the alignment term (i.e., sup/semi), and `--model` controls the model type (i.e., RN50/ViT-B-32/ViT-L-14).
103 |
104 | ## Citation
105 |
106 | If you finding our work interesting or helpful to you, please cite as follows:
107 |
108 | ```
109 | @article{li2025multi,
110 | title={Multi-Level Collaboration in Model Merging},
111 | author={Li, Qi and Yu, Runpeng and Wang, Xinchao},
112 | journal={arXiv preprint arXiv:2503.01268},
113 | year={2025}
114 | }
115 | ```
116 |
--------------------------------------------------------------------------------
/data/data_links.txt:
--------------------------------------------------------------------------------
1 |
2 | CIFAR10, CIFAR10, MNIST, GTSRB, SVHN can be automatically downloaded via torchvision.
3 |
4 | # RESISC45
5 | https://huggingface.co/datasets/timm/resisc45
6 |
7 |
8 | # STL10
9 | https://ai.stanford.edu/~acoates/stl10
10 |
11 |
--------------------------------------------------------------------------------
/src/args.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import torch
4 | def parse_arguments():
5 | parser = argparse.ArgumentParser()
6 | parser.add_argument(
7 | "--data-location",
8 | type=str,
9 | default=os.path.expanduser('./data'),
10 | help="The root directory for the datasets.",
11 | )
12 | parser.add_argument(
13 | "--eval-datasets",
14 | default=None,
15 | type=lambda x: x.split(","),
16 | help="Which datasets to use for evaluation. Split by comma, e.g. MNIST,EuroSAT. "
17 | )
18 | parser.add_argument(
19 | "--train-dataset",
20 | default=None,
21 | type=lambda x: x.split(","),
22 | help="Which dataset(s) to patch on.",
23 | )
24 | parser.add_argument(
25 | "--exp_name",
26 | type=str,
27 | default=None,
28 | help="Name of the experiment, for organization purposes only."
29 | )
30 | parser.add_argument(
31 | "--results-db",
32 | type=str,
33 | default=None,
34 | help="Where to store the results, else does not store",
35 | )
36 | parser.add_argument(
37 | "--batch-size",
38 | type=int,
39 | default=128,
40 | )
41 | parser.add_argument(
42 | "--lr",
43 | type=float,
44 | default=0.001,
45 | help="Learning rate."
46 | )
47 | parser.add_argument(
48 | "--wd",
49 | type=float,
50 | default=0.1,
51 | help="Weight decay"
52 | )
53 | parser.add_argument(
54 | "--ls",
55 | type=float,
56 | default=0.0,
57 | help="Label smoothing."
58 | )
59 | parser.add_argument(
60 | "--warmup_length",
61 | type=int,
62 | default=500,
63 | )
64 | parser.add_argument(
65 | "--epochs",
66 | type=int,
67 | default=10,
68 | )
69 | parser.add_argument(
70 | "--load",
71 | type=lambda x: x.split(","),
72 | default=None,
73 | help="Optionally load _classifiers_, e.g. a zero shot classifier or probe or ensemble both.",
74 | )
75 | parser.add_argument(
76 | "--save",
77 | type=str,
78 | default=None,
79 | help="Optionally save a _classifier_, e.g. a zero shot classifier or probe.",
80 | )
81 | parser.add_argument(
82 | "--cache-dir",
83 | type=str,
84 | default=None,
85 | help="Directory for caching features and encoder",
86 | )
87 | parser.add_argument(
88 | "--openclip-cachedir",
89 | type=str,
90 | default='./open_clip',
91 | help='Directory for caching models from OpenCLIP'
92 | )
93 |
94 | parser.add_argument(
95 | "--ckpt-dir",
96 | type=str,
97 | default='./checkpoints',
98 | )
99 | parser.add_argument(
100 | "--logs-dir",
101 | type=str,
102 | default='./logs/',
103 | )
104 | parser.add_argument(
105 | "--suffix",
106 | type=str,
107 | default='Val',
108 | )
109 | parser.add_argument(
110 | "--ada_name",
111 | type=str,
112 | default='lambda.pt',
113 | )
114 | parser.add_argument(
115 | "--scaling-coef-",
116 | type=float,
117 | default=0.3,
118 | help="Label smoothing."
119 | )
120 | parser.add_argument(
121 | "--model",
122 | type=str,
123 | default='RN50',
124 | help="The type of model (e.g. RN50, ViT-B-32, ViT-L-14).",
125 | )
126 | parser.add_argument(
127 | "--num_co_models",
128 | type=int,
129 | default=2,
130 | help="number of collaborating models."
131 | )
132 | parser.add_argument(
133 | "--global_epoch",
134 | type=int,
135 | default=1000,
136 | help="number of global epochs."
137 | )
138 | parser.add_argument(
139 | "--scaling",
140 | type=int,
141 | default=100.0,
142 | help="scaling params."
143 | )
144 |
145 | parser.add_argument(
146 | "--alignment_type",
147 | type=str,
148 | default='sup',
149 | help="sup for supervised and semi for semisupervised."
150 | )
151 |
152 | parsed_args = parser.parse_args()
153 | parsed_args.device = "cuda" if torch.cuda.is_available() else "cpu"
154 |
155 | if parsed_args.load is not None and len(parsed_args.load) == 1:
156 | parsed_args.load = parsed_args.load[0]
157 | return parsed_args
158 |
--------------------------------------------------------------------------------
/src/datasets/cars.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torchvision.datasets as datasets
4 | import pathlib
5 | from typing import Callable, Optional, Any, Tuple
6 | from PIL import Image
7 | from torchvision.datasets.utils import download_and_extract_archive, download_url, verify_str_arg
8 | from torchvision.datasets.vision import VisionDataset
9 |
10 | class PytorchStanfordCars(VisionDataset):
11 | """`Stanford Cars `_ Dataset
12 |
13 | The Cars dataset contains 16,185 images of 196 classes of cars. The data is
14 | split into 8,144 training images and 8,041 testing images, where each class
15 | has been split roughly in a 50-50 split
16 |
17 | .. note::
18 |
19 | This class needs `scipy `_ to load target files from `.mat` format.
20 |
21 | Args:
22 | root (string): Root directory of dataset
23 | split (string, optional): The dataset split, supports ``"train"`` (default) or ``"test"``.
24 | transform (callable, optional): A function/transform that takes in an PIL image
25 | and returns a transformed version. E.g, ``transforms.RandomCrop``
26 | target_transform (callable, optional): A function/transform that takes in the
27 | target and transforms it.
28 | download (bool, optional): If True, downloads the dataset from the internet and
29 | puts it in root directory. If dataset is already downloaded, it is not
30 | downloaded again."""
31 |
32 | def __init__(
33 | self,
34 | root: str,
35 | split: str = "train",
36 | transform: Optional[Callable] = None,
37 | target_transform: Optional[Callable] = None,
38 | download: bool = False,
39 | ) -> None:
40 |
41 | try:
42 | import scipy.io as sio
43 | except ImportError:
44 | raise RuntimeError("Scipy is not found. This dataset needs to have scipy installed: pip install scipy")
45 |
46 | super().__init__(root, transform=transform, target_transform=target_transform)
47 |
48 | self._split = verify_str_arg(split, "split", ("train", "test"))
49 | self._base_folder = pathlib.Path('./data') / "stanford_cars"
50 | devkit = self._base_folder / "devkit"
51 |
52 | if self._split == "train":
53 | self._annotations_mat_path = devkit / "cars_train_annos.mat"
54 | self._images_base_path = self._base_folder / "cars_train"
55 | else:
56 | self._annotations_mat_path = devkit / "cars_test_annos_withlabels.mat"
57 | self._images_base_path = self._base_folder / "cars_test"
58 |
59 | # if download:
60 | # self.download()
61 |
62 | if not self._check_exists():
63 | raise RuntimeError("Dataset not found. You can use download=True to download it")
64 |
65 | self._samples = [
66 | (
67 | str(self._images_base_path / annotation["fname"]),
68 | annotation["class"] - 1, # Original target mapping starts from 1, hence -1
69 | )
70 | for annotation in sio.loadmat(self._annotations_mat_path, squeeze_me=True)["annotations"]
71 | ]
72 |
73 | self.classes = sio.loadmat(str(devkit / "cars_meta.mat"), squeeze_me=True)["class_names"].tolist()
74 | self.class_to_idx = {cls: i for i, cls in enumerate(self.classes)}
75 |
76 | def __len__(self) -> int:
77 | return len(self._samples)
78 |
79 | def __getitem__(self, idx: int) -> Tuple[Any, Any]:
80 | """Returns pil_image and class_id for given index"""
81 | image_path, target = self._samples[idx]
82 | pil_image = Image.open(image_path).convert("RGB")
83 |
84 | if self.transform is not None:
85 | pil_image = self.transform(pil_image)
86 | if self.target_transform is not None:
87 | target = self.target_transform(target)
88 | return pil_image, target, idx
89 |
90 |
91 | def download(self) -> None:
92 | if self._check_exists():
93 | return
94 |
95 | download_and_extract_archive(
96 | url="https://ai.stanford.edu/~jkrause/cars/car_devkit.tgz",
97 | download_root=str(self._base_folder),
98 | md5="c3b158d763b6e2245038c8ad08e45376",
99 | )
100 | if self._split == "train":
101 | download_and_extract_archive(
102 | url="https://ai.stanford.edu/~jkrause/car196/cars_train.tgz",
103 | download_root=str(self._base_folder),
104 | md5="065e5b463ae28d29e77c1b4b166cfe61",
105 | )
106 | else:
107 | download_and_extract_archive(
108 | url="https://ai.stanford.edu/~jkrause/car196/cars_test.tgz",
109 | download_root=str(self._base_folder),
110 | md5="4ce7ebf6a94d07f1952d94dd34c4d501",
111 | )
112 | download_url(
113 | url="https://ai.stanford.edu/~jkrause/car196/cars_test_annos_withlabels.mat",
114 | root=str(self._base_folder),
115 | md5="b0a2b23655a3edd16d84508592a98d10",
116 | )
117 |
118 | def _check_exists(self) -> bool:
119 | print(self._base_folder / "devkit")
120 | if not (self._base_folder / "devkit").is_dir():
121 | return False
122 |
123 | return self._annotations_mat_path.exists() and self._images_base_path.is_dir()
124 |
125 |
126 | class Cars:
127 | def __init__(self,
128 | preprocess,
129 | location=os.path.expanduser('./data'),
130 | batch_size=32,
131 | num_workers=16):
132 | # Data loading code
133 |
134 | self.train_dataset = PytorchStanfordCars(location, 'train', preprocess, download=True)
135 | self.train_loader = torch.utils.data.DataLoader(
136 | self.train_dataset,
137 | shuffle=True,
138 | batch_size=batch_size,
139 | num_workers=num_workers,
140 | )
141 |
142 | self.test_dataset = PytorchStanfordCars(location, 'test', preprocess, download=True)
143 | self.test_loader = torch.utils.data.DataLoader(
144 | self.test_dataset,
145 | batch_size=batch_size,
146 | num_workers=num_workers
147 | )
148 | self.test_loader_shuffle = torch.utils.data.DataLoader(
149 | self.test_dataset,
150 | shuffle=True,
151 | batch_size=batch_size,
152 | num_workers=num_workers
153 | )
154 | idx_to_class = dict((v, k)
155 | for k, v in self.train_dataset.class_to_idx.items())
156 | self.classnames = [idx_to_class[i].replace(
157 | '_', ' ') for i in range(len(idx_to_class))]
158 |
--------------------------------------------------------------------------------
/src/datasets/cifar10.py:
--------------------------------------------------------------------------------
1 | import os
2 | import PIL
3 | import torch
4 | import numpy as np
5 | import torchvision
6 | from torchvision import transforms
7 | from torchvision.datasets import CIFAR10 as PyTorchCIFAR10
8 | from torchvision.datasets import VisionDataset
9 | from PIL import Image
10 |
11 | cifar_classnames = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
12 |
13 | class MyPyTorchCIFAR10(PyTorchCIFAR10):
14 | def __init__(self, root, download, train, transform):
15 | super().__init__(root=root, download=download, train=train, transform=transform)
16 |
17 | def __getitem__(self, index: int):
18 | """
19 | Args:
20 | index (int): Index
21 |
22 | Returns:
23 | tuple: (image, target) where target is index of the target class.
24 | """
25 | img, target = self.data[index], self.targets[index]
26 |
27 | # doing this so that it is consistent with all other datasets
28 | # to return a PIL Image
29 | img = Image.fromarray(img)
30 |
31 | if self.transform is not None:
32 | img = self.transform(img)
33 |
34 | if self.target_transform is not None:
35 | target = self.target_transform(target)
36 |
37 | return img, target, index
38 |
39 | class CIFAR10:
40 | def __init__(self, preprocess,
41 | location=os.path.expanduser('./data'),
42 | batch_size=128,
43 | num_workers=16):
44 |
45 |
46 | self.train_dataset = MyPyTorchCIFAR10(
47 | root=location, download=True, train=True, transform=preprocess
48 | )
49 |
50 | self.train_loader = torch.utils.data.DataLoader(
51 | self.train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers
52 | )
53 |
54 | self.test_dataset = MyPyTorchCIFAR10(
55 | root=location, download=True, train=False, transform=preprocess
56 | )
57 |
58 | self.test_loader = torch.utils.data.DataLoader(
59 | self.test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers
60 | )
61 |
62 | self.test_loader_shuffle = torch.utils.data.DataLoader(
63 | self.test_dataset,
64 | shuffle=True,
65 | batch_size=batch_size,
66 | num_workers=num_workers
67 | )
68 |
69 | self.classnames = self.test_dataset.classes
70 |
71 | def convert(x):
72 | if isinstance(x, np.ndarray):
73 | return torchvision.transforms.functional.to_pil_image(x)
74 | return x
75 |
76 | class BasicVisionDataset(VisionDataset):
77 | def __init__(self, images, targets, transform=None, target_transform=None):
78 | if transform is not None:
79 | transform.transforms.insert(0, convert)
80 | super(BasicVisionDataset, self).__init__(root=None, transform=transform, target_transform=target_transform)
81 | assert len(images) == len(targets)
82 |
83 | self.images = images
84 | self.targets = targets
85 |
86 | def __getitem__(self, index):
87 | return self.transform(self.images[index]), self.targets[index]
88 |
89 | def __len__(self):
90 | return len(self.targets)
91 |
--------------------------------------------------------------------------------
/src/datasets/cifar100.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from torchvision.datasets import CIFAR100 as PyTorchCIFAR100
4 | from PIL import Image
5 |
6 | class MyPyTorchCIFAR100(PyTorchCIFAR100):
7 | def __init__(self, root, download, train, transform):
8 | super().__init__(root=root, download=download, train=train, transform=transform)
9 |
10 | def __getitem__(self, index: int):
11 | """
12 | Args:
13 | index (int): Index
14 |
15 | Returns:
16 | tuple: (image, target) where target is index of the target class.
17 | """
18 | img, target = self.data[index], self.targets[index]
19 |
20 | # doing this so that it is consistent with all other datasets
21 | # to return a PIL Image
22 | img = Image.fromarray(img)
23 |
24 | if self.transform is not None:
25 | img = self.transform(img)
26 |
27 | if self.target_transform is not None:
28 | target = self.target_transform(target)
29 |
30 | return img, target, index
31 |
32 | class CIFAR100:
33 | def __init__(self,
34 | preprocess,
35 | location=os.path.expanduser('./data'),
36 | batch_size=128,
37 | num_workers=16):
38 |
39 | self.train_dataset = MyPyTorchCIFAR100(
40 | root=location, download=True, train=True, transform=preprocess
41 | )
42 |
43 | self.train_loader = torch.utils.data.DataLoader(
44 | self.train_dataset, batch_size=batch_size, num_workers=num_workers
45 | )
46 |
47 | self.test_dataset = MyPyTorchCIFAR100(
48 | root=location, download=True, train=False, transform=preprocess
49 | )
50 |
51 | self.test_loader = torch.utils.data.DataLoader(
52 | self.test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers
53 | )
54 |
55 | self.test_loader_shuffle = torch.utils.data.DataLoader(
56 | self.test_dataset,
57 | shuffle=True,
58 | batch_size=batch_size,
59 | num_workers=num_workers
60 | )
61 |
62 | self.classnames = self.test_dataset.classes
63 |
64 |
65 |
--------------------------------------------------------------------------------
/src/datasets/common.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import json
4 | import glob
5 | import collections
6 | import random
7 | import numpy as np
8 | from tqdm import tqdm
9 | import torchvision.datasets as datasets
10 | from torch.utils.data import Dataset, DataLoader, Sampler
11 |
12 | class SubsetSampler(Sampler):
13 | def __init__(self, indices):
14 | self.indices = indices
15 |
16 | def __iter__(self):
17 | return (i for i in self.indices)
18 |
19 | def __len__(self):
20 | return len(self.indices)
21 |
22 | class ImageFolderWithPaths(datasets.ImageFolder):
23 | def __init__(self, path, transform, flip_label_prob=0.0):
24 | super().__init__(path, transform)
25 | self.flip_label_prob = flip_label_prob
26 | if self.flip_label_prob > 0:
27 | print(f'Flipping labels with probability {self.flip_label_prob}')
28 | num_classes = len(self.classes)
29 | for i in range(len(self.samples)):
30 | if random.random() < self.flip_label_prob:
31 | new_label = random.randint(0, num_classes-1)
32 | self.samples[i] = (
33 | self.samples[i][0],
34 | new_label
35 | )
36 |
37 | def __getitem__(self, index):
38 | image, label = super(ImageFolderWithPaths, self).__getitem__(index)
39 | return {
40 | 'images': image,
41 | 'labels': label,
42 | 'image_paths': self.samples[index][0]
43 | }
44 |
45 | def maybe_dictionarize(batch): # double check
46 | if isinstance(batch, dict):
47 | return batch
48 |
49 | if len(batch) ==2:
50 | batch = {'images': batch[0], 'labels': batch[1]}
51 | elif len(batch) == 3:
52 | batch = {'images': batch[0], 'labels': batch[1], 'indices': batch[2]}
53 | elif len(batch) == 4:
54 | batch = {'images': batch[0], 'labels': batch[1], 'indices': batch[2], 'metadata': batch[3]}
55 | else:
56 | raise ValueError(f'Unexpected number of elements: {len(batch)}')
57 |
58 | return batch
59 |
60 | def get_features_helper(image_encoder, dataloader, device):
61 | all_data = collections.defaultdict(list)
62 |
63 | image_encoder = image_encoder.to(device)
64 | image_encoder = torch.nn.DataParallel(image_encoder, device_ids=[x for x in range(torch.cuda.device_count())])
65 | image_encoder.eval()
66 |
67 | with torch.no_grad():
68 | for batch in tqdm(dataloader):
69 | batch = maybe_dictionarize(batch)
70 | features = image_encoder(batch['images'].cuda())
71 |
72 | all_data['features'].append(features.cpu())
73 |
74 | for key, val in batch.items():
75 | if key == 'images':
76 | continue
77 | if hasattr(val, 'cpu'):
78 | val = val.cpu()
79 | all_data[key].append(val)
80 | else:
81 | all_data[key].extend(val)
82 |
83 | for key, val in all_data.items():
84 | if torch.is_tensor(val[0]):
85 | all_data[key] = torch.cat(val).numpy()
86 |
87 | return all_data
88 |
89 | def get_features(is_train, image_encoder, dataset, device):
90 | split = 'train' if is_train else 'val'
91 | dname = type(dataset).__name__
92 | if image_encoder.cache_dir is not None:
93 | cache_dir = f'{image_encoder.cache_dir}/{dname}/{split}'
94 | cached_files = glob.glob(f'{cache_dir}/*')
95 | if image_encoder.cache_dir is not None and len(cached_files) > 0:
96 | print(f'Getting features from {cache_dir}')
97 | data = {}
98 | for cached_file in cached_files:
99 | name = os.path.splitext(os.path.basename(cached_file))[0]
100 | data[name] = torch.load(cached_file)
101 | else:
102 | print(f'Did not find cached features at {cache_dir}. Building from scratch.')
103 | loader = dataset.train_loader if is_train else dataset.test_loader
104 | data = get_features_helper(image_encoder, loader, device)
105 | if image_encoder.cache_dir is None:
106 | print('Not caching because no cache directory was passed.')
107 | else:
108 | os.makedirs(cache_dir, exist_ok=True)
109 | print(f'Caching data at {cache_dir}')
110 | for name, val in data.items():
111 | torch.save(val, f'{cache_dir}/{name}.pt')
112 | return data
113 |
114 | class FeatureDataset(Dataset):
115 | def __init__(self, is_train, image_encoder, dataset, device):
116 | self.data = get_features(is_train, image_encoder, dataset, device)
117 |
118 | def __len__(self):
119 | return len(self.data['features'])
120 |
121 | def __getitem__(self, idx):
122 | data = {k: v[idx] for k, v in self.data.items()}
123 | data['features'] = torch.from_numpy(data['features']).float()
124 | return data
125 |
126 | def get_dataloader(dataset, split):
127 | if split=='train':
128 | dataloader = dataset.train_loader
129 | elif split=='test':
130 | dataloader = dataset.test_loader
131 | elif split=='test_shuffled':
132 | dataloader = dataset.test_loader_shuffle
133 | elif split=='dev':
134 | dataloader = dataset.test_loader_shuffle
135 | elif split=='shadowtrain':
136 | dataloader = dataset.shadowtrain_loader
137 | return dataloader
--------------------------------------------------------------------------------
/src/datasets/dtd.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torchvision.datasets as datasets
4 |
5 | class ImageFolderDataset(datasets.ImageFolder):
6 | def __init__(self, root, transform):
7 | super().__init__(root, transform)
8 |
9 | def __getitem__(self, index: int):
10 | path, target = self.samples[index]
11 | sample = self.loader(path)
12 | if self.transform is not None:
13 | sample = self.transform(sample)
14 | if self.target_transform is not None:
15 | target = self.target_transform(target)
16 | return sample, target, index
17 |
18 | class DTD:
19 | def __init__(self,
20 | preprocess,
21 | location=os.path.expanduser('./data'),
22 | batch_size=32,
23 | num_workers=16):
24 | # Data loading code
25 | location = './data'
26 | traindir = os.path.join(location, 'dtd', 'train')
27 | valdir = os.path.join(location, 'dtd', 'test')
28 |
29 | self.train_dataset = ImageFolderDataset(
30 | traindir, transform=preprocess)
31 | self.train_loader = torch.utils.data.DataLoader(
32 | self.train_dataset,
33 | shuffle=True,
34 | batch_size=batch_size,
35 | num_workers=num_workers,
36 | )
37 |
38 | self.test_dataset = ImageFolderDataset(valdir, transform=preprocess)
39 | self.test_loader = torch.utils.data.DataLoader(
40 | self.test_dataset,
41 | batch_size=batch_size,
42 | num_workers=num_workers
43 | )
44 | self.test_loader_shuffle = torch.utils.data.DataLoader(
45 | self.test_dataset,
46 | shuffle=True,
47 | batch_size=batch_size,
48 | num_workers=num_workers
49 | )
50 | idx_to_class = dict((v, k)
51 | for k, v in self.train_dataset.class_to_idx.items())
52 | self.classnames = [idx_to_class[i].replace(
53 | '_', ' ') for i in range(len(idx_to_class))]
--------------------------------------------------------------------------------
/src/datasets/eurosat.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torchvision.datasets as datasets
4 | import re
5 | import numpy as np
6 |
7 | def pretify_classname(classname):
8 | l = re.findall(r'[A-Z](?:[a-z]+|[A-Z]*(?=[A-Z]|$))', classname)
9 | l = [i.lower() for i in l]
10 | out = ' '.join(l)
11 | if out.endswith('al'):
12 | return out + ' area'
13 | return out
14 |
15 | class ImageFolderDataset(datasets.ImageFolder):
16 | def __init__(self, root, transform):
17 | super().__init__(root, transform)
18 | self.indices = np.arange(len(self.samples))
19 |
20 | def __getitem__(self, index: int):
21 | path, target = self.samples[index]
22 | sample = self.loader(path)
23 | if self.transform is not None:
24 | sample = self.transform(sample)
25 | if self.target_transform is not None:
26 | target = self.target_transform(target)
27 | return sample, target, index
28 |
29 | class EuroSATBase:
30 | def __init__(self,
31 | preprocess,
32 | test_split,
33 | location='./data',
34 | batch_size=32,
35 | num_workers=16):
36 | # Data loading code
37 | location = './data'
38 | traindir = os.path.join(location, 'EuroSAT_splits', 'train')
39 | testdir = os.path.join(location, 'EuroSAT_splits', test_split)
40 |
41 |
42 | self.train_dataset = ImageFolderDataset(traindir, transform=preprocess)
43 | self.train_loader = torch.utils.data.DataLoader(
44 | self.train_dataset,
45 | shuffle=True,
46 | batch_size=batch_size,
47 | num_workers=num_workers,
48 | )
49 |
50 | self.test_dataset = ImageFolderDataset(testdir, transform=preprocess)
51 | self.test_loader = torch.utils.data.DataLoader(
52 | self.test_dataset,
53 | batch_size=batch_size,
54 | num_workers=num_workers
55 | )
56 | self.test_loader_shuffle = torch.utils.data.DataLoader(
57 | self.test_dataset,
58 | shuffle=True,
59 | batch_size=batch_size,
60 | num_workers=num_workers
61 | )
62 | idx_to_class = dict((v, k)
63 | for k, v in self.train_dataset.class_to_idx.items())
64 | self.classnames = [idx_to_class[i].replace('_', ' ') for i in range(len(idx_to_class))]
65 | self.classnames = [pretify_classname(c) for c in self.classnames]
66 | ours_to_open_ai = {
67 | 'annual crop': 'annual crop land',
68 | 'forest': 'forest',
69 | 'herbaceous vegetation': 'brushland or shrubland',
70 | 'highway': 'highway or road',
71 | 'industrial area': 'industrial buildings or commercial buildings',
72 | 'pasture': 'pasture land',
73 | 'permanent crop': 'permanent crop land',
74 | 'residential area': 'residential buildings or homes or apartments',
75 | 'river': 'river',
76 | 'sea lake': 'lake or sea',
77 | }
78 | for i in range(len(self.classnames)):
79 | self.classnames[i] = ours_to_open_ai[self.classnames[i]]
80 |
81 |
82 | class EuroSAT(EuroSATBase):
83 | def __init__(self,
84 | preprocess,
85 | location='~/datasets',
86 | batch_size=32,
87 | num_workers=16):
88 | super().__init__(preprocess, 'test', location, batch_size, num_workers)
89 |
90 |
91 | class EuroSATVal(EuroSATBase):
92 | def __init__(self,
93 | preprocess,
94 | location='~/datasets',
95 | batch_size=32,
96 | num_workers=16):
97 | super().__init__(preprocess, 'val', location, batch_size, num_workers)
--------------------------------------------------------------------------------
/src/datasets/flowers.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torchvision.datasets as datasets
4 | import re
5 | import numpy as np
6 |
7 | class ImageFolderDataset(datasets.ImageFolder):
8 | def __init__(self, root, transform):
9 | super().__init__(root, transform)
10 | self.indices = np.arange(len(self.samples))
11 |
12 | def __getitem__(self, index: int):
13 | path, target = self.samples[index]
14 | sample = self.loader(path)
15 | if self.transform is not None:
16 | sample = self.transform(sample)
17 | if self.target_transform is not None:
18 | target = self.target_transform(target)
19 | return sample, target, index
20 |
21 | class FlowersBase:
22 | def __init__(self,
23 | preprocess,
24 | test_split,
25 | location='./data',
26 | batch_size=32,
27 | num_workers=16):
28 | # Data loading code
29 | location = './data'
30 | traindir = os.path.join(location, 'flowers', 'train')
31 | testdir = os.path.join(location, 'flowers', test_split)
32 |
33 |
34 | self.train_dataset = ImageFolderDataset(traindir, transform=preprocess)
35 | self.train_loader = torch.utils.data.DataLoader(
36 | self.train_dataset,
37 | shuffle=True,
38 | batch_size=batch_size,
39 | num_workers=num_workers,
40 | )
41 |
42 | self.test_dataset = ImageFolderDataset(testdir, transform=preprocess)
43 | self.test_loader = torch.utils.data.DataLoader(
44 | self.test_dataset,
45 | batch_size=batch_size,
46 | num_workers=num_workers
47 | )
48 | self.test_loader_shuffle = torch.utils.data.DataLoader(
49 | self.test_dataset,
50 | shuffle=True,
51 | batch_size=batch_size,
52 | num_workers=num_workers
53 | )
54 |
55 | idx_to_class = dict((v, k)
56 | for k, v in self.train_dataset.class_to_idx.items())
57 | self.classnames = [idx_to_class[i].replace('_', ' ') for i in range(len(idx_to_class))]
58 |
59 |
60 | class Flowers(FlowersBase):
61 | def __init__(self,
62 | preprocess,
63 | location='~/datasets',
64 | batch_size=32,
65 | num_workers=16):
66 | super().__init__(preprocess, 'test', location, batch_size, num_workers)
67 |
68 |
69 | class FlowersVal(FlowersBase):
70 | def __init__(self,
71 | preprocess,
72 | location='~/datasets',
73 | batch_size=32,
74 | num_workers=16):
75 | super().__init__(preprocess, 'val', location, batch_size, num_workers)
76 |
--------------------------------------------------------------------------------
/src/datasets/gtsrb.py:
--------------------------------------------------------------------------------
1 | import csv
2 | import os
3 | import pathlib
4 | from typing import Any, Callable, Dict, List, Optional, Tuple
5 |
6 | import numpy as np
7 | import PIL
8 | import torch
9 | from torchvision.datasets.folder import make_dataset
10 | from torchvision.datasets.utils import (download_and_extract_archive,
11 | verify_str_arg)
12 | from torchvision.datasets.vision import VisionDataset
13 |
14 | def find_classes(directory: str) -> Tuple[List[str], Dict[str, int]]:
15 | """Finds the class folders in a dataset.
16 |
17 | See :class:`DatasetFolder` for details.
18 | """
19 | classes = sorted(entry.name for entry in os.scandir(directory) if entry.is_dir())
20 | if not classes:
21 | raise FileNotFoundError(f"Couldn't find any class folder in {directory}.")
22 |
23 | class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
24 | return classes, class_to_idx
25 |
26 | class PyTorchGTSRB(VisionDataset):
27 | """`German Traffic Sign Recognition Benchmark (GTSRB) `_ Dataset.
28 |
29 | Modified from https://pytorch.org/vision/main/_modules/torchvision/datasets/gtsrb.html#GTSRB.
30 |
31 | Args:
32 | root (string): Root directory of the dataset.
33 | split (string, optional): The dataset split, supports ``"train"`` (default), or ``"test"``.
34 | transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed
35 | version. E.g, ``transforms.RandomCrop``.
36 | target_transform (callable, optional): A function/transform that takes in the target and transforms it.
37 | download (bool, optional): If True, downloads the dataset from the internet and
38 | puts it in root directory. If dataset is already downloaded, it is not
39 | downloaded again.
40 | """
41 |
42 | def __init__(
43 | self,
44 | root: str,
45 | split: str = "train",
46 | transform: Optional[Callable] = None,
47 | target_transform: Optional[Callable] = None,
48 | download: bool = False,
49 | ) -> None:
50 |
51 | super().__init__(root, transform=transform, target_transform=target_transform)
52 |
53 | self._split = verify_str_arg(split, "split", ("train", "test"))
54 | self._base_folder = pathlib.Path("./data") / "gtsrb"
55 | self._target_folder = (
56 | self._base_folder / "GTSRB" / ("Training" if self._split == "train" else "Final_Test/Images")
57 | )
58 |
59 | if download:
60 | self.download()
61 |
62 | if not self._check_exists():
63 | raise RuntimeError("Dataset not found. You can use download=True to download it")
64 |
65 | if self._split == "train":
66 | _, class_to_idx = find_classes(str(self._target_folder))
67 | samples = make_dataset(str(self._target_folder), extensions=(".ppm",), class_to_idx=class_to_idx)
68 | else:
69 | with open(self._base_folder / "GT-final_test.csv") as csv_file:
70 | samples = [
71 | (str(self._target_folder / row["Filename"]), int(row["ClassId"]))
72 | for row in csv.DictReader(csv_file, delimiter=";", skipinitialspace=True)
73 | ]
74 |
75 | self._samples = samples
76 | self.transform = transform
77 | self.target_transform = target_transform
78 |
79 | def __len__(self) -> int:
80 | return len(self._samples)
81 |
82 | def __getitem__(self, index: int) -> Tuple[Any, Any]:
83 |
84 | path, target = self._samples[index]
85 | sample = PIL.Image.open(path).convert("RGB")
86 |
87 | if self.transform is not None:
88 | sample = self.transform(sample)
89 |
90 | if self.target_transform is not None:
91 | target = self.target_transform(target)
92 |
93 | return sample, target, index
94 |
95 |
96 | def _check_exists(self) -> bool:
97 | return self._target_folder.is_dir()
98 |
99 | def download(self) -> None:
100 | if self._check_exists():
101 | return
102 |
103 | base_url = "https://sid.erda.dk/public/archives/daaeac0d7ce1152aea9b61d9f1e19370/"
104 |
105 | if self._split == "train":
106 | download_and_extract_archive(
107 | f"{base_url}GTSRB-Training_fixed.zip",
108 | download_root=str(self._base_folder),
109 | md5="513f3c79a4c5141765e10e952eaa2478",
110 | )
111 | else:
112 | download_and_extract_archive(
113 | f"{base_url}GTSRB_Final_Test_Images.zip",
114 | download_root=str(self._base_folder),
115 | md5="c7e4e6327067d32654124b0fe9e82185",
116 | )
117 | download_and_extract_archive(
118 | f"{base_url}GTSRB_Final_Test_GT.zip",
119 | download_root=str(self._base_folder),
120 | md5="fe31e9c9270bbcd7b84b7f21a9d9d9e5",
121 | )
122 |
123 |
124 | class GTSRB:
125 | def __init__(self,
126 | preprocess,
127 | location=os.path.expanduser('./data'),
128 | batch_size=128,
129 | num_workers=16):
130 |
131 | # to fit with repo conventions for location
132 | self.train_dataset = PyTorchGTSRB(
133 | root=location,
134 | download=True,
135 | split='train',
136 | transform=preprocess
137 | )
138 |
139 | self.train_loader = torch.utils.data.DataLoader(
140 | self.train_dataset,
141 | batch_size=batch_size,
142 | shuffle=True,
143 | num_workers=num_workers
144 | )
145 |
146 | self.test_dataset = PyTorchGTSRB(
147 | root=location,
148 | download=True,
149 | split='test',
150 | transform=preprocess
151 | )
152 |
153 | self.test_loader = torch.utils.data.DataLoader(
154 | self.test_dataset,
155 | batch_size=batch_size,
156 | shuffle=False,
157 | num_workers=num_workers
158 | )
159 |
160 | self.test_loader_shuffle = torch.utils.data.DataLoader(
161 | self.test_dataset,
162 | shuffle=True,
163 | batch_size=batch_size,
164 | num_workers=num_workers
165 | )
166 |
167 | # from https://github.com/openai/CLIP/blob/e184f608c5d5e58165682f7c332c3a8b4c1545f2/data/prompts.md
168 | self.classnames = [
169 | 'red and white circle 20 kph speed limit',
170 | 'red and white circle 30 kph speed limit',
171 | 'red and white circle 50 kph speed limit',
172 | 'red and white circle 60 kph speed limit',
173 | 'red and white circle 70 kph speed limit',
174 | 'red and white circle 80 kph speed limit',
175 | 'end / de-restriction of 80 kph speed limit',
176 | 'red and white circle 100 kph speed limit',
177 | 'red and white circle 120 kph speed limit',
178 | 'red and white circle red car and black car no passing',
179 | 'red and white circle red truck and black car no passing',
180 | 'red and white triangle road intersection warning',
181 | 'white and yellow diamond priority road',
182 | 'red and white upside down triangle yield right-of-way',
183 | 'stop',
184 | 'empty red and white circle',
185 | 'red and white circle no truck entry',
186 | 'red circle with white horizonal stripe no entry',
187 | 'red and white triangle with exclamation mark warning',
188 | 'red and white triangle with black left curve approaching warning',
189 | 'red and white triangle with black right curve approaching warning',
190 | 'red and white triangle with black double curve approaching warning',
191 | 'red and white triangle rough / bumpy road warning',
192 | 'red and white triangle car skidding / slipping warning',
193 | 'red and white triangle with merging / narrow lanes warning',
194 | 'red and white triangle with person digging / construction / road work warning',
195 | 'red and white triangle with traffic light approaching warning',
196 | 'red and white triangle with person walking warning',
197 | 'red and white triangle with child and person walking warning',
198 | 'red and white triangle with bicyle warning',
199 | 'red and white triangle with snowflake / ice warning',
200 | 'red and white triangle with deer warning',
201 | 'white circle with gray strike bar no speed limit',
202 | 'blue circle with white right turn arrow mandatory',
203 | 'blue circle with white left turn arrow mandatory',
204 | 'blue circle with white forward arrow mandatory',
205 | 'blue circle with white forward or right turn arrow mandatory',
206 | 'blue circle with white forward or left turn arrow mandatory',
207 | 'blue circle with white keep right arrow mandatory',
208 | 'blue circle with white keep left arrow mandatory',
209 | 'blue circle with white arrows indicating a traffic circle',
210 | 'white circle with gray strike bar indicating no passing for cars has ended',
211 | 'white circle with gray strike bar indicating no passing for trucks has ended',
212 | ]
213 |
--------------------------------------------------------------------------------
/src/datasets/imagenet.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 |
4 | from .common import ImageFolderWithPaths, SubsetSampler
5 | import numpy as np
6 |
7 | def get_imagenet_classnames():
8 | imagenet_classnames = [
9 | "tench", "goldfish", "great white shark", "tiger shark", "hammerhead shark", "electric ray",
10 | "stingray", "rooster", "hen", "ostrich", "brambling", "goldfinch", "house finch", "junco",
11 | "indigo bunting", "American robin", "bulbul", "jay", "magpie", "chickadee", "American dipper",
12 | "kite (bird of prey)", "bald eagle", "vulture", "great grey owl", "fire salamander",
13 | "smooth newt", "newt", "spotted salamander", "axolotl", "American bullfrog", "tree frog",
14 | "tailed frog", "loggerhead sea turtle", "leatherback sea turtle", "mud turtle", "terrapin",
15 | "box turtle", "banded gecko", "green iguana", "Carolina anole",
16 | "desert grassland whiptail lizard", "agama", "frilled-necked lizard", "alligator lizard",
17 | "Gila monster", "European green lizard", "chameleon", "Komodo dragon", "Nile crocodile",
18 | "American alligator", "triceratops", "worm snake", "ring-necked snake",
19 | "eastern hog-nosed snake", "smooth green snake", "kingsnake", "garter snake", "water snake",
20 | "vine snake", "night snake", "boa constrictor", "African rock python", "Indian cobra",
21 | "green mamba", "sea snake", "Saharan horned viper", "eastern diamondback rattlesnake",
22 | "sidewinder rattlesnake", "trilobite", "harvestman", "scorpion", "yellow garden spider",
23 | "barn spider", "European garden spider", "southern black widow", "tarantula", "wolf spider",
24 | "tick", "centipede", "black grouse", "ptarmigan", "ruffed grouse", "prairie grouse", "peafowl",
25 | "quail", "partridge", "african grey parrot", "macaw", "sulphur-crested cockatoo", "lorikeet",
26 | "coucal", "bee eater", "hornbill", "hummingbird", "jacamar", "toucan", "duck",
27 | "red-breasted merganser", "goose", "black swan", "tusker", "echidna", "platypus", "wallaby",
28 | "koala", "wombat", "jellyfish", "sea anemone", "brain coral", "flatworm", "nematode", "conch",
29 | "snail", "slug", "sea slug", "chiton", "chambered nautilus", "Dungeness crab", "rock crab",
30 | "fiddler crab", "red king crab", "American lobster", "spiny lobster", "crayfish", "hermit crab",
31 | "isopod", "white stork", "black stork", "spoonbill", "flamingo", "little blue heron",
32 | "great egret", "bittern bird", "crane bird", "limpkin", "common gallinule", "American coot",
33 | "bustard", "ruddy turnstone", "dunlin", "common redshank", "dowitcher", "oystercatcher",
34 | "pelican", "king penguin", "albatross", "grey whale", "killer whale", "dugong", "sea lion",
35 | "Chihuahua", "Japanese Chin", "Maltese", "Pekingese", "Shih Tzu", "King Charles Spaniel",
36 | "Papillon", "toy terrier", "Rhodesian Ridgeback", "Afghan Hound", "Basset Hound", "Beagle",
37 | "Bloodhound", "Bluetick Coonhound", "Black and Tan Coonhound", "Treeing Walker Coonhound",
38 | "English foxhound", "Redbone Coonhound", "borzoi", "Irish Wolfhound", "Italian Greyhound",
39 | "Whippet", "Ibizan Hound", "Norwegian Elkhound", "Otterhound", "Saluki", "Scottish Deerhound",
40 | "Weimaraner", "Staffordshire Bull Terrier", "American Staffordshire Terrier",
41 | "Bedlington Terrier", "Border Terrier", "Kerry Blue Terrier", "Irish Terrier",
42 | "Norfolk Terrier", "Norwich Terrier", "Yorkshire Terrier", "Wire Fox Terrier",
43 | "Lakeland Terrier", "Sealyham Terrier", "Airedale Terrier", "Cairn Terrier",
44 | "Australian Terrier", "Dandie Dinmont Terrier", "Boston Terrier", "Miniature Schnauzer",
45 | "Giant Schnauzer", "Standard Schnauzer", "Scottish Terrier", "Tibetan Terrier",
46 | "Australian Silky Terrier", "Soft-coated Wheaten Terrier", "West Highland White Terrier",
47 | "Lhasa Apso", "Flat-Coated Retriever", "Curly-coated Retriever", "Golden Retriever",
48 | "Labrador Retriever", "Chesapeake Bay Retriever", "German Shorthaired Pointer", "Vizsla",
49 | "English Setter", "Irish Setter", "Gordon Setter", "Brittany dog", "Clumber Spaniel",
50 | "English Springer Spaniel", "Welsh Springer Spaniel", "Cocker Spaniel", "Sussex Spaniel",
51 | "Irish Water Spaniel", "Kuvasz", "Schipperke", "Groenendael dog", "Malinois", "Briard",
52 | "Australian Kelpie", "Komondor", "Old English Sheepdog", "Shetland Sheepdog", "collie",
53 | "Border Collie", "Bouvier des Flandres dog", "Rottweiler", "German Shepherd Dog", "Dobermann",
54 | "Miniature Pinscher", "Greater Swiss Mountain Dog", "Bernese Mountain Dog",
55 | "Appenzeller Sennenhund", "Entlebucher Sennenhund", "Boxer", "Bullmastiff", "Tibetan Mastiff",
56 | "French Bulldog", "Great Dane", "St. Bernard", "husky", "Alaskan Malamute", "Siberian Husky",
57 | "Dalmatian", "Affenpinscher", "Basenji", "pug", "Leonberger", "Newfoundland dog",
58 | "Great Pyrenees dog", "Samoyed", "Pomeranian", "Chow Chow", "Keeshond", "brussels griffon",
59 | "Pembroke Welsh Corgi", "Cardigan Welsh Corgi", "Toy Poodle", "Miniature Poodle",
60 | "Standard Poodle", "Mexican hairless dog (xoloitzcuintli)", "grey wolf", "Alaskan tundra wolf",
61 | "red wolf or maned wolf", "coyote", "dingo", "dhole", "African wild dog", "hyena", "red fox",
62 | "kit fox", "Arctic fox", "grey fox", "tabby cat", "tiger cat", "Persian cat", "Siamese cat",
63 | "Egyptian Mau", "cougar", "lynx", "leopard", "snow leopard", "jaguar", "lion", "tiger",
64 | "cheetah", "brown bear", "American black bear", "polar bear", "sloth bear", "mongoose",
65 | "meerkat", "tiger beetle", "ladybug", "ground beetle", "longhorn beetle", "leaf beetle",
66 | "dung beetle", "rhinoceros beetle", "weevil", "fly", "bee", "ant", "grasshopper",
67 | "cricket insect", "stick insect", "cockroach", "praying mantis", "cicada", "leafhopper",
68 | "lacewing", "dragonfly", "damselfly", "red admiral butterfly", "ringlet butterfly",
69 | "monarch butterfly", "small white butterfly", "sulphur butterfly", "gossamer-winged butterfly",
70 | "starfish", "sea urchin", "sea cucumber", "cottontail rabbit", "hare", "Angora rabbit",
71 | "hamster", "porcupine", "fox squirrel", "marmot", "beaver", "guinea pig", "common sorrel horse",
72 | "zebra", "pig", "wild boar", "warthog", "hippopotamus", "ox", "water buffalo", "bison",
73 | "ram (adult male sheep)", "bighorn sheep", "Alpine ibex", "hartebeest", "impala (antelope)",
74 | "gazelle", "arabian camel", "llama", "weasel", "mink", "European polecat",
75 | "black-footed ferret", "otter", "skunk", "badger", "armadillo", "three-toed sloth", "orangutan",
76 | "gorilla", "chimpanzee", "gibbon", "siamang", "guenon", "patas monkey", "baboon", "macaque",
77 | "langur", "black-and-white colobus", "proboscis monkey", "marmoset", "white-headed capuchin",
78 | "howler monkey", "titi monkey", "Geoffroy's spider monkey", "common squirrel monkey",
79 | "ring-tailed lemur", "indri", "Asian elephant", "African bush elephant", "red panda",
80 | "giant panda", "snoek fish", "eel", "silver salmon", "rock beauty fish", "clownfish",
81 | "sturgeon", "gar fish", "lionfish", "pufferfish", "abacus", "abaya", "academic gown",
82 | "accordion", "acoustic guitar", "aircraft carrier", "airliner", "airship", "altar", "ambulance",
83 | "amphibious vehicle", "analog clock", "apiary", "apron", "trash can", "assault rifle",
84 | "backpack", "bakery", "balance beam", "balloon", "ballpoint pen", "Band-Aid", "banjo",
85 | "baluster / handrail", "barbell", "barber chair", "barbershop", "barn", "barometer", "barrel",
86 | "wheelbarrow", "baseball", "basketball", "bassinet", "bassoon", "swimming cap", "bath towel",
87 | "bathtub", "station wagon", "lighthouse", "beaker", "military hat (bearskin or shako)",
88 | "beer bottle", "beer glass", "bell tower", "baby bib", "tandem bicycle", "bikini",
89 | "ring binder", "binoculars", "birdhouse", "boathouse", "bobsleigh", "bolo tie", "poke bonnet",
90 | "bookcase", "bookstore", "bottle cap", "hunting bow", "bow tie", "brass memorial plaque", "bra",
91 | "breakwater", "breastplate", "broom", "bucket", "buckle", "bulletproof vest",
92 | "high-speed train", "butcher shop", "taxicab", "cauldron", "candle", "cannon", "canoe",
93 | "can opener", "cardigan", "car mirror", "carousel", "tool kit", "cardboard box / carton",
94 | "car wheel", "automated teller machine", "cassette", "cassette player", "castle", "catamaran",
95 | "CD player", "cello", "mobile phone", "chain", "chain-link fence", "chain mail", "chainsaw",
96 | "storage chest", "chiffonier", "bell or wind chime", "china cabinet", "Christmas stocking",
97 | "church", "movie theater", "cleaver", "cliff dwelling", "cloak", "clogs", "cocktail shaker",
98 | "coffee mug", "coffeemaker", "spiral or coil", "combination lock", "computer keyboard",
99 | "candy store", "container ship", "convertible", "corkscrew", "cornet", "cowboy boot",
100 | "cowboy hat", "cradle", "construction crane", "crash helmet", "crate", "infant bed",
101 | "Crock Pot", "croquet ball", "crutch", "cuirass", "dam", "desk", "desktop computer",
102 | "rotary dial telephone", "diaper", "digital clock", "digital watch", "dining table",
103 | "dishcloth", "dishwasher", "disc brake", "dock", "dog sled", "dome", "doormat", "drilling rig",
104 | "drum", "drumstick", "dumbbell", "Dutch oven", "electric fan", "electric guitar",
105 | "electric locomotive", "entertainment center", "envelope", "espresso machine", "face powder",
106 | "feather boa", "filing cabinet", "fireboat", "fire truck", "fire screen", "flagpole", "flute",
107 | "folding chair", "football helmet", "forklift", "fountain", "fountain pen", "four-poster bed",
108 | "freight car", "French horn", "frying pan", "fur coat", "garbage truck",
109 | "gas mask or respirator", "gas pump", "goblet", "go-kart", "golf ball", "golf cart", "gondola",
110 | "gong", "gown", "grand piano", "greenhouse", "radiator grille", "grocery store", "guillotine",
111 | "hair clip", "hair spray", "half-track", "hammer", "hamper", "hair dryer", "hand-held computer",
112 | "handkerchief", "hard disk drive", "harmonica", "harp", "combine harvester", "hatchet",
113 | "holster", "home theater", "honeycomb", "hook", "hoop skirt", "gymnastic horizontal bar",
114 | "horse-drawn vehicle", "hourglass", "iPod", "clothes iron", "carved pumpkin", "jeans", "jeep",
115 | "T-shirt", "jigsaw puzzle", "rickshaw", "joystick", "kimono", "knee pad", "knot", "lab coat",
116 | "ladle", "lampshade", "laptop computer", "lawn mower", "lens cap", "letter opener", "library",
117 | "lifeboat", "lighter", "limousine", "ocean liner", "lipstick", "slip-on shoe", "lotion",
118 | "music speaker", "loupe magnifying glass", "sawmill", "magnetic compass", "messenger bag",
119 | "mailbox", "tights", "one-piece bathing suit", "manhole cover", "maraca", "marimba", "mask",
120 | "matchstick", "maypole", "maze", "measuring cup", "medicine cabinet", "megalith", "microphone",
121 | "microwave oven", "military uniform", "milk can", "minibus", "miniskirt", "minivan", "missile",
122 | "mitten", "mixing bowl", "mobile home", "ford model t", "modem", "monastery", "monitor",
123 | "moped", "mortar and pestle", "graduation cap", "mosque", "mosquito net", "vespa",
124 | "mountain bike", "tent", "computer mouse", "mousetrap", "moving van", "muzzle", "metal nail",
125 | "neck brace", "necklace", "baby pacifier", "notebook computer", "obelisk", "oboe", "ocarina",
126 | "odometer", "oil filter", "pipe organ", "oscilloscope", "overskirt", "bullock cart",
127 | "oxygen mask", "product packet / packaging", "paddle", "paddle wheel", "padlock", "paintbrush",
128 | "pajamas", "palace", "pan flute", "paper towel", "parachute", "parallel bars", "park bench",
129 | "parking meter", "railroad car", "patio", "payphone", "pedestal", "pencil case",
130 | "pencil sharpener", "perfume", "Petri dish", "photocopier", "plectrum", "Pickelhaube",
131 | "picket fence", "pickup truck", "pier", "piggy bank", "pill bottle", "pillow", "ping-pong ball",
132 | "pinwheel", "pirate ship", "drink pitcher", "block plane", "planetarium", "plastic bag",
133 | "plate rack", "farm plow", "plunger", "Polaroid camera", "pole", "police van", "poncho",
134 | "pool table", "soda bottle", "plant pot", "potter's wheel", "power drill", "prayer rug",
135 | "printer", "prison", "missile", "projector", "hockey puck", "punching bag", "purse", "quill",
136 | "quilt", "race car", "racket", "radiator", "radio", "radio telescope", "rain barrel",
137 | "recreational vehicle", "fishing casting reel", "reflex camera", "refrigerator",
138 | "remote control", "restaurant", "revolver", "rifle", "rocking chair", "rotisserie", "eraser",
139 | "rugby ball", "ruler measuring stick", "sneaker", "safe", "safety pin", "salt shaker", "sandal",
140 | "sarong", "saxophone", "scabbard", "weighing scale", "school bus", "schooner", "scoreboard",
141 | "CRT monitor", "screw", "screwdriver", "seat belt", "sewing machine", "shield", "shoe store",
142 | "shoji screen / room divider", "shopping basket", "shopping cart", "shovel", "shower cap",
143 | "shower curtain", "ski", "balaclava ski mask", "sleeping bag", "slide rule", "sliding door",
144 | "slot machine", "snorkel", "snowmobile", "snowplow", "soap dispenser", "soccer ball", "sock",
145 | "solar thermal collector", "sombrero", "soup bowl", "keyboard space bar", "space heater",
146 | "space shuttle", "spatula", "motorboat", "spider web", "spindle", "sports car", "spotlight",
147 | "stage", "steam locomotive", "through arch bridge", "steel drum", "stethoscope", "scarf",
148 | "stone wall", "stopwatch", "stove", "strainer", "tram", "stretcher", "couch", "stupa",
149 | "submarine", "suit", "sundial", "sunglasses", "sunglasses", "sunscreen", "suspension bridge",
150 | "mop", "sweatshirt", "swim trunks / shorts", "swing", "electrical switch", "syringe",
151 | "table lamp", "tank", "tape player", "teapot", "teddy bear", "television", "tennis ball",
152 | "thatched roof", "front curtain", "thimble", "threshing machine", "throne", "tile roof",
153 | "toaster", "tobacco shop", "toilet seat", "torch", "totem pole", "tow truck", "toy store",
154 | "tractor", "semi-trailer truck", "tray", "trench coat", "tricycle", "trimaran", "tripod",
155 | "triumphal arch", "trolleybus", "trombone", "hot tub", "turnstile", "typewriter keyboard",
156 | "umbrella", "unicycle", "upright piano", "vacuum cleaner", "vase", "vaulted or arched ceiling",
157 | "velvet fabric", "vending machine", "vestment", "viaduct", "violin", "volleyball",
158 | "waffle iron", "wall clock", "wallet", "wardrobe", "military aircraft", "sink",
159 | "washing machine", "water bottle", "water jug", "water tower", "whiskey jug", "whistle",
160 | "hair wig", "window screen", "window shade", "Windsor tie", "wine bottle", "airplane wing",
161 | "wok", "wooden spoon", "wool", "split-rail fence", "shipwreck", "sailboat", "yurt", "website",
162 | "comic book", "crossword", "traffic or street sign", "traffic light", "dust jacket", "menu",
163 | "plate", "guacamole", "consomme", "hot pot", "trifle", "ice cream", "popsicle", "baguette",
164 | "bagel", "pretzel", "cheeseburger", "hot dog", "mashed potatoes", "cabbage", "broccoli",
165 | "cauliflower", "zucchini", "spaghetti squash", "acorn squash", "butternut squash", "cucumber",
166 | "artichoke", "bell pepper", "cardoon", "mushroom", "Granny Smith apple", "strawberry", "orange",
167 | "lemon", "fig", "pineapple", "banana", "jackfruit", "cherimoya (custard apple)", "pomegranate",
168 | "hay", "carbonara", "chocolate syrup", "dough", "meatloaf", "pizza", "pot pie", "burrito",
169 | "red wine", "espresso", "tea cup", "eggnog", "mountain", "bubble", "cliff", "coral reef",
170 | "geyser", "lakeshore", "promontory", "sandbar", "beach", "valley", "volcano", "baseball player",
171 | "bridegroom", "scuba diver", "rapeseed", "daisy", "yellow lady's slipper", "corn", "acorn",
172 | "rose hip", "horse chestnut seed", "coral fungus", "agaric", "gyromitra", "stinkhorn mushroom",
173 | "earth star fungus", "hen of the woods mushroom", "bolete", "corn cob", "toilet paper"
174 | ]
175 | return imagenet_classnames
176 |
177 | class ImageNet:
178 | def __init__(self,
179 | preprocess,
180 | location=os.path.expanduser('./data'),
181 | batch_size=32,
182 | num_workers=32):
183 | self.preprocess = preprocess
184 | self.location = location
185 | self.batch_size = batch_size
186 | self.num_workers = num_workers
187 | self.classnames = get_imagenet_classnames()
188 |
189 | self.populate_train()
190 | self.populate_test()
191 |
192 | def populate_train(self):
193 | traindir = os.path.join(self.location, self.name(), 'train')
194 | self.train_dataset = ImageFolderWithPaths(
195 | traindir,
196 | transform=self.preprocess)
197 | sampler = self.get_train_sampler()
198 | kwargs = {'shuffle' : True} if sampler is None else {}
199 | self.train_loader = torch.utils.data.DataLoader(
200 | self.train_dataset,
201 | sampler=sampler,
202 | batch_size=self.batch_size,
203 | num_workers=self.num_workers,
204 | **kwargs,
205 | )
206 |
207 | def populate_test(self):
208 | self.test_dataset = self.get_test_dataset()
209 | self.test_loader = torch.utils.data.DataLoader(
210 | self.test_dataset,
211 | batch_size=self.batch_size,
212 | num_workers=self.num_workers,
213 | sampler=self.get_test_sampler()
214 | )
215 | self.test_loader_shuffle = torch.utils.data.DataLoader(
216 | self.test_dataset,
217 | shuffle=True,
218 | batch_size=self.batch_size,
219 | num_workers=self.num_workers,
220 | sampler=self.get_test_sampler()
221 | )
222 |
223 | def get_test_path(self):
224 | test_path = os.path.join(self.location, self.name(), 'val_in_folder')
225 | if not os.path.exists(test_path):
226 | test_path = os.path.join(self.location, self.name(), 'val')
227 | return test_path
228 |
229 | def get_train_sampler(self):
230 | return None
231 |
232 | def get_test_sampler(self):
233 | return None
234 |
235 | def get_test_dataset(self):
236 | return ImageFolderWithPaths(self.get_test_path(), transform=self.preprocess)
237 |
238 | def name(self):
239 | return 'imagenet'
240 |
241 | class ImageNetTrain(ImageNet):
242 |
243 | def get_test_dataset(self):
244 | pass
245 |
246 | class ImageNetK(ImageNet):
247 |
248 | def get_train_sampler(self):
249 | idxs = np.zeros(len(self.train_dataset.targets))
250 | target_array = np.array(self.train_dataset.targets)
251 | for c in range(1000):
252 | m = target_array == c
253 | n = len(idxs[m])
254 | arr = np.zeros(n)
255 | arr[:self.k()] = 1
256 | np.random.shuffle(arr)
257 | idxs[m] = arr
258 |
259 | idxs = idxs.astype('int')
260 | sampler = SubsetSampler(np.where(idxs)[0])
261 | return sampler
--------------------------------------------------------------------------------
/src/datasets/imagenet100.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torchvision.datasets as datasets
4 |
5 | class ImageFolderDataset(datasets.ImageFolder):
6 | def __init__(self, root, transform):
7 | super().__init__(root, transform)
8 |
9 | def __getitem__(self, index: int):
10 | path, target = self.samples[index]
11 | sample = self.loader(path)
12 | if self.transform is not None:
13 | sample = self.transform(sample)
14 | if self.target_transform is not None:
15 | target = self.target_transform(target)
16 | return sample, target, index
17 |
18 | class ImageNet100:
19 | def __init__(self,
20 | preprocess,
21 | location=os.path.expanduser('./data'),
22 | batch_size=32,
23 | num_workers=16):
24 | # Data loading code
25 | location = './data'
26 | traindir = os.path.join(location, 'ImageNet100', 'train')
27 | valdir = os.path.join(location, 'ImageNet100', 'val')
28 |
29 | self.train_dataset = ImageFolderDataset(
30 | traindir, transform=preprocess)
31 | self.train_loader = torch.utils.data.DataLoader(
32 | self.train_dataset,
33 | shuffle=True,
34 | batch_size=batch_size,
35 | num_workers=num_workers,
36 | )
37 | self.test_dataset = ImageFolderDataset(valdir, transform=preprocess)
38 | self.test_loader = torch.utils.data.DataLoader(
39 | self.test_dataset,
40 | batch_size=batch_size,
41 | num_workers=num_workers
42 | )
43 | self.test_loader_shuffle = torch.utils.data.DataLoader(
44 | self.test_dataset,
45 | shuffle=True,
46 | batch_size=batch_size,
47 | num_workers=num_workers
48 | )
49 | idx_to_class = dict((v, k)
50 | for k, v in self.train_dataset.class_to_idx.items())
51 | self.classnames = [idx_to_class[i].replace(
52 | '_', ' ') for i in range(len(idx_to_class))]
--------------------------------------------------------------------------------
/src/datasets/mnist.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torchvision.datasets as datasets
4 | from PIL import Image
5 |
6 | class MyMNIST(datasets.MNIST):
7 | def __init__(self, root, download, train, transform):
8 | super().__init__(root=root, download=download, train=train, transform=transform)
9 |
10 | def __getitem__(self, index: int):
11 | """
12 | Args:
13 | index (int): Index
14 |
15 | Returns:
16 | tuple: (image, target) where target is index of the target class.
17 | """
18 | img, target = self.data[index], int(self.targets[index])
19 |
20 | # doing this so that it is consistent with all other datasets
21 | # to return a PIL Image
22 | img = Image.fromarray(img.numpy(), mode="L")
23 |
24 | if self.transform is not None:
25 | img = self.transform(img)
26 |
27 | if self.target_transform is not None:
28 | target = self.target_transform(target)
29 |
30 | return img, target, index
31 |
32 | class MNIST:
33 | def __init__(self,
34 | preprocess,
35 | location=os.path.expanduser('./data'),
36 | batch_size=128,
37 | num_workers=16):
38 |
39 |
40 | self.train_dataset = MyMNIST(
41 | root=location,
42 | download=True,
43 | train=True,
44 | transform=preprocess
45 | )
46 |
47 | self.train_loader = torch.utils.data.DataLoader(
48 | self.train_dataset,
49 | batch_size=batch_size,
50 | shuffle=True,
51 | num_workers=num_workers
52 | )
53 |
54 | self.test_dataset = MyMNIST(
55 | root=location,
56 | download=True,
57 | train=False,
58 | transform=preprocess
59 | )
60 |
61 | self.test_loader = torch.utils.data.DataLoader(
62 | self.test_dataset,
63 | batch_size=batch_size,
64 | shuffle=False,
65 | num_workers=num_workers
66 | )
67 |
68 | self.test_loader_shuffle = torch.utils.data.DataLoader(
69 | self.test_dataset,
70 | shuffle=True,
71 | batch_size=batch_size,
72 | num_workers=num_workers
73 | )
74 |
75 | self.classnames = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
--------------------------------------------------------------------------------
/src/datasets/pets.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torchvision.datasets as datasets
4 |
5 | class ImageFolderDataset(datasets.ImageFolder):
6 | def __init__(self, root, transform):
7 | super().__init__(root, transform)
8 |
9 | def __getitem__(self, index: int):
10 | path, target = self.samples[index]
11 | sample = self.loader(path)
12 | if self.transform is not None:
13 | sample = self.transform(sample)
14 | if self.target_transform is not None:
15 | target = self.target_transform(target)
16 | return sample, target, index
17 |
18 | class PETS:
19 | def __init__(self,
20 | preprocess,
21 | location=os.path.expanduser('./data'),
22 | batch_size=32,
23 | num_workers=16):
24 | # Data loading code
25 | location = './data'
26 | traindir = os.path.join(location, 'pets', 'train')
27 | valdir = os.path.join(location, 'pets', 'test')
28 |
29 | self.train_dataset = ImageFolderDataset(
30 | traindir, transform=preprocess)
31 | self.train_loader = torch.utils.data.DataLoader(
32 | self.train_dataset,
33 | shuffle=True,
34 | batch_size=batch_size,
35 | num_workers=num_workers,
36 | )
37 |
38 | self.test_dataset = ImageFolderDataset(valdir, transform=preprocess)
39 | self.test_loader = torch.utils.data.DataLoader(
40 | self.test_dataset,
41 | batch_size=batch_size,
42 | num_workers=num_workers
43 | )
44 | self.test_loader_shuffle = torch.utils.data.DataLoader(
45 | self.test_dataset,
46 | shuffle=True,
47 | batch_size=batch_size,
48 | num_workers=num_workers
49 | )
50 | idx_to_class = dict((v, k)
51 | for k, v in self.train_dataset.class_to_idx.items())
52 | self.classnames = [idx_to_class[i].replace(
53 | '_', ' ') for i in range(len(idx_to_class))]
--------------------------------------------------------------------------------
/src/datasets/registry.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import inspect
3 | import random
4 | import torch
5 | import copy
6 | from torch.utils.data.dataset import random_split
7 | from src.datasets.cars import Cars
8 | from src.datasets.cifar10 import CIFAR10
9 | from src.datasets.cifar100 import CIFAR100
10 | from src.datasets.dtd import DTD
11 | from src.datasets.eurosat import EuroSAT, EuroSATVal
12 | from src.datasets.gtsrb import GTSRB
13 | from src.datasets.imagenet import ImageNet
14 | from src.datasets.mnist import MNIST
15 | from src.datasets.resisc45 import RESISC45
16 | from src.datasets.stl10 import STL10
17 | from src.datasets.svhn import SVHN
18 | from src.datasets.sun397 import SUN397
19 | from src.datasets.pets import PETS
20 | from src.datasets.flowers import Flowers, FlowersVal
21 | from src.datasets.imagenet100 import ImageNet100
22 | from src.datasets.common import get_dataloader, maybe_dictionarize
23 | registry = {
24 | name: obj for name, obj in inspect.getmembers(sys.modules[__name__], inspect.isclass)
25 | }
26 |
27 | class GenericDataset(object):
28 | def __init__(self):
29 | self.train_dataset = None
30 | self.train_loader = None
31 | self.test_dataset = None
32 | self.test_loader = None
33 | self.classnames = None
34 |
35 |
36 | def split_train_into_train_dev_cifar_mnist(dataset, new_dataset_class_name, batch_size, num_workers, val_fraction, test_length, max_val_samples=None, seed=0):
37 | assert val_fraction > 0. and val_fraction < 1.
38 | total_size = len(dataset.train_dataset)
39 | val_size = test_length # shadow train = shadow test
40 | if max_val_samples is not None:
41 | val_size = min(val_size, max_val_samples)
42 | train_size = total_size - val_size
43 | target_train_size = int(train_size/2)
44 | target_test_size = train_size - target_train_size
45 | assert val_size > 0
46 | assert train_size > 0
47 | lengths = [target_train_size, target_test_size, val_size]
48 | print(lengths)
49 | trainset, valset, shadowset = random_split(
50 | dataset.train_dataset,
51 | lengths,
52 | generator=torch.Generator().manual_seed(seed) # same split
53 | )
54 |
55 | new_dataset = None
56 | new_dataset_class = type(new_dataset_class_name, (GenericDataset, ), {})
57 | new_dataset = new_dataset_class()
58 | new_dataset.train_dataset = trainset
59 | new_dataset.train_loader = torch.utils.data.DataLoader(
60 | new_dataset.train_dataset,
61 | shuffle=True,
62 | batch_size=batch_size,
63 | num_workers=num_workers,
64 | )
65 | new_dataset.test_dataset = valset
66 | new_dataset.test_loader = torch.utils.data.DataLoader(
67 | new_dataset.test_dataset,
68 | batch_size=batch_size,
69 | num_workers=num_workers
70 | )
71 | new_dataset.test_loader_shuffle = torch.utils.data.DataLoader(
72 | new_dataset.test_dataset,
73 | batch_size=batch_size,
74 | num_workers=num_workers,
75 | shuffle=True
76 | )
77 |
78 | new_dataset.shadowtrain_dataset = shadowset
79 | new_dataset.shadowtrain_loader = torch.utils.data.DataLoader(
80 | new_dataset.shadowtrain_dataset,
81 | shuffle=True,
82 | batch_size=batch_size,
83 | num_workers=num_workers,
84 | )
85 |
86 | new_dataset.classnames = copy.copy(dataset.classnames)
87 | return new_dataset
88 |
89 | def get_dataset_classnames(dataset_name, preprocess, location, batch_size=128, num_workers=16):
90 | assert dataset_name in registry, f'Unsupported dataset: {dataset_name}. Supported datasets: {list(registry.keys())}'
91 | dataset_class = registry[dataset_name]
92 | dataset = dataset_class(
93 | preprocess, location=location, batch_size=batch_size, num_workers=num_workers
94 | )
95 | return dataset.classnames
96 |
97 |
98 | def get_dataset_cifar_mnist(dataset_name, split, preprocess, location, batch_size=128, num_workers=16, val_fraction=0.4, max_val_samples=500000):
99 | print(location)
100 | # if dataset_name == 'MNIST':
101 | # val_fraction = 0.5
102 | if split=='train':
103 | if dataset_name=='EuroSAT':
104 | dataset_class = registry[dataset_name+"Val"]
105 | dataset = dataset_class(
106 | preprocess, location=location, batch_size=batch_size, num_workers=num_workers
107 | )
108 | else:
109 | dataset_class = registry[dataset_name]
110 | base_dataset = dataset_class(
111 | preprocess, location=location, batch_size=batch_size, num_workers=num_workers
112 | )
113 | # print("base_dataset: ", len(base_dataset.test_dataset))
114 | if dataset_name == 'PETS':
115 | len_val = 1400
116 | elif dataset_name == 'STL10':
117 | len_val = 1600
118 | else:
119 | len_val = len(base_dataset.test_dataset)
120 | dataset = split_train_into_train_dev_cifar_mnist(
121 | base_dataset, dataset_name, batch_size, num_workers, val_fraction, len_val, max_val_samples)
122 | return dataset.train_dataset, get_dataloader(dataset, split=split)
123 |
124 | elif split=='test':
125 | if dataset_name=='EuroSAT':
126 | dataset_class = registry[dataset_name+"Val"]
127 | dataset = dataset_class(
128 | preprocess, location=location, batch_size=batch_size, num_workers=num_workers
129 | )
130 | else:
131 | dataset_class = registry[dataset_name]
132 | base_dataset = dataset_class(
133 | preprocess, location=location, batch_size=batch_size, num_workers=num_workers
134 | )
135 | if dataset_name == 'PETS':
136 | len_val = 1400
137 | elif dataset_name == 'STL10':
138 | len_val = 1600
139 | else:
140 | len_val = len(base_dataset.test_dataset)
141 | dataset = split_train_into_train_dev_cifar_mnist(
142 | base_dataset, dataset_name, batch_size, num_workers, val_fraction, len_val, max_val_samples)
143 | return dataset.test_dataset, get_dataloader(dataset, split=split)
144 |
145 | elif split=='shadowtrain':
146 | if dataset_name=='EuroSAT':
147 | dataset_class = registry[dataset_name+"Val"]
148 | dataset = dataset_class(
149 | preprocess, location=location, batch_size=batch_size, num_workers=num_workers
150 | )
151 | else:
152 | dataset_class = registry[dataset_name]
153 | base_dataset = dataset_class(
154 | preprocess, location=location, batch_size=batch_size, num_workers=num_workers
155 | )
156 | if dataset_name == 'PETS':
157 | len_val = 1400
158 | elif dataset_name == 'STL10':
159 | len_val = 1600
160 | else:
161 | len_val = len(base_dataset.test_dataset)
162 | dataset = split_train_into_train_dev_cifar_mnist(
163 | base_dataset, dataset_name, batch_size, num_workers, val_fraction, len_val, max_val_samples)
164 | return dataset.shadowtrain_dataset, get_dataloader(dataset, split=split)
165 |
166 | elif split=='shadowtest' or split=='shadowtest_shuffled':
167 | assert dataset_name in registry, f'Unsupported dataset: {dataset_name}. Supported datasets: {list(registry.keys())}'
168 | dataset_class = registry[dataset_name]
169 | base_dataset = dataset_class(
170 | preprocess, location=location, batch_size=batch_size, num_workers=num_workers
171 | )
172 | test_loader = torch.utils.data.DataLoader(
173 | base_dataset.test_dataset,
174 | batch_size=batch_size,
175 | num_workers=num_workers,
176 | shuffle=True
177 | )
178 | return base_dataset.test_dataset, test_loader
179 |
180 |
181 |
182 | else:
183 | raise "Not implemented"
--------------------------------------------------------------------------------
/src/datasets/resisc45.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 |
4 | import abc
5 | import os
6 | from typing import Any, Callable, Dict, Optional, Tuple
7 |
8 | import numpy as np
9 | import torch
10 | from torch import Tensor
11 | from torch.utils.data import Dataset
12 | from torchvision.datasets import ImageFolder
13 | from torchvision.datasets.folder import default_loader as pil_loader
14 |
15 |
16 | # modified from: https://github.com/microsoft/torchgeo
17 | class VisionDataset(Dataset[Dict[str, Any]], abc.ABC):
18 | """Abstract base class for datasets lacking geospatial information.
19 | This base class is designed for datasets with pre-defined image chips.
20 | """
21 |
22 | @abc.abstractmethod
23 | def __getitem__(self, index: int) -> Dict[str, Any]:
24 | """Return an index within the dataset.
25 | Args:
26 | index: index to return
27 | Returns:
28 | data and labels at that index
29 | Raises:
30 | IndexError: if index is out of range of the dataset
31 | """
32 |
33 | @abc.abstractmethod
34 | def __len__(self) -> int:
35 | """Return the length of the dataset.
36 | Returns:
37 | length of the dataset
38 | """
39 |
40 | def __str__(self) -> str:
41 | """Return the informal string representation of the object.
42 | Returns:
43 | informal string representation
44 | """
45 | return f"""\
46 | {self.__class__.__name__} Dataset
47 | type: VisionDataset
48 | size: {len(self)}"""
49 |
50 |
51 | class VisionClassificationDataset(VisionDataset, ImageFolder):
52 | """Abstract base class for classification datasets lacking geospatial information.
53 | This base class is designed for datasets with pre-defined image chips which
54 | are separated into separate folders per class.
55 | """
56 |
57 | def __init__(
58 | self,
59 | root: str,
60 | transforms: Optional[Callable[[Dict[str, Tensor]], Dict[str, Tensor]]] = None,
61 | loader: Optional[Callable[[str], Any]] = pil_loader,
62 | is_valid_file: Optional[Callable[[str], bool]] = None,
63 | ) -> None:
64 | """Initialize a new VisionClassificationDataset instance.
65 | Args:
66 | root: root directory where dataset can be found
67 | transforms: a function/transform that takes input sample and its target as
68 | entry and returns a transformed version
69 | loader: a callable function which takes as input a path to an image and
70 | returns a PIL Image or numpy array
71 | is_valid_file: A function that takes the path of an Image file and checks if
72 | the file is a valid file
73 | """
74 | # When transform & target_transform are None, ImageFolder.__getitem__(index)
75 | # returns a PIL.Image and int for image and label, respectively
76 | super().__init__(
77 | root=root,
78 | transform=None,
79 | target_transform=None,
80 | loader=loader,
81 | is_valid_file=is_valid_file,
82 | )
83 |
84 | # Must be set after calling super().__init__()
85 | self.transforms = transforms
86 |
87 | def __getitem__(self, index: int) -> Dict[str, Tensor]:
88 | """Return an index within the dataset.
89 | Args:
90 | index: index to return
91 | Returns:
92 | data and label at that index
93 | """
94 | image, label = self._load_image(index)
95 |
96 | if self.transforms is not None:
97 | return self.transforms(image), label, index
98 |
99 | return image, label, index
100 |
101 | def __len__(self) -> int:
102 | """Return the number of data points in the dataset.
103 | Returns:
104 | length of the dataset
105 | """
106 | return len(self.imgs)
107 |
108 | def _load_image(self, index: int) -> Tuple[Tensor, Tensor]:
109 | """Load a single image and it's class label.
110 | Args:
111 | index: index to return
112 | Returns:
113 | the image
114 | the image class label
115 | """
116 | img, label = ImageFolder.__getitem__(self, index)
117 | label = torch.tensor(label)
118 | return img, label
119 |
120 |
121 | class RESISC45Dataset(VisionClassificationDataset):
122 | """RESISC45 dataset.
123 | The `RESISC45 `_
124 | dataset is a dataset for remote sensing image scene classification.
125 | Dataset features:
126 | * 31,500 images with 0.2-30 m per pixel resolution (256x256 px)
127 | * three spectral bands - RGB
128 | * 45 scene classes, 700 images per class
129 | * images extracted from Google Earth from over 100 countries
130 | * images conditions with high variability (resolution, weather, illumination)
131 | Dataset format:
132 | * images are three-channel jpgs
133 | Dataset classes:
134 | 0. airplane
135 | 1. airport
136 | 2. baseball_diamond
137 | 3. basketball_court
138 | 4. beach
139 | 5. bridge
140 | 6. chaparral
141 | 7. church
142 | 8. circular_farmland
143 | 9. cloud
144 | 10. commercial_area
145 | 11. dense_residential
146 | 12. desert
147 | 13. forest
148 | 14. freeway
149 | 15. golf_course
150 | 16. ground_track_field
151 | 17. harbor
152 | 18. industrial_area
153 | 19. intersection
154 | 20. island
155 | 21. lake
156 | 22. meadow
157 | 23. medium_residential
158 | 24. mobile_home_park
159 | 25. mountain
160 | 26. overpass
161 | 27. palace
162 | 28. parking_lot
163 | 29. railway
164 | 30. railway_station
165 | 31. rectangular_farmland
166 | 32. river
167 | 33. roundabout
168 | 34. runway
169 | 35. sea_ice
170 | 36. ship
171 | 37. snowberg
172 | 38. sparse_residential
173 | 39. stadium
174 | 40. storage_tank
175 | 41. tennis_court
176 | 42. terrace
177 | 43. thermal_power_station
178 | 44. wetland
179 | This dataset uses the train/val/test splits defined in the "In-domain representation
180 | learning for remote sensing" paper:
181 | * https://arxiv.org/abs/1911.06721
182 | If you use this dataset in your research, please cite the following paper:
183 | * https://doi.org/10.1109/jproc.2017.2675998
184 | """
185 |
186 | # url = "https://drive.google.com/file/d/1DnPSU5nVSN7xv95bpZ3XQ0JhKXZOKgIv"
187 | # md5 = "d824acb73957502b00efd559fc6cfbbb"
188 | # filename = "NWPU-RESISC45.rar"
189 | directory = "resisc45/NWPU-RESISC45"
190 |
191 | splits = ["train", "val", "test"]
192 | split_urls = {
193 | "train": "https://storage.googleapis.com/remote_sensing_representations/resisc45-train.txt", # noqa: E501
194 | "val": "https://storage.googleapis.com/remote_sensing_representations/resisc45-val.txt", # noqa: E501
195 | "test": "https://storage.googleapis.com/remote_sensing_representations/resisc45-test.txt", # noqa: E501
196 | }
197 | split_md5s = {
198 | "train": "b5a4c05a37de15e4ca886696a85c403e",
199 | "val": "a0770cee4c5ca20b8c32bbd61e114805",
200 | "test": "3dda9e4988b47eb1de9f07993653eb08",
201 | }
202 | classes = [
203 | "airplane",
204 | "airport",
205 | "baseball_diamond",
206 | "basketball_court",
207 | "beach",
208 | "bridge",
209 | "chaparral",
210 | "church",
211 | "circular_farmland",
212 | "cloud",
213 | "commercial_area",
214 | "dense_residential",
215 | "desert",
216 | "forest",
217 | "freeway",
218 | "golf_course",
219 | "ground_track_field",
220 | "harbor",
221 | "industrial_area",
222 | "intersection",
223 | "island",
224 | "lake",
225 | "meadow",
226 | "medium_residential",
227 | "mobile_home_park",
228 | "mountain",
229 | "overpass",
230 | "palace",
231 | "parking_lot",
232 | "railway",
233 | "railway_station",
234 | "rectangular_farmland",
235 | "river",
236 | "roundabout",
237 | "runway",
238 | "sea_ice",
239 | "ship",
240 | "snowberg",
241 | "sparse_residential",
242 | "stadium",
243 | "storage_tank",
244 | "tennis_court",
245 | "terrace",
246 | "thermal_power_station",
247 | "wetland",
248 | ]
249 |
250 | def __init__(
251 | self,
252 | root: str = "data",
253 | split: str = "train",
254 | transforms: Optional[Callable[[Dict[str, Tensor]], Dict[str, Tensor]]] = None,
255 | ) -> None:
256 | """Initialize a new RESISC45 dataset instance.
257 | Args:
258 | root: root directory where dataset can be found
259 | split: one of "train", "val", or "test"
260 | transforms: a function/transform that takes input sample and its target as
261 | entry and returns a transformed version
262 | """
263 | assert split in self.splits
264 | self.root = "./data"
265 |
266 | valid_fns = set()
267 | with open(os.path.join(self.root, "resisc45", f"resisc45-{split}.txt")) as f:
268 | for fn in f:
269 | valid_fns.add(fn.strip())
270 | is_in_split: Callable[[str], bool] = lambda x: os.path.basename(
271 | x) in valid_fns
272 |
273 | super().__init__(
274 | root=os.path.join(root, self.directory),
275 | transforms=transforms,
276 | is_valid_file=is_in_split,
277 | )
278 |
279 |
280 |
281 | class RESISC45:
282 | def __init__(self,
283 | preprocess,
284 | location=os.path.expanduser('./data'),
285 | batch_size=32,
286 | num_workers=16):
287 |
288 | self.train_dataset = RESISC45Dataset(root=location, split='train', transforms=preprocess)
289 | self.train_loader = torch.utils.data.DataLoader(
290 | self.train_dataset,
291 | shuffle=True,
292 | batch_size=batch_size,
293 | num_workers=num_workers,
294 | )
295 |
296 | self.test_dataset = RESISC45Dataset(root=location, split='test', transforms=preprocess)
297 | self.test_loader = torch.utils.data.DataLoader(
298 | self.test_dataset,
299 | batch_size=batch_size,
300 | num_workers=num_workers
301 | )
302 | self.test_loader_shuffle = torch.utils.data.DataLoader(
303 | self.test_dataset,
304 | shuffle=True,
305 | batch_size=batch_size,
306 | num_workers=num_workers
307 | )
308 |
309 | # class names have _ so split on this for better zero-shot head
310 | self.classnames = [' '.join(c.split('_')) for c in RESISC45Dataset.classes]
311 |
--------------------------------------------------------------------------------
/src/datasets/stl10.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torchvision.datasets as datasets
4 | from PIL import Image
5 | import numpy as np
6 |
7 | class MySTL10(datasets.STL10):
8 | def __init__(self, root, download, split, transform):
9 | super().__init__(root=root, download=download, split=split, transform=transform)
10 |
11 | def __getitem__(self, index: int):
12 | """
13 | Args:
14 | index (int): Index
15 |
16 | Returns:
17 | tuple: (image, target) where target is index of the target class.
18 | """
19 | target: Optional[int]
20 | if self.labels is not None:
21 | img, target = self.data[index], int(self.labels[index])
22 | else:
23 | img, target = self.data[index], None
24 |
25 | # doing this so that it is consistent with all other datasets
26 | # to return a PIL Image
27 | img = Image.fromarray(np.transpose(img, (1, 2, 0)))
28 |
29 | if self.transform is not None:
30 | img = self.transform(img)
31 |
32 | if self.target_transform is not None:
33 | target = self.target_transform(target)
34 |
35 | return img, target, index
36 |
37 | class STL10:
38 | def __init__(self,
39 | preprocess,
40 | location=os.path.expanduser('./data'),
41 | batch_size=128,
42 | num_workers=16):
43 |
44 | location = os.path.join(location, 'stl10')
45 | self.train_dataset = MySTL10(
46 | root=location,
47 | download=True,
48 | split='train',
49 | transform=preprocess
50 | )
51 |
52 | self.train_loader = torch.utils.data.DataLoader(
53 | self.train_dataset,
54 | batch_size=batch_size,
55 | shuffle=True,
56 | num_workers=num_workers
57 | )
58 |
59 | self.test_dataset = MySTL10(
60 | root=location,
61 | download=True,
62 | split='test',
63 | transform=preprocess
64 | )
65 |
66 | self.test_loader = torch.utils.data.DataLoader(
67 | self.test_dataset,
68 | batch_size=batch_size,
69 | shuffle=False,
70 | num_workers=num_workers
71 | )
72 |
73 | self.test_loader_shuffle = torch.utils.data.DataLoader(
74 | self.test_dataset,
75 | shuffle=True,
76 | batch_size=batch_size,
77 | num_workers=num_workers
78 | )
79 |
80 | self.classnames = self.train_dataset.classes
--------------------------------------------------------------------------------
/src/datasets/sun397.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torchvision.datasets as datasets
4 |
5 | class ImageFolderDataset(datasets.ImageFolder):
6 | def __init__(self, root, transform):
7 | super().__init__(root, transform)
8 |
9 | def __getitem__(self, index: int):
10 | path, target = self.samples[index]
11 | sample = self.loader(path)
12 | if self.transform is not None:
13 | sample = self.transform(sample)
14 | if self.target_transform is not None:
15 | target = self.target_transform(target)
16 | return sample, target, index
17 |
18 | class SUN397:
19 | def __init__(self,
20 | preprocess,
21 | location=os.path.expanduser('./data'),
22 | batch_size=32,
23 | num_workers=16):
24 | # Data loading code
25 | traindir = os.path.join(location, 'sun397', 'train')
26 | valdir = os.path.join(location, 'sun397', 'test')
27 |
28 |
29 | self.train_dataset = ImageFolderDataset(traindir, transform=preprocess)
30 | self.train_loader = torch.utils.data.DataLoader(
31 | self.train_dataset,
32 | shuffle=True,
33 | batch_size=batch_size,
34 | num_workers=num_workers,
35 | )
36 |
37 | self.test_dataset = ImageFolderDataset(valdir, transform=preprocess)
38 | self.test_loader = torch.utils.data.DataLoader(
39 | self.test_dataset,
40 | batch_size=batch_size,
41 | num_workers=num_workers
42 | )
43 | self.test_loader_shuffle = torch.utils.data.DataLoader(
44 | self.test_dataset,
45 | shuffle=True,
46 | batch_size=batch_size,
47 | num_workers=num_workers
48 | )
49 | idx_to_class = dict((v, k)
50 | for k, v in self.train_dataset.class_to_idx.items())
51 | self.classnames = [idx_to_class[i][3:].replace('_', ' ') for i in range(len(idx_to_class))]
52 | # print(self.classnames)
53 | print(len(self.classnames))
--------------------------------------------------------------------------------
/src/datasets/svhn.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from torchvision.datasets import SVHN as PyTorchSVHN
4 | import numpy as np
5 | from PIL import Image
6 |
7 | class MyPyTorchSVHN(PyTorchSVHN):
8 | def __init__(self, root, download, split, transform):
9 | super().__init__(root=root, download=download, split=split, transform=transform)
10 |
11 | def __getitem__(self, index: int):
12 | """
13 | Args:
14 | index (int): Index
15 |
16 | Returns:
17 | tuple: (image, target) where target is index of the target class.
18 | """
19 | img, target = self.data[index], int(self.labels[index])
20 |
21 | # doing this so that it is consistent with all other datasets
22 | # to return a PIL Image
23 | img = Image.fromarray(np.transpose(img, (1, 2, 0)))
24 |
25 | if self.transform is not None:
26 | img = self.transform(img)
27 |
28 | if self.target_transform is not None:
29 | target = self.target_transform(target)
30 |
31 | return img, target, index
32 |
33 | class SVHN:
34 | def __init__(self,
35 | preprocess,
36 | location=os.path.expanduser('./data'),
37 | batch_size=128,
38 | num_workers=16):
39 |
40 | # to fit with repo conventions for location
41 | modified_location = os.path.join(location, 'svhn')
42 |
43 | self.train_dataset = MyPyTorchSVHN(
44 | root=modified_location,
45 | download=True,
46 | split='train',
47 | transform=preprocess
48 | )
49 |
50 | self.train_loader = torch.utils.data.DataLoader(
51 | self.train_dataset,
52 | batch_size=batch_size,
53 | shuffle=True,
54 | num_workers=num_workers
55 | )
56 |
57 | self.test_dataset = MyPyTorchSVHN(
58 | root=modified_location,
59 | download=True,
60 | split='test',
61 | transform=preprocess
62 | )
63 |
64 | self.test_loader = torch.utils.data.DataLoader(
65 | self.test_dataset,
66 | batch_size=batch_size,
67 | shuffle=False,
68 | num_workers=num_workers
69 | )
70 |
71 | self.test_loader_shuffle = torch.utils.data.DataLoader(
72 | self.test_dataset,
73 | shuffle=True,
74 | batch_size=batch_size,
75 | num_workers=num_workers
76 | )
77 |
78 | self.classnames = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
79 |
--------------------------------------------------------------------------------
/src/datasets/templates.py:
--------------------------------------------------------------------------------
1 | cars_template = [
2 | lambda c: f'a photo of a {c}.',
3 | lambda c: f'a photo of the {c}.',
4 | lambda c: f'a photo of my {c}.',
5 | lambda c: f'i love my {c}!',
6 | lambda c: f'a photo of my dirty {c}.',
7 | lambda c: f'a photo of my clean {c}.',
8 | lambda c: f'a photo of my new {c}.',
9 | lambda c: f'a photo of my old {c}.',
10 | ]
11 |
12 | cifar10_template = [
13 | lambda c: f'a photo of a {c}.',
14 | lambda c: f'a blurry photo of a {c}.',
15 | lambda c: f'a black and white photo of a {c}.',
16 | lambda c: f'a low contrast photo of a {c}.',
17 | lambda c: f'a high contrast photo of a {c}.',
18 | lambda c: f'a bad photo of a {c}.',
19 | lambda c: f'a good photo of a {c}.',
20 | lambda c: f'a photo of a small {c}.',
21 | lambda c: f'a photo of a big {c}.',
22 | lambda c: f'a photo of the {c}.',
23 | lambda c: f'a blurry photo of the {c}.',
24 | lambda c: f'a black and white photo of the {c}.',
25 | lambda c: f'a low contrast photo of the {c}.',
26 | lambda c: f'a high contrast photo of the {c}.',
27 | lambda c: f'a bad photo of the {c}.',
28 | lambda c: f'a good photo of the {c}.',
29 | lambda c: f'a photo of the small {c}.',
30 | lambda c: f'a photo of the big {c}.',
31 | ]
32 |
33 | cifar100_template = [
34 | lambda c: f'a photo of a {c}.',
35 | lambda c: f'a blurry photo of a {c}.',
36 | lambda c: f'a black and white photo of a {c}.',
37 | lambda c: f'a low contrast photo of a {c}.',
38 | lambda c: f'a high contrast photo of a {c}.',
39 | lambda c: f'a bad photo of a {c}.',
40 | lambda c: f'a good photo of a {c}.',
41 | lambda c: f'a photo of a small {c}.',
42 | lambda c: f'a photo of a big {c}.',
43 | lambda c: f'a photo of the {c}.',
44 | lambda c: f'a blurry photo of the {c}.',
45 | lambda c: f'a black and white photo of the {c}.',
46 | lambda c: f'a low contrast photo of the {c}.',
47 | lambda c: f'a high contrast photo of the {c}.',
48 | lambda c: f'a bad photo of the {c}.',
49 | lambda c: f'a good photo of the {c}.',
50 | lambda c: f'a photo of the small {c}.',
51 | lambda c: f'a photo of the big {c}.',
52 | ]
53 |
54 | dtd_template = [
55 | lambda c: f'a photo of a {c} texture.',
56 | lambda c: f'a photo of a {c} pattern.',
57 | lambda c: f'a photo of a {c} thing.',
58 | lambda c: f'a photo of a {c} object.',
59 | lambda c: f'a photo of the {c} texture.',
60 | lambda c: f'a photo of the {c} pattern.',
61 | lambda c: f'a photo of the {c} thing.',
62 | lambda c: f'a photo of the {c} object.',
63 | ]
64 |
65 | eurosat_template = [
66 | lambda c: f'a centered satellite photo of {c}.',
67 | lambda c: f'a centered satellite photo of a {c}.',
68 | lambda c: f'a centered satellite photo of the {c}.',
69 | ]
70 |
71 | food101_template = [
72 | lambda c: f'a photo of {c}, a type of food.',
73 | ]
74 |
75 | gtsrb_template = [
76 | lambda c: f'a zoomed in photo of a "{c}" traffic sign.',
77 | lambda c: f'a centered photo of a "{c}" traffic sign.',
78 | lambda c: f'a close up photo of a "{c}" traffic sign.',
79 | ]
80 |
81 | mnist_template = [
82 | lambda c: f'a photo of the number: "{c}".',
83 | ]
84 |
85 | imagenet_template = [
86 | lambda c: f'a bad photo of a {c}.',
87 | lambda c: f'a photo of many {c}.',
88 | lambda c: f'a sculpture of a {c}.',
89 | lambda c: f'a photo of the hard to see {c}.',
90 | lambda c: f'a low resolution photo of the {c}.',
91 | lambda c: f'a rendering of a {c}.',
92 | lambda c: f'graffiti of a {c}.',
93 | lambda c: f'a bad photo of the {c}.',
94 | lambda c: f'a cropped photo of the {c}.',
95 | lambda c: f'a tattoo of a {c}.',
96 | lambda c: f'the embroidered {c}.',
97 | lambda c: f'a photo of a hard to see {c}.',
98 | lambda c: f'a bright photo of a {c}.',
99 | lambda c: f'a photo of a clean {c}.',
100 | lambda c: f'a photo of a dirty {c}.',
101 | lambda c: f'a dark photo of the {c}.',
102 | lambda c: f'a drawing of a {c}.',
103 | lambda c: f'a photo of my {c}.',
104 | lambda c: f'the plastic {c}.',
105 | lambda c: f'a photo of the cool {c}.',
106 | lambda c: f'a close-up photo of a {c}.',
107 | lambda c: f'a black and white photo of the {c}.',
108 | lambda c: f'a painting of the {c}.',
109 | lambda c: f'a painting of a {c}.',
110 | lambda c: f'a pixelated photo of the {c}.',
111 | lambda c: f'a sculpture of the {c}.',
112 | lambda c: f'a bright photo of the {c}.',
113 | lambda c: f'a cropped photo of a {c}.',
114 | lambda c: f'a plastic {c}.',
115 | lambda c: f'a photo of the dirty {c}.',
116 | lambda c: f'a jpeg corrupted photo of a {c}.',
117 | lambda c: f'a blurry photo of the {c}.',
118 | lambda c: f'a photo of the {c}.',
119 | lambda c: f'a good photo of the {c}.',
120 | lambda c: f'a rendering of the {c}.',
121 | lambda c: f'a {c} in a video game.',
122 | lambda c: f'a photo of one {c}.',
123 | lambda c: f'a doodle of a {c}.',
124 | lambda c: f'a close-up photo of the {c}.',
125 | lambda c: f'a photo of a {c}.',
126 | lambda c: f'the origami {c}.',
127 | lambda c: f'the {c} in a video game.',
128 | lambda c: f'a sketch of a {c}.',
129 | lambda c: f'a doodle of the {c}.',
130 | lambda c: f'a origami {c}.',
131 | lambda c: f'a low resolution photo of a {c}.',
132 | lambda c: f'the toy {c}.',
133 | lambda c: f'a rendition of the {c}.',
134 | lambda c: f'a photo of the clean {c}.',
135 | lambda c: f'a photo of a large {c}.',
136 | lambda c: f'a rendition of a {c}.',
137 | lambda c: f'a photo of a nice {c}.',
138 | lambda c: f'a photo of a weird {c}.',
139 | lambda c: f'a blurry photo of a {c}.',
140 | lambda c: f'a cartoon {c}.',
141 | lambda c: f'art of a {c}.',
142 | lambda c: f'a sketch of the {c}.',
143 | lambda c: f'a embroidered {c}.',
144 | lambda c: f'a pixelated photo of a {c}.',
145 | lambda c: f'itap of the {c}.',
146 | lambda c: f'a jpeg corrupted photo of the {c}.',
147 | lambda c: f'a good photo of a {c}.',
148 | lambda c: f'a plushie {c}.',
149 | lambda c: f'a photo of the nice {c}.',
150 | lambda c: f'a photo of the small {c}.',
151 | lambda c: f'a photo of the weird {c}.',
152 | lambda c: f'the cartoon {c}.',
153 | lambda c: f'art of the {c}.',
154 | lambda c: f'a drawing of the {c}.',
155 | lambda c: f'a photo of the large {c}.',
156 | lambda c: f'a black and white photo of a {c}.',
157 | lambda c: f'the plushie {c}.',
158 | lambda c: f'a dark photo of a {c}.',
159 | lambda c: f'itap of a {c}.',
160 | lambda c: f'graffiti of the {c}.',
161 | lambda c: f'a toy {c}.',
162 | lambda c: f'itap of my {c}.',
163 | lambda c: f'a photo of a cool {c}.',
164 | lambda c: f'a photo of a small {c}.',
165 | lambda c: f'a tattoo of the {c}.',
166 | ]
167 |
168 | resisc45_template = [
169 | lambda c: f'satellite imagery of {c}.',
170 | lambda c: f'aerial imagery of {c}.',
171 | lambda c: f'satellite photo of {c}.',
172 | lambda c: f'aerial photo of {c}.',
173 | lambda c: f'satellite view of {c}.',
174 | lambda c: f'aerial view of {c}.',
175 | lambda c: f'satellite imagery of a {c}.',
176 | lambda c: f'aerial imagery of a {c}.',
177 | lambda c: f'satellite photo of a {c}.',
178 | lambda c: f'aerial photo of a {c}.',
179 | lambda c: f'satellite view of a {c}.',
180 | lambda c: f'aerial view of a {c}.',
181 | lambda c: f'satellite imagery of the {c}.',
182 | lambda c: f'aerial imagery of the {c}.',
183 | lambda c: f'satellite photo of the {c}.',
184 | lambda c: f'aerial photo of the {c}.',
185 | lambda c: f'satellite view of the {c}.',
186 | lambda c: f'aerial view of the {c}.',
187 | ]
188 |
189 | stl10_template = [
190 | lambda c: f'a photo of a {c}.',
191 | lambda c: f'a photo of the {c}.',
192 | ]
193 |
194 | sun397_template = [
195 | lambda c: f'a photo of a {c}.',
196 | lambda c: f'a photo of the {c}.',
197 | ]
198 |
199 | svhn_template = [
200 | lambda c: f'a photo of the number: "{c}".',
201 | ]
202 |
203 | pets_template = [
204 | lambda c: f'a photo of a {c}, a type of pet.'
205 | ]
206 |
207 | caltech101_template = [
208 | lambda c: f'a photo of a {c}.',
209 | lambda c: f'a painting of a {c}.',
210 | lambda c: f'a plastic {c}.',
211 | lambda c: f'a sculpture of a {c}.',
212 | lambda c: f'a sketch of a {c}.',
213 | lambda c: f'a tattoo of a {c}.',
214 | lambda c: f'a toy {c}.',
215 | lambda c: f'a rendition of a {c}.',
216 | lambda c: f'a embroidered {c}.',
217 | lambda c: f'a cartoon {c}.',
218 | lambda c: f'a {c} in a video game.',
219 | lambda c: f'a plushie {c}.',
220 | lambda c: f'a origami {c}.',
221 | lambda c: f'art of a {c}.',
222 | lambda c: f'graffiti of a {c}.',
223 | lambda c: f'a drawing of a {c}.',
224 | lambda c: f'a doodle of a {c}.',
225 | lambda c: f'a photo of the {c}.',
226 | lambda c: f'a painting of the {c}.',
227 | lambda c: f'the plastic {c}.',
228 | lambda c: f'a sculpture of the {c}.',
229 | lambda c: f'a sketch of the {c}.',
230 | lambda c: f'a tattoo of the {c}.',
231 | lambda c: f'the toy {c}.',
232 | lambda c: f'a rendition of the {c}.',
233 | lambda c: f'the embroidered {c}.',
234 | lambda c: f'the cartoon {c}.',
235 | lambda c: f'the {c} in a video game.',
236 | lambda c: f'the plushie {c}.',
237 | lambda c: f'the origami {c}.',
238 | lambda c: f'art of the {c}.',
239 | lambda c: f'graffiti of the {c}.',
240 | lambda c: f'a drawing of the {c}.',
241 | lambda c: f'a doodle of the {c}.',
242 | ]
243 |
244 | flower_templates = [
245 | lambda c: f'a photo of a {c}, a type of flower.',
246 | ]
247 |
248 | cub_templates = [
249 | lambda c: f'a photo of a {c}, a type of bird.',
250 | ]
251 |
252 | fashion_mnist_template = [
253 | lambda c: f'a photo of a {c}.',
254 | lambda c: f'a blurry photo of a {c}.',
255 | lambda c: f'a black and white photo of a {c}.',
256 | lambda c: f'a thumbnail of a {c}.',
257 | lambda c: f'a photo of the {c}.',
258 | lambda c: f'a blurry photo of the {c}.',
259 | lambda c: f'a black and white photo of the {c}.',
260 | lambda c: f'a thumbnail of the {c}.',
261 | ]
262 |
263 | dataset_to_template = {
264 | 'Cars': cars_template,
265 | 'CIFAR10': cifar10_template,
266 | 'CIFAR100': cifar100_template,
267 | 'DTD': dtd_template,
268 | 'EuroSAT': eurosat_template,
269 | 'Food101': food101_template,
270 | 'GTSRB': gtsrb_template,
271 | 'MNIST': mnist_template,
272 | 'ImageNet': imagenet_template,
273 | 'RESISC45': resisc45_template,
274 | 'STL10': stl10_template,
275 | 'SUN397': sun397_template,
276 | 'SVHN': svhn_template,
277 | 'PETS': pets_template,
278 | 'Caltech101': caltech101_template,
279 | 'Flowers': flower_templates,
280 | 'TIN': imagenet_template,
281 | 'ImageNet100': imagenet_template,
282 | 'ImageNet': imagenet_template,
283 | 'CUB200': cub_templates,
284 | 'FashionMNIST': fashion_mnist_template
285 | }
286 |
287 |
288 | def get_templates(dataset_name):
289 | if dataset_name.endswith('Val'):
290 | return get_templates(dataset_name.replace('Val', ''))
291 | assert dataset_name in dataset_to_template, f'Unsupported dataset: {dataset_name}'
292 | return dataset_to_template[dataset_name]
--------------------------------------------------------------------------------
/src/eval.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import tqdm
4 | import torch
5 | import numpy as np
6 | import utils
7 | from src.datasets.common import get_dataloader, maybe_dictionarize
8 | from src.datasets.templates import get_templates
9 | from heads import get_classification_head, build_classification_head
10 | from modeling import ImageClassifier, ImageEncoder, ClassificationHead
11 | from src.datasets.registry import get_dataset_cifar_mnist
12 | import torchvision.utils as vutils
13 | from src.utils import *
14 |
15 | def eval_single_dataset(image_encoder, dataset_name, args, backdoor_info=None):
16 | print("")
17 | #
18 | classification_head = get_classification_head(args, dataset_name)
19 | model = ImageClassifier(image_encoder, classification_head)
20 | model.eval()
21 |
22 | #
23 | test_dataset, test_loader = get_dataset_cifar_mnist(
24 | dataset_name,
25 | 'shadowtest',
26 | model.val_preprocess,
27 | location=args.data_location,
28 | batch_size=args.batch_size
29 | )
30 | normalizer = model.val_preprocess.transforms[-1]
31 | inv_normalizer = NormalizeInverse(normalizer.mean, normalizer.std)
32 | print("Evaluation Size:", len(test_dataset))
33 |
34 | device = args.device
35 |
36 | with torch.no_grad():
37 | top1, correct, n = 0., 0., 0.
38 | for i, data in enumerate(tqdm.tqdm(test_loader)):
39 | data = maybe_dictionarize(data)
40 | x = data['images']
41 | y = data['labels']
42 |
43 | x = x.cuda()
44 | y = y.cuda()
45 | logits = utils.get_logits(x, model)
46 | pred = logits.argmax(dim=1, keepdim=True).to(device)
47 | correct += pred.eq(y.view_as(pred)).sum().item()
48 | n += y.size(0)
49 |
50 | top1 = correct / n
51 |
52 | metrics = {'top1': top1}
53 | print(f'Accuracy: {100*top1:.2f}%')
54 |
55 | return metrics
56 |
57 | def eval_single_dataset_head(image_encoder, head, dataset_name, args):
58 | model = ImageClassifier(image_encoder, head)
59 | model.eval()
60 | test_dataset, test_loader = get_dataset_cifar_mnist(dataset_name, 'test', model.val_preprocess, location=args.data_location, batch_size=args.batch_size)
61 | device = args.device
62 |
63 | with torch.no_grad():
64 | top1, correct, n = 0., 0., 0.
65 | for i, data in enumerate(tqdm.tqdm(test_loader)):
66 | data = maybe_dictionarize(data)
67 | x = data['images'].to(device)
68 | y = data['labels'].to(device)
69 | logits = utils.get_logits(x, model)
70 | pred = logits.argmax(dim=1, keepdim=True).to(device)
71 | correct += pred.eq(y.view_as(pred)).sum().item()
72 | n += y.size(0)
73 | top1 = correct / n
74 |
75 | metrics = {'top1': top1}
76 | print(f'Done evaluating on {dataset_name}. Accuracy: {100 * top1:.2f}%')
77 | return metrics
78 |
79 | def eval_single_dataset_preprocess_head(image_encoder, head, dataset_name, args):
80 | model = ImageClassifier(image_encoder, head)
81 | model.eval()
82 | test_dataset, test_loader = get_dataset_cifar_mnist(dataset_name, model.val_preprocess, 'test', location=args.data_location, batch_size=args.batch_size)
83 | device = args.device
84 |
85 | with torch.no_grad():
86 | top1, correct, n = 0., 0., 0.
87 | for i, data in enumerate(tqdm.tqdm(test_loader)):
88 | data = maybe_dictionarize(data)
89 | x = data['images'].to(device)
90 | y = data['labels'].to(device)
91 | logits = utils.get_logits(x, model)
92 | pred = logits.argmax(dim=1, keepdim=True).to(device)
93 | correct += pred.eq(y.view_as(pred)).sum().item()
94 | n += y.size(0)
95 | top1 = correct / n
96 | metrics = {'top1': top1}
97 | print(f'Done evaluating on {dataset_name}. Accuracy: {100 * top1:.2f}%')
98 | return metrics
99 |
100 | def evaluate(image_encoder, args, backdoor_info=None):
101 | if args.eval_datasets is None:
102 | return
103 | info = vars(args)
104 | for i, dataset_name in enumerate(args.eval_datasets):
105 | print('Evaluating on', dataset_name)
106 |
107 | results = eval_single_dataset(image_encoder, dataset_name, args, backdoor_info)
108 |
109 | for key, val in results.items():
110 | if 'worst' in key or 'f1' in key.lower() or 'pm0' in key:
111 | print(f"{dataset_name} {key}: {val:.4f}")
112 | if backdoor_info is not None:
113 | info[dataset_name + '-B:' + key] = val # trigger
114 | else:
115 | info[dataset_name + ':' + key] = val # clean
116 | return info
--------------------------------------------------------------------------------
/src/figures/comparison.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LiQiiiii/Neural-Ligand/a51a30e9600e366511ceca6d650b09b08267033a/src/figures/comparison.png
--------------------------------------------------------------------------------
/src/figures/exp.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LiQiiiii/Neural-Ligand/a51a30e9600e366511ceca6d650b09b08267033a/src/figures/exp.png
--------------------------------------------------------------------------------
/src/figures/figures.txt:
--------------------------------------------------------------------------------
1 | Some figures in the paper are saved here.
2 |
--------------------------------------------------------------------------------
/src/figures/main_table.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LiQiiiii/Neural-Ligand/a51a30e9600e366511ceca6d650b09b08267033a/src/figures/main_table.png
--------------------------------------------------------------------------------
/src/figures/neulig_overview.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LiQiiiii/Neural-Ligand/a51a30e9600e366511ceca6d650b09b08267033a/src/figures/neulig_overview.png
--------------------------------------------------------------------------------
/src/figures/neulig_train_pip.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LiQiiiii/Neural-Ligand/a51a30e9600e366511ceca6d650b09b08267033a/src/figures/neulig_train_pip.png
--------------------------------------------------------------------------------
/src/finetune_clean.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import sys
4 | sys.path.append(os.path.abspath('.'))
5 | import torch
6 | from src.args import parse_arguments
7 | from src.datasets.common import get_dataloader, maybe_dictionarize
8 | from src.datasets.registry import get_dataset_cifar_mnist
9 | from src.eval import evaluate
10 | from src.modeling import ImageEncoder, ImageClassifier, MultiHeadImageClassifier
11 | from src.utils import cosine_lr, LabelSmoothing
12 | from src.heads import get_classification_head
13 | import src.datasets as datasets
14 | from torch.utils.data import Subset
15 | import pickle
16 | import numpy as np
17 | import copy
18 | import tqdm
19 |
20 | def save_dataset_splits(train_dataset, test_dataset, shadowtrain_dataset, shadowtest_dataset, save_dir):
21 | os.makedirs(save_dir, exist_ok=True)
22 | with open(os.path.join(save_dir, 'train_dataset.pkl'), 'wb') as f:
23 | pickle.dump(train_dataset, f)
24 |
25 | with open(os.path.join(save_dir, 'test_dataset.pkl'), 'wb') as f:
26 | pickle.dump(test_dataset, f)
27 |
28 | with open(os.path.join(save_dir, 'shadowtrain_dataset.pkl'), 'wb') as f:
29 | pickle.dump(shadowtrain_dataset, f)
30 |
31 | with open(os.path.join(save_dir, 'shadowtest_dataset.pkl'), 'wb') as f:
32 | pickle.dump(shadowtest_dataset, f)
33 |
34 | def load_datasets(save_dir):
35 | with open(os.path.join(save_dir, 'train_dataset.pkl'), 'rb') as f:
36 | train_dataset = pickle.load(f)
37 |
38 | with open(os.path.join(save_dir, 'test_dataset.pkl'), 'rb') as f:
39 | test_dataset = pickle.load(f)
40 |
41 | with open(os.path.join(save_dir, 'shadowtrain_dataset.pkl'), 'rb') as f:
42 | shadowtrain_dataset = pickle.load(f)
43 |
44 | with open(os.path.join(save_dir, 'shadowtest_dataset.pkl'), 'rb') as f:
45 | shadowtest_dataset = pickle.load(f)
46 |
47 |
48 | return train_dataset, test_dataset, shadowtrain_dataset, shadowtest_dataset
49 | def check_datasets_exist(save_dir):
50 | return (os.path.exists(os.path.join(save_dir, 'train_dataset.pkl')) and
51 | os.path.exists(os.path.join(save_dir, 'test_dataset.pkl')) and
52 | os.path.exists(os.path.join(save_dir, 'shadowtrain_dataset.pkl')) and
53 | os.path.exists(os.path.join(save_dir, 'shadowtest_dataset.pkl')))
54 |
55 | def load_dataset_splits(save_dir):
56 | with open(os.path.join(save_dir, 'train_indices.pkl'), 'rb') as f:
57 | train_indices = pickle.load(f)
58 | with open(os.path.join(save_dir, 'test_indices.pkl'), 'rb') as f:
59 | test_indices = pickle.load(f)
60 | with open(os.path.join(save_dir, 'shadowtrain_indices.pkl'), 'rb') as f:
61 | shadowtrain_indices = pickle.load(f)
62 | with open(os.path.join(save_dir, 'shadowtest_indices.pkl'), 'rb') as f:
63 | shadowtest_indices = pickle.load(f)
64 |
65 | return train_indices, test_indices, shadowtrain_indices, shadowtest_indices
66 |
67 | def finetune(model, args):
68 | dataset = args.dataset
69 | preprocess_fn = model.train_preprocess
70 |
71 | print_every = 100
72 | dataset_save_dir = os.path.join("{}/{}/dataset_splits".format(args.save, args.dataset))
73 |
74 | if check_datasets_exist(dataset_save_dir):
75 | print("Subsets already exits...")
76 | from torch.utils.data import DataLoader
77 | train_dataset, test_dataset, shadowtrain_dataset, shadowtest_dataset = load_datasets(dataset_save_dir)
78 | train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
79 | shadowtrain_loader = DataLoader(shadowtrain_dataset, batch_size=args.batch_size, shuffle=True)
80 | else:
81 | print("Subsets do not exist...")
82 | train_dataset, train_loader = get_dataset_cifar_mnist(
83 | dataset,
84 | 'train',
85 | preprocess_fn,
86 | location=args.data_location,
87 | batch_size=args.batch_size
88 | )
89 | test_dataset, test_loader = get_dataset_cifar_mnist(
90 | dataset,
91 | 'test',
92 | preprocess_fn,
93 | location=args.data_location,
94 | batch_size=args.batch_size
95 | )
96 | shadowtrain_dataset, shadowtrain_loader = get_dataset_cifar_mnist(
97 | dataset,
98 | 'shadowtrain',
99 | preprocess_fn,
100 | location=args.data_location,
101 | batch_size=args.batch_size
102 | )
103 | shadowtest_dataset, shadowtest_loader = get_dataset_cifar_mnist(
104 | dataset,
105 | 'shadowtest',
106 | preprocess_fn,
107 | location=args.data_location,
108 | batch_size=args.batch_size
109 | )
110 | save_dataset_splits(train_dataset, test_dataset, shadowtrain_dataset, shadowtest_dataset, dataset_save_dir)
111 |
112 | num_batches = len(train_loader)
113 | print("train_length: {}, val_length: {}, shadowtrain_length: {}, shadowtest_length: {}".format(len(train_dataset), len(test_dataset), len(shadowtrain_dataset), len(shadowtest_dataset)))
114 | # save pre-trained model
115 |
116 | # dataset_dir = dataset + '_1epoch'
117 | # ckpdir = os.path.join(args.save, dataset_dir)
118 |
119 | ckpdir = os.path.join(args.save, dataset)
120 |
121 | if args.save is not None:
122 | os.makedirs(ckpdir, exist_ok=True)
123 | model_path = os.path.join(args.save, f'zeroshot.pt')
124 | if not os.path.exists(model_path):
125 | model.image_encoder.save(model_path)
126 | # evaluate pre-trained model
127 | print("Initial evaluation:")
128 | image_encoder = model.image_encoder
129 | args.eval_datasets = [dataset]
130 | evaluate(image_encoder, args)
131 |
132 | # test_loaders = [test_loader]
133 | # evaluate_single(model, test_loaders, args.device)
134 |
135 | # train model for target train set
136 | loss_fn = torch.nn.CrossEntropyLoss()
137 | params = [p for p in model.parameters() if p.requires_grad]
138 | optimizer = torch.optim.AdamW(params, lr=args.lr, weight_decay=args.wd)
139 | scheduler = cosine_lr(optimizer, args.lr, args.warmup_length, args.epochs * num_batches)
140 | for epoch in range(args.epochs):
141 | model = model.cuda()
142 | model.train()
143 | for i, batch in enumerate(train_loader):
144 | start_time = time.time()
145 | step = i + epoch * num_batches
146 | scheduler(step)
147 | optimizer.zero_grad()
148 |
149 | batch = maybe_dictionarize(batch)
150 | inputs = batch['images'].to('cuda:0')
151 | labels = batch['labels'].to('cuda:0')
152 | data_time = time.time() - start_time
153 |
154 | logits = model(inputs)
155 | loss = loss_fn(logits, labels)
156 | loss.backward()
157 | torch.nn.utils.clip_grad_norm_(params, 1.0)
158 | optimizer.step()
159 | batch_time = time.time() - start_time
160 |
161 | if step % print_every == 0:
162 | percent_complete = 100 * i / len(train_loader)
163 | print(
164 | f"Target Train Epoch: {epoch} [{percent_complete:.0f}% {i}/{len(train_loader)}]\t"
165 | f"Loss: {loss.item():.6f}\tData (t) {data_time:.3f}\tBatch (t) {batch_time:.3f}", flush=True
166 | )
167 | # evaluate target model
168 | image_encoder = model.image_encoder
169 | args.eval_datasets = [dataset] # eval dataset
170 | evaluate(image_encoder, args)
171 |
172 | # Save the finetuned model
173 | if args.save is not None:
174 | ft_path = os.path.join(ckpdir, 'finetuned.pt')
175 | image_encoder.save(ft_path)
176 |
177 | def finetune_dev(model_shadow, args):
178 | dataset = args.dataset
179 | preprocess_fn = model_shadow.train_preprocess
180 |
181 | print_every = 100
182 | dataset_save_dir = os.path.join("{}/{}/dataset_splits".format(args.save, args.dataset))
183 |
184 | if check_datasets_exist(dataset_save_dir):
185 | print("Subsets already exits...")
186 | from torch.utils.data import DataLoader
187 | train_dataset, test_dataset, shadowtrain_dataset, shadowtest_dataset = load_datasets(dataset_save_dir)
188 | train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
189 | shadowtrain_loader = DataLoader(shadowtrain_dataset, batch_size=args.batch_size, shuffle=True)
190 | else:
191 | print("Subsets do not exist...")
192 | train_dataset, train_loader = get_dataset_cifar_mnist(
193 | dataset,
194 | 'train',
195 | preprocess_fn,
196 | location=args.data_location,
197 | batch_size=args.batch_size
198 | )
199 | test_dataset, test_loader = get_dataset_cifar_mnist(
200 | dataset,
201 | 'test',
202 | preprocess_fn,
203 | location=args.data_location,
204 | batch_size=args.batch_size
205 | )
206 | shadowtrain_dataset, shadowtrain_loader = get_dataset_cifar_mnist(
207 | dataset,
208 | 'shadowtrain',
209 | preprocess_fn,
210 | location=args.data_location,
211 | batch_size=args.batch_size
212 | )
213 | shadowtest_dataset, shadowtest_loader = get_dataset_cifar_mnist(
214 | dataset,
215 | 'shadowtest',
216 | preprocess_fn,
217 | location=args.data_location,
218 | batch_size=args.batch_size
219 | )
220 | save_dataset_splits(train_dataset, test_dataset, shadowtrain_dataset, shadowtest_dataset, dataset_save_dir)
221 |
222 | num_batches = len(train_loader)
223 | print("train_length: {}, val_length: {}, shadowtrain_length: {}, shadowtest_length: {}".format(len(train_dataset), len(test_dataset), len(shadowtrain_dataset), len(shadowtest_dataset)))
224 |
225 | # train model for shadow train set
226 | model_shadow = model_shadow.to(args.device)
227 | loss_fn_shadow = torch.nn.CrossEntropyLoss()
228 | params_shadow = [p for p in model_shadow.parameters() if p.requires_grad]
229 | optimizer_shadow = torch.optim.AdamW(params_shadow, lr=args.lr, weight_decay=args.wd)
230 | scheduler_shadow = cosine_lr(optimizer_shadow, args.lr, args.warmup_length, args.epochs * num_batches)
231 | for epoch in range(args.epochs):
232 | model_shadow = model_shadow.cuda()
233 | model_shadow.train()
234 | for i, batch in enumerate(shadowtrain_loader):
235 | start_time = time.time()
236 | step = i + epoch * num_batches
237 | scheduler_shadow(step)
238 | optimizer_shadow.zero_grad()
239 |
240 | batch = maybe_dictionarize(batch)
241 | inputs = batch['images'].to('cuda:0')
242 | labels = batch['labels'].to('cuda:0')
243 | data_time = time.time() - start_time
244 |
245 | logits = model_shadow(inputs)
246 | loss = loss_fn_shadow(logits, labels)
247 | loss.backward()
248 | torch.nn.utils.clip_grad_norm_(params_shadow, 1.0)
249 | optimizer_shadow.step()
250 | batch_time = time.time() - start_time
251 |
252 | if step % print_every == 0:
253 | percent_complete = 100 * i / len(shadowtrain_loader)
254 | print(
255 | f"Shadow Train Epoch: {epoch} [{percent_complete:.0f}% {i}/{len(shadowtrain_loader)}]\t"
256 | f"Loss: {loss.item():.6f}\tData (t) {data_time:.3f}\tBatch (t) {batch_time:.3f}", flush=True
257 | )
258 |
259 | # evaluate shadow model
260 | image_encoder_shadow = model_shadow.image_encoder
261 | args.eval_datasets = [dataset] # eval dataset
262 | evaluate(image_encoder_shadow, args)
263 |
264 | ckpdir = os.path.join(args.save, dataset)
265 | if args.save is not None:
266 | dev_ft_path = os.path.join(ckpdir, 'finetuned_dev.pt')
267 | image_encoder_shadow.save(dev_ft_path)
268 |
269 | if __name__ == '__main__':
270 | data_location = "./data"
271 | models = ['RN50', 'ViT-B-32', 'ViT-L-14']
272 | datasets = ['CIFAR10', 'MNIST', 'GTSRB', 'RESISC45', 'CIFAR100', 'SVHN', 'STL10']
273 |
274 | epochs = {
275 | 'GTSRB': 11,
276 | 'MNIST': 5,
277 | 'RESISC45': 15,
278 | 'SVHN': 4,
279 | 'STL10': 50,
280 | 'CIFAR100': 5,
281 | 'CIFAR10': 5,
282 | }
283 |
284 | for model_name in models:
285 | for dataset in datasets:
286 | print('='*100)
287 | print(f'Finetuning {model_name} on {dataset}')
288 | print('='*100)
289 | args = parse_arguments()
290 |
291 | args.lr = 1e-5
292 | args.epochs = epochs[dataset]
293 | args.data_location = data_location
294 | args.dataset = dataset
295 | args.batch_size = 32
296 |
297 | args.model = model_name
298 | args.save = f'./checkpoints/{args.model}'
299 | args.cache_dir = ''
300 | args.openclip_cachedir = './open_clip'
301 | image_encoder = ImageEncoder(args, keep_lang=False)
302 | classification_head = get_classification_head(args, dataset)
303 | model = ImageClassifier(image_encoder, classification_head)
304 | model.freeze_head()
305 | model_shadow = copy.deepcopy(model)
306 |
307 | finetune(model, args)
308 | finetune_dev(model_shadow, args)
309 |
--------------------------------------------------------------------------------
/src/heads.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from tqdm import tqdm
4 | import open_clip
5 | from src.datasets.templates import get_templates
6 | from src.datasets.registry import get_dataset_classnames
7 | from src.modeling import ClassificationHead, ImageEncoder
8 |
9 | def build_classification_head(model, dataset_name, template, data_location, device):
10 | template = get_templates(dataset_name)
11 |
12 | logit_scale = model.logit_scale
13 | classnames = get_dataset_classnames(
14 | dataset_name,
15 | None,
16 | location=data_location
17 | )
18 | model.eval()
19 | model.to(device)
20 |
21 | print('Building classification head.')
22 | with torch.no_grad():
23 | zeroshot_weights = []
24 | for classname in tqdm(classnames):
25 | texts = []
26 | for t in template:
27 | texts.append(t(classname))
28 | texts = open_clip.tokenize(texts).to(device) # tokenize
29 | embeddings = model.encode_text(texts) # embed with text encoder
30 | embeddings /= embeddings.norm(dim=-1, keepdim=True)
31 |
32 | embeddings = embeddings.mean(dim=0, keepdim=True)
33 | embeddings /= embeddings.norm()
34 |
35 | zeroshot_weights.append(embeddings)
36 |
37 | zeroshot_weights = torch.stack(zeroshot_weights, dim=0).to(device)
38 | zeroshot_weights = torch.transpose(zeroshot_weights, 0, 2)
39 |
40 | zeroshot_weights *= logit_scale.exp()
41 |
42 | zeroshot_weights = zeroshot_weights.squeeze().float()
43 | zeroshot_weights = torch.transpose(zeroshot_weights, 0, 1)
44 |
45 | classification_head = ClassificationHead(normalize=True, weights=zeroshot_weights)
46 |
47 | return classification_head
48 |
49 |
50 | def get_classification_head(args, dataset):
51 | filename = os.path.join(args.save, f'head_{dataset}.pt')
52 | if os.path.exists(filename):
53 | print(f'Classification head for {args.model} on {dataset} exists at {filename}')
54 | return ClassificationHead.load(filename)
55 | print(f'Did not find classification head for {args.model} on {dataset} at {filename}, building one from scratch.')
56 | model = ImageEncoder(args, keep_lang=True).model
57 | template = get_templates(dataset)
58 | classification_head = build_classification_head(model, dataset, template, args.data_location, args.device)
59 | os.makedirs(args.save, exist_ok=True)
60 | classification_head.save(filename)
61 | return classification_head
62 |
63 | def get_classification_head_dev(args, model, dataset, flag):
64 | if flag == 'shadow':
65 | filename = os.path.join(args.save, f'head_{dataset}_shadow.pt')
66 | if os.path.exists(filename):
67 | print(f'Classification head for {args.model} on {dataset} exists at {filename}')
68 | return ClassificationHead.load(filename)
69 | print(f'Did not find classification head for {args.model} on {dataset} at {filename}, building one from scratch.')
70 | model_todo = model.model
71 | template = get_templates(dataset)
72 | classification_head = build_classification_head(model_todo, dataset, template, args.data_location, args.device)
73 | os.makedirs(args.save, exist_ok=True)
74 | classification_head.save(filename)
75 | return classification_head
76 | if flag == 'target':
77 | filename = os.path.join(args.save, f'head_{dataset}_target.pt')
78 | if os.path.exists(filename):
79 | print(f'Classification head for {args.model} on {dataset} exists at {filename}')
80 | return ClassificationHead.load(filename)
81 | print(f'Did not find classification head for {args.model} on {dataset} at {filename}, building one from scratch.')
82 | model_todo = model.model
83 | template = get_templates(dataset)
84 | classification_head = build_classification_head(model_todo, dataset, template, args.data_location, args.device)
85 | os.makedirs(args.save, exist_ok=True)
86 | classification_head.save(filename)
87 | return classification_head
--------------------------------------------------------------------------------
/src/modeling.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | import open_clip
4 |
5 | import utils
6 | import math
7 |
8 |
9 | class ImageEncoder(torch.nn.Module):
10 | def __init__(self, args, keep_lang=False):
11 | super().__init__()
12 |
13 | print(f'Creating {args.model} with random initialization.')
14 |
15 | self.model, self.train_preprocess, self.val_preprocess = open_clip.create_model_and_transforms(
16 | args.model, pretrained='openai', cache_dir=args.openclip_cachedir) # pretrained=None
17 |
18 | self.cache_dir = args.cache_dir
19 |
20 | if not keep_lang and hasattr(self.model, 'transformer'):
21 | delattr(self.model, 'transformer')
22 |
23 | # self._initialize_weights()
24 |
25 | def _initialize_weights(self):
26 | for module in self.model.modules():
27 | if isinstance(module, (torch.nn.Linear, torch.nn.Conv2d)):
28 | torch.nn.init.kaiming_uniform_(module.weight, a=math.sqrt(5))
29 | if module.bias is not None:
30 | fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(module.weight)
31 | bound = 1 / math.sqrt(fan_in)
32 | torch.nn.init.uniform_(module.bias, -bound, bound)
33 |
34 | def forward(self, images):
35 | assert self.model is not None
36 | return self.model.encode_image(images)
37 |
38 | def __call__(self, inputs):
39 | return self.forward(inputs)
40 |
41 | def save(self, filename):
42 | print(f'Saving image encoder to {filename}')
43 | utils.torch_save(self, filename)
44 |
45 | @classmethod
46 | def load(cls, model_name, filename):
47 | print(f'Loading image encoder from {filename}')
48 | state_dict = torch.load(filename)
49 | return cls.load(model_name, state_dict)
50 |
51 | def load_from_state_dict(self, model_name, state_dict):
52 | print("start loading state dict from {}".format(state_dict))
53 | self.model, self.train_preprocess, self.val_preprocess = open_clip.create_model_and_transforms(
54 | model_name, pretrained='openai', device='cpu')
55 | # model.load_from_state_dict(state_dict, strict=False)
56 | checkpoint = torch.load(state_dict, map_location=torch.device('cpu'))
57 | self.model.visual.load_state_dict(checkpoint)
58 | delattr(self.model, 'transformer')
59 | print("successfully loading state dict!")
60 |
61 | class ClassificationHead(torch.nn.Linear):
62 | def __init__(self, normalize, weights, biases=None):
63 | output_size, input_size = weights.shape
64 | super().__init__(input_size, output_size)
65 | self.normalize = normalize
66 | if weights is not None:
67 | self.weight = torch.nn.Parameter(weights.clone())
68 | if biases is not None:
69 | self.bias = torch.nn.Parameter(biases.clone())
70 | else:
71 | self.bias = torch.nn.Parameter(torch.zeros_like(self.bias))
72 |
73 | def forward(self, inputs):
74 | if self.normalize:
75 | inputs = inputs / inputs.norm(dim=-1, keepdim=True)
76 | return super().forward(inputs)
77 |
78 | def __call__(self, inputs):
79 | return self.forward(inputs)
80 |
81 | def save(self, filename):
82 | print(f'Saving classification head to {filename}')
83 | utils.torch_save(self, filename)
84 |
85 | @classmethod
86 | def load(cls, filename):
87 | print(f'Loading classification head from {filename}')
88 | return utils.torch_load(filename)
89 |
90 |
91 | class ImageClassifier(torch.nn.Module):
92 | def __init__(self, image_encoder, classification_head):
93 | super().__init__()
94 | self.image_encoder = image_encoder
95 | self.classification_head = classification_head
96 | if self.image_encoder is not None:
97 | if hasattr(self.image_encoder, 'train_preprocess'):
98 | self.train_preprocess = self.image_encoder.train_preprocess
99 | self.val_preprocess = self.image_encoder.val_preprocess
100 | elif hasattr(self.image_encoder.model, 'train_preprocess'):
101 | self.train_preprocess = self.image_encoder.model.train_preprocess
102 | self.val_preprocess = self.image_encoder.model.val_preprocess
103 |
104 | def freeze_head(self):
105 | self.classification_head.weight.requires_grad_(False)
106 | self.classification_head.bias.requires_grad_(False)
107 |
108 | def forward(self, inputs):
109 | features = self.image_encoder(inputs)
110 | outputs = self.classification_head(features)
111 | return outputs
112 |
113 | def __call__(self, inputs):
114 | return self.forward(inputs)
115 |
116 | def save(self, filename):
117 | print(f'Saving image classifier to {filename}')
118 | utils.torch_save(self, filename)
119 |
120 | @classmethod
121 | def load(cls, filename):
122 | print(f'Loading image classifier from {filename}')
123 | return utils.torch_load(filename)
124 |
125 | class ImageClassifier_debug(torch.nn.Module):
126 | def __init__(self, image_encoder, image_encoder2, classification_head):
127 | super().__init__()
128 | self.image_encoder = image_encoder
129 | self.image_encoder2 = image_encoder2
130 | self.classification_head = classification_head
131 | if self.image_encoder is not None:
132 | self.train_preprocess = self.image_encoder.train_preprocess
133 | self.val_preprocess = self.image_encoder.val_preprocess
134 |
135 | def freeze_head(self):
136 | self.classification_head.weight.requires_grad_(False)
137 | self.classification_head.bias.requires_grad_(False)
138 |
139 | def forward(self, inputs):
140 | features = self.image_encoder(inputs)
141 | features2 = self.image_encoder2(inputs)
142 | outputs = self.classification_head(features + features2)
143 | return outputs
144 |
145 | def __call__(self, inputs):
146 | return self.forward(inputs)
147 |
148 | def save(self, filename):
149 | print(f'Saving image classifier to {filename}')
150 | utils.torch_save(self, filename)
151 |
152 | @classmethod
153 | def load(cls, filename):
154 | print(f'Loading image classifier from {filename}')
155 | return utils.torch_load(filename)
156 |
157 | class MultiHeadImageClassifier(torch.nn.Module):
158 | def __init__(self, image_encoder, classification_heads):
159 | super().__init__()
160 | self.image_encoder = image_encoder
161 | self.classification_heads = torch.nn.ModuleList(classification_heads)
162 | if self.image_encoder is not None:
163 | self.train_preprocess = self.image_encoder.train_preprocess
164 | self.val_preprocess = self.image_encoder.val_preprocess
165 |
166 | def freeze_head(self):
167 | for idx in range(len(self.classification_heads)):
168 | self.classification_heads[idx].weight.requires_grad_(False)
169 | self.classification_heads[idx].bias.requires_grad_(False)
170 |
171 | def forward(self, inputs, head_idx):
172 | features = self.image_encoder(inputs)
173 | outputs = self.classification_heads[head_idx](features)
174 | return outputs
175 |
176 | def __call__(self, inputs, head_idx):
177 | return self.forward(inputs, head_idx)
178 |
179 | def save(self, filename):
180 | print(f'Saving image classifier to {filename}')
181 | utils.torch_save(self, filename)
182 |
183 | @classmethod
184 | def load(cls, filename):
185 | print(f'Loading image classifier from {filename}')
186 | return utils.torch_load(filename)
187 |
--------------------------------------------------------------------------------
/src/neulig_main.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import time
4 | import sys
5 | import tqdm
6 | sys.path.append('.')
7 | sys.path.append('./src')
8 | from src.modeling import ImageEncoder
9 | from task_vectors import TaskVector
10 | # from eval import eval_single_dataset
11 | from args import parse_arguments
12 | from utils import *
13 | import torchvision.transforms as transforms
14 | from PIL import Image
15 | import time
16 | import torchvision.utils as vutils
17 | # from src.datasets.registry import get_dataset
18 | from src.heads import get_classification_head
19 | import torch
20 | from collections import Counter
21 | import torch.nn.functional as F
22 | import torch.nn as nn
23 | import torch.optim as optim
24 | from src.datasets.common import get_dataloader, maybe_dictionarize
25 | import timm
26 | from itertools import cycle
27 | from modeling import ImageClassifier, ImageEncoder, ClassificationHead
28 | from open_clip import create_model_and_transforms
29 | from torch.utils.data import DataLoader, TensorDataset
30 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
31 |
32 | def merge_ckps(fusion_model, sample_weights):
33 | flat_ft = torch.vstack([state_dict_to_vector(check, []).to('cpu') for check in fusion_model.ckpts]).to('cpu')
34 | tv_flat_checks = flat_ft
35 | final_ck = None
36 | for j in range(fusion_model.num_models):
37 | weighted_value = sample_weights[0, j].to('cpu') * tv_flat_checks[j]
38 | if final_ck is None:
39 | final_ck = weighted_value
40 | else:
41 | final_ck += weighted_value
42 | final_ck = final_ck.to(device)
43 | return final_ck
44 |
45 |
46 | def save_dataset_splits(train_dataset, test_dataset, shadowtrain_dataset, shadowtest_dataset, save_dir):
47 | os.makedirs(save_dir, exist_ok=True)
48 | with open(os.path.join(save_dir, 'train_dataset.pkl'), 'wb') as f:
49 | pickle.dump(train_dataset, f)
50 |
51 | with open(os.path.join(save_dir, 'test_dataset.pkl'), 'wb') as f:
52 | pickle.dump(test_dataset, f)
53 |
54 | with open(os.path.join(save_dir, 'shadowtrain_dataset.pkl'), 'wb') as f:
55 | pickle.dump(shadowtrain_dataset, f)
56 |
57 | with open(os.path.join(save_dir, 'shadowtest_dataset.pkl'), 'wb') as f:
58 | pickle.dump(shadowtest_dataset, f)
59 |
60 | def load_datasets(save_dir):
61 | with open(os.path.join(save_dir, 'train_dataset.pkl'), 'rb') as f:
62 | train_dataset = pickle.load(f)
63 |
64 | with open(os.path.join(save_dir, 'test_dataset.pkl'), 'rb') as f:
65 | test_dataset = pickle.load(f)
66 |
67 | with open(os.path.join(save_dir, 'shadowtrain_dataset.pkl'), 'rb') as f:
68 | shadowtrain_dataset = pickle.load(f)
69 |
70 | with open(os.path.join(save_dir, 'shadowtest_dataset.pkl'), 'rb') as f:
71 | shadowtest_dataset = pickle.load(f)
72 |
73 | return train_dataset, test_dataset, shadowtrain_dataset, shadowtest_dataset
74 | def check_datasets_exist(save_dir):
75 | return (os.path.exists(os.path.join(save_dir, 'train_dataset.pkl')) and
76 | os.path.exists(os.path.join(save_dir, 'test_dataset.pkl')) and
77 | os.path.exists(os.path.join(save_dir, 'shadowtrain_dataset.pkl')) and
78 | os.path.exists(os.path.join(save_dir, 'shadowtest_dataset.pkl'))
79 | )
80 | def load_dataset_splits(save_dir):
81 | with open(os.path.join(save_dir, 'train_indices.pkl'), 'rb') as f:
82 | train_indices = pickle.load(f)
83 | with open(os.path.join(save_dir, 'test_indices.pkl'), 'rb') as f:
84 | test_indices = pickle.load(f)
85 | with open(os.path.join(save_dir, 'shadowtrain_indices.pkl'), 'rb') as f:
86 | shadowtrain_indices = pickle.load(f)
87 | with open(os.path.join(save_dir, 'shadowtest_indices.pkl'), 'rb') as f:
88 | shadowtest_indices = pickle.load(f)
89 |
90 | return train_indices, test_indices, shadowtrain_indices, shadowtest_indices
91 |
92 |
93 | def evaluate_ori(fusion_model, test_loaders, criterion, device):
94 | fusion_model.eval()
95 | total_loss = 0.0
96 | merged_total_loss = 0.0
97 | total_correct = []
98 | total_samples = []
99 | merged_total_correct = []
100 |
101 | with torch.no_grad():
102 | for loader_idx, test_loader in enumerate(test_loaders):
103 | cur_correct = 0
104 | cur_samples = 0
105 | merged_cur_correct = 0
106 | for i, data in enumerate(tqdm.tqdm(test_loader)):
107 | data = maybe_dictionarize(data)
108 | inputs = data['images'].to(device)
109 | labels = data['labels'].to(device)
110 |
111 | outputs, _ = fusion_model(inputs, dataset_index=loader_idx)
112 |
113 | model_outputs = []
114 | for i, model in enumerate(fusion_model.models):
115 | model.eval()
116 | with torch.no_grad():
117 | output = model(inputs)
118 | model_outputs.append(output)
119 | weighting_model = fusion_model.get_weighting_model()
120 | stacked_outputs = torch.cat(model_outputs, dim=1)
121 | merge_weights = weighting_model(stacked_outputs)
122 |
123 | merged_checks = merge_ckps(fusion_model, merge_weights)
124 | merged_state_dict = vector_to_state_dict(merged_checks, ptm_check, remove_keys=[])
125 | image_encoder.load_state_dict(merged_state_dict, strict=False)
126 | image_encoder.to(device)
127 | merged_model = ImageClassifier(image_encoder, fusion_model.prediction_heads[loader_idx])
128 | merged_outputs = merged_model(inputs)
129 | loss = criterion(outputs, labels)
130 | total_loss += loss.item()
131 |
132 | merged_loss = criterion(merged_outputs, labels)
133 | merged_total_loss += merged_loss.item()
134 |
135 | cur_samples += labels.size(0)
136 |
137 | _, predicted = torch.max(outputs.data, 1)
138 | cur_correct += (predicted == labels).sum().item()
139 |
140 | _, merged_predicted = torch.max(merged_outputs.data, 1)
141 | merged_cur_correct += (merged_predicted == labels).sum().item()
142 |
143 | total_samples.append(cur_samples)
144 |
145 | total_correct.append(cur_correct)
146 | merged_total_correct.append(merged_cur_correct)
147 |
148 | accuracies = [100.0 * total_correct[i] / total_samples[i] for i in range(len(total_samples))]
149 | print("accuracy per task: ", accuracies)
150 | merged_accuracies = [100.0 * merged_total_correct[i] / total_samples[i] for i in range(len(total_samples))]
151 | print("merged_accuracy per task: ", merged_accuracies)
152 | avg_accuracy = sum(accuracies) / len(accuracies)
153 | avg_loss = total_loss / sum(total_samples)
154 |
155 | merged_avg_accuracy = sum(merged_accuracies) / len(merged_accuracies)
156 | merged_avg_loss = merged_total_loss / sum(total_samples)
157 |
158 | return avg_loss, avg_accuracy, merged_avg_loss, merged_avg_accuracy
159 |
160 | class WeightingModel(nn.Module):
161 | def __init__(self, input_dim=512, num_models=6):
162 | super(WeightingModel, self).__init__()
163 | self.num_models = num_models
164 | self.fc = nn.Linear(input_dim * num_models, num_models)
165 | def forward(self, x):
166 |
167 | logits = self.fc(x)
168 | weights = F.softmax(logits, dim=1)
169 | return weights
170 |
171 |
172 | class FusionModel(nn.Module):
173 | def __init__(self, ckpts, models, prediction_heads, input_dim=1024): # ViT-L-14: 768, RN50: 1024, ViT-B-32: 512
174 | super(FusionModel, self).__init__()
175 | self.models = models
176 | self.prediction_heads = prediction_heads
177 | self.num_models = len(models)
178 | self.weighting_model = WeightingModel(input_dim=input_dim, num_models=self.num_models)
179 |
180 | self.ckpts = ckpts
181 | self.flat_ft = torch.vstack([state_dict_to_vector(check, []).to('cpu') for check in self.ckpts]).to('cpu')
182 | self.mean_ft = torch.mean(self.flat_ft, dim=0)
183 | self.diff_ft = self.flat_ft - self.mean_ft.unsqueeze(0) # ksi
184 | self.sum_ft = torch.sum(self.diff_ft, dim=1).to(device)
185 |
186 | mean = torch.mean(self.sum_ft)
187 | std = torch.std(self.sum_ft)
188 | self.sum_ft = (self.sum_ft - mean) / std
189 |
190 | def forward(self, inputs, dataset_index):
191 | model_outputs = []
192 | self.weighting_model.train()
193 | for i, model in enumerate(self.models):
194 | model.eval()
195 | with torch.no_grad():
196 | output = model(inputs)
197 | model_outputs.append(output)
198 |
199 | stacked_outputs = torch.cat(model_outputs, dim=1)
200 |
201 | weights = self.weighting_model(stacked_outputs)
202 | reg_loss = torch.matmul(weights, self.sum_ft)
203 |
204 | tensor_sum = torch.sum(weights)
205 |
206 | weighted_sum = 0
207 | for i in range(self.num_models):
208 | weighted_output = model_outputs[i] * weights[:, i].unsqueeze(1)
209 | weighted_sum += weighted_output
210 | final_output = self.prediction_heads[dataset_index](weighted_sum)
211 | return final_output, reg_loss
212 |
213 | def get_weighting_model(self):
214 | return self.weighting_model
215 |
216 | args = parse_arguments()
217 | args.save = './checkpoints/{}'.format(args.model)
218 |
219 | exam_datasets = ['GTSRB', 'CIFAR100', 'RESISC45', 'CIFAR10', 'MNIST', 'STL10', 'SVHN']
220 | num_classes = [43, 100, 45, 10, 10, 10, 10]
221 | use_merged_model = True
222 |
223 | classification_heads = [get_classification_head(args, dataset_name).to(device) for dataset_name in exam_datasets]
224 |
225 | import itertools
226 | exam_datasets_list = [list(comb) for comb in itertools.combinations(exam_datasets, args.num_co_models)]
227 | num_classes_list = [list(comb) for comb in itertools.combinations(num_classes, args.num_co_models)]
228 | classification_heads_list = [list(comb) for comb in itertools.combinations(classification_heads, args.num_co_models)]
229 |
230 | for mm in range(len(exam_datasets_list)):
231 | exam_datasets = exam_datasets_list[mm]
232 | num_classes = num_classes_list[mm]
233 | classification_heads = classification_heads_list[mm]
234 |
235 | args.save = os.path.join(args.ckpt_dir,args.model)
236 | args.save = './checkpoints/{}'.format(args.model)
237 | pretrained_checkpoint = os.path.join(args.save, 'zeroshot.pt')
238 | image_encoder = torch.load(pretrained_checkpoint)
239 | image_encoder_shadow = torch.load(pretrained_checkpoint)
240 | ptm_check = torch.load(pretrained_checkpoint).state_dict()
241 |
242 | from tm_utils import *
243 | ft_checks, ft_checks_shadow = [], []
244 | ft_archs, ft_archs_shadow = [], []
245 |
246 | for dataset_name in exam_datasets:
247 | ckpt_name = os.path.join(args.save, dataset_name, 'finetuned.pt')
248 | ckpt_name_shadow = os.path.join(args.save, dataset_name, 'finetuned_dev.pt')
249 | ft_archs.append(torch.load(ckpt_name).to(device))
250 | ft_archs_shadow.append(torch.load(ckpt_name_shadow).to(device))
251 | ft_checks.append(torch.load(ckpt_name).state_dict())
252 | ft_checks_shadow.append(torch.load(ckpt_name_shadow).state_dict())
253 | print(ckpt_name)
254 | print(ckpt_name_shadow)
255 |
256 | if args.model == 'RN50':
257 | fusion_model = FusionModel(ft_checks, ft_archs, classification_heads, input_dim=1024)
258 | fusion_model_shadow = FusionModel(ft_checks_shadow, ft_archs_shadow, classification_heads, input_dim=1024)
259 | elif args.model == 'ViT-B-32':
260 | fusion_model = FusionModel(ft_checks, ft_archs, classification_heads, input_dim=512)
261 | fusion_model_shadow = FusionModel(ft_checks_shadow, ft_archs_shadow, classification_heads, input_dim=512)
262 | elif args.model == 'ViT-L-14':
263 | fusion_model = FusionModel(ft_checks, ft_archs, classification_heads, input_dim=768)
264 | fusion_model_shadow = FusionModel(ft_checks_shadow, ft_archs_shadow, classification_heads, input_dim=768)
265 | test_loaders, train_loaders, shadowtrain_loaders, shadowtest_loaders, adv_test_loaders = [], [], [], [], []
266 |
267 | for num_ld in range(len(exam_datasets)):
268 | dataset_save_dir = os.path.join("{}/{}/dataset_splits".format(args.save, exam_datasets[num_ld]))
269 | print("cur_process_dataset: ", dataset_save_dir)
270 | if check_datasets_exist(dataset_save_dir):
271 | print("Subsets already exits...")
272 | from torch.utils.data import DataLoader
273 |
274 | train_dataset, test_dataset, shadowtrain_dataset, shadowtest_dataset = load_datasets(dataset_save_dir)
275 |
276 | train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
277 |
278 | test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=True)
279 |
280 | shadowtrain_loader = DataLoader(shadowtrain_dataset, batch_size=args.batch_size, shuffle=True)
281 |
282 | shadowtest_loader = DataLoader(shadowtest_dataset, batch_size=args.batch_size, shuffle=True)
283 |
284 | test_loaders.append(test_loader)
285 | train_loaders.append(train_loader)
286 |
287 | print("dataset: {}, train_length: {}, test_length: {}, shadowtrain_length: {}, shadowtest_length: {}".format(exam_datasets[num_ld], len(train_dataset), len(test_dataset), len(shadowtrain_dataset), len(shadowtest_dataset)))
288 |
289 | fusion_model = fusion_model.to(device)
290 | fusion_model_shadow = fusion_model_shadow.to(device)
291 |
292 | optimizer = optim.Adam(fusion_model.weighting_model.parameters(), lr=0.001)
293 | optimizer_shadow = optim.Adam(fusion_model_shadow.weighting_model.parameters(), lr=0.001)
294 |
295 | criterion = nn.CrossEntropyLoss()
296 | criterion_reg = nn.MSELoss()
297 |
298 | print("#########################################################")
299 | print("###############PortLand Training Begins##################")
300 | print("#########################################################")
301 | avg_loss, accuracy, merged_avg_loss, merged_accuracy = evaluate_ori(fusion_model, test_loaders, criterion, device)
302 | print(f"Initial Evaluation - Avg Loss: {avg_loss:.4f}, Merged Avg Loss: {merged_avg_loss:.4f}, Ensembling Accuracy: {accuracy:.2f}%, Merging Accuracy: {merged_accuracy:.2f}%")
303 | best_accuracy = 0.0
304 | for glb_ep in range(args.global_epoch):
305 | fusion_model.train()
306 | loaders_cycle = [cycle(loader) for loader in train_loaders]
307 | total_batches = min(len(loader) for loader in train_loaders)
308 |
309 | for batch_idx in range(total_batches):
310 | for loader_idx, loader in enumerate(loaders_cycle):
311 | data = next(loader)
312 |
313 | data = maybe_dictionarize(data)
314 | inputs = data['images'].to(device)
315 | labels = data['labels'].to(device)
316 |
317 | outputs, reg_loss = fusion_model(inputs, dataset_index=loader_idx)
318 |
319 | target = torch.zeros_like(reg_loss).to(device)
320 | loss_reg = criterion_reg(reg_loss, target) / args.scaling
321 |
322 | if args.alignment_type == 'sup':
323 | loss_ce = criterion(outputs, labels)
324 | elif args.alignment_type == 'semi': # semi-supervised (entropy minimization)
325 | probs = F.softmax(outputs, dim=1)
326 | loss_ce = -torch.mean(torch.sum(probs * torch.log(probs + 1e-6), dim=1))
327 |
328 | loss = loss_ce + loss_reg
329 |
330 | print(f"Epoch: {glb_ep}, Current Dataset Index: {loader_idx}, Batch: {batch_idx + 1}/{total_batches}, Loss: {loss.item():.4f}")
331 |
332 | optimizer.zero_grad()
333 | loss.backward()
334 | optimizer.step()
335 |
336 | if (glb_ep+1)%10==0:
337 | avg_loss, accuracy, merged_avg_loss, merged_accuracy = evaluate_ori(
338 | fusion_model, test_loaders, criterion, device
339 | )
340 | print(
341 | f"Epoch [{glb_ep + 1}/{args.global_epoch}] Evaluation - "
342 | f"Ensembling Avg Loss: {avg_loss:.4f}, Merging Avg Loss: {merged_avg_loss:.4f}, "
343 | f"Ensembling Accuracy: {accuracy:.2f}%, Merging Accuracy: {merged_accuracy:.2f}%"
344 | )
345 |
346 | if merged_accuracy > best_accuracy:
347 | best_accuracy = merged_accuracy
348 | print(f"New best model found with accuracy: {best_accuracy:.2f}%")
349 |
350 | print("#########################################################")
351 | print("################PortLand Training Ends###################")
352 | print("#########################################################")
353 |
354 |
--------------------------------------------------------------------------------
/src/pgbar.py:
--------------------------------------------------------------------------------
1 | import os, torch
2 | import sys
3 | import time
4 |
5 | _, term_width = os.popen('stty size', 'r').read().split()
6 | term_width = int(term_width)
7 | TOTAL_BAR_LENGTH = 65.
8 | last_time = time.time()
9 | begin_time = last_time
10 |
11 |
12 | def progress_bar(current, total, msg=None):
13 | global last_time, begin_time
14 | if current == 0:
15 | begin_time = time.time() # Reset for new bar.
16 |
17 | cur_len = int(TOTAL_BAR_LENGTH * current / total)
18 | rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1
19 |
20 | sys.stdout.write(' [')
21 | for i in range(cur_len):
22 | sys.stdout.write('=')
23 | sys.stdout.write('>')
24 | for i in range(rest_len):
25 | sys.stdout.write('.')
26 | sys.stdout.write(']')
27 |
28 | cur_time = time.time()
29 | step_time = cur_time - last_time
30 | last_time = cur_time
31 | tot_time = cur_time - begin_time
32 |
33 | L = []
34 | L.append(' Step: %s' % format_time(step_time))
35 | L.append(' | Tot: %s' % format_time(tot_time))
36 | if msg:
37 | L.append(' | ' + msg)
38 |
39 | msg = ''.join(L)
40 | sys.stdout.write(msg)
41 | for i in range(term_width - int(TOTAL_BAR_LENGTH) - len(msg) - 3):
42 | sys.stdout.write(' ')
43 |
44 | # Go back to the center of the bar.
45 | for i in range(term_width - int(TOTAL_BAR_LENGTH / 2) + 2):
46 | sys.stdout.write('\b')
47 | sys.stdout.write(' %d/%d ' % (current + 1, total))
48 |
49 | if current < total - 1:
50 | sys.stdout.write('\r')
51 | else:
52 | sys.stdout.write('\n')
53 | sys.stdout.flush()
54 |
55 |
56 | def format_time(seconds):
57 | days = int(seconds / 3600 / 24)
58 | seconds = seconds - days * 3600 * 24
59 | hours = int(seconds / 3600)
60 | seconds = seconds - hours * 3600
61 | minutes = int(seconds / 60)
62 | seconds = seconds - minutes * 60
63 | secondsf = int(seconds)
64 | seconds = seconds - secondsf
65 | millis = int(seconds * 1000)
66 |
67 | f = ''
68 | i = 1
69 | if days > 0:
70 | f += str(days) + 'D'
71 | i += 1
72 | if hours > 0 and i <= 2:
73 | f += str(hours) + 'h'
74 | i += 1
75 | if minutes > 0 and i <= 2:
76 | f += str(minutes) + 'm'
77 | i += 1
78 | if secondsf > 0 and i <= 2:
79 | f += str(secondsf) + 's'
80 | i += 1
81 | if millis > 0 and i <= 2:
82 | f += str(millis) + 'ms'
83 | i += 1
84 | if f == '':
85 | f = '0ms'
86 | return f
--------------------------------------------------------------------------------
/src/task_vectors.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | class TaskVector():
5 | def __init__(self, pretrained_checkpoint=None, finetuned_checkpoint=None, vector=None):
6 | """Initializes the task vector from a pretrained and a finetuned checkpoints.
7 |
8 | This can either be done by passing two state dicts (one corresponding to the
9 | pretrained model, and another to the finetuned model), or by directly passying in
10 | the task vector state dict.
11 | """
12 | if vector is not None:
13 | self.vector = vector
14 | else:
15 | print(pretrained_checkpoint, finetuned_checkpoint)
16 | assert pretrained_checkpoint is not None and finetuned_checkpoint is not None
17 | with torch.no_grad():
18 | print('TaskVector:' + finetuned_checkpoint)
19 | pretrained_state_dict = torch.load(pretrained_checkpoint).state_dict()
20 | finetuned_state_dict = torch.load(finetuned_checkpoint).state_dict()
21 | self.vector = {}
22 | for key in pretrained_state_dict:
23 | if pretrained_state_dict[key].dtype in [torch.int64, torch.uint8]:
24 | continue
25 | self.vector[key] = finetuned_state_dict[key] - pretrained_state_dict[key]
26 | print(len(self.vector))
27 |
28 | def __add__(self, other):
29 | """Add two task vectors together."""
30 | with torch.no_grad():
31 | new_vector = {}
32 | for key in self.vector:
33 | if key not in other.vector:
34 | print(f'Warning, key {key} is not present in both task vectors.')
35 | continue
36 | new_vector[key] = self.vector[key] + other.vector[key]
37 | return TaskVector(vector=new_vector)
38 |
39 | def __radd__(self, other):
40 | if other is None or isinstance(other, int):
41 | return self
42 | return self.__add__(other)
43 |
44 | def __neg__(self):
45 | """Negate a task vector."""
46 | with torch.no_grad():
47 | new_vector = {}
48 | for key in self.vector:
49 | new_vector[key] = - self.vector[key]
50 | return TaskVector(vector=new_vector)
51 |
52 | def weightmerging(self, taskvectors, coefficients):
53 | with torch.no_grad():
54 | new_vector = {}
55 | for key in taskvectors[0].vector:
56 | new_vector[key] = sum(coefficients[k] * taskvectors[k][key] for k in range(len(taskvectors)))
57 | return TaskVector(vector=new_vector)
58 |
59 | def apply_to(self, pretrained_checkpoint, scaling_coef=1.0):
60 | """Apply a task vector to a pretrained model."""
61 | with torch.no_grad():
62 | pretrained_model = torch.load(pretrained_checkpoint)
63 | new_state_dict = {}
64 | pretrained_state_dict = pretrained_model.state_dict()
65 | for key in pretrained_state_dict:
66 | if key not in self.vector:
67 | print(f'Warning: key {key} is present in the pretrained state dict but not in the task vector')
68 | continue
69 | new_state_dict[key] = pretrained_state_dict[key] + scaling_coef * self.vector[key]
70 | pretrained_model.load_state_dict(new_state_dict, strict=False)
71 | return pretrained_model
72 |
73 |
--------------------------------------------------------------------------------
/src/tm_utils.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os, copy
3 | import torch
4 | import matplotlib.pyplot as plt
5 | import numpy as np
6 | import re
7 | from collections import OrderedDict
8 | import torch.nn.functional as F
9 | # from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
10 |
11 | ## Model conversion utils
12 | def state_dict_to_vector(state_dict, remove_keys=[]):
13 | shared_state_dict = copy.deepcopy(state_dict)
14 | for key in remove_keys:
15 | if key in shared_state_dict:
16 | del shared_state_dict[key]
17 | sorted_shared_state_dict = OrderedDict(sorted(shared_state_dict.items()))
18 | return torch.nn.utils.parameters_to_vector(
19 | [value.reshape(-1) for key, value in sorted_shared_state_dict.items()]
20 | )
21 |
22 |
23 | def vector_to_state_dict(vector, state_dict, remove_keys=[]):
24 | # create a reference dict to define the order of the vector
25 | reference_dict = copy.deepcopy(state_dict)
26 | for key in remove_keys:
27 | if key in reference_dict:
28 | del reference_dict[key]
29 | sorted_reference_dict = OrderedDict(sorted(reference_dict.items()))
30 |
31 | # create a shared state dict using the refence dict
32 | torch.nn.utils.vector_to_parameters(vector, sorted_reference_dict.values())
33 |
34 | # add back the encoder and decoder embedding weights.
35 | if "transformer.shared.weight" in sorted_reference_dict:
36 | for key in remove_keys:
37 | sorted_reference_dict[key] = sorted_reference_dict[
38 | "transformer.shared.weight"
39 | ]
40 | return sorted_reference_dict
41 |
42 |
43 | def add_ptm_to_tv(tv_dict, ptm_dict):
44 | assert set(tv_dict.keys()) == set(
45 | ptm_dict.keys()
46 | ), "Differing parameter names in models."
47 | final_dict = copy.deepcopy(tv_dict)
48 | for k, v in ptm_dict.items():
49 | final_dict[k] = tv_dict[k] + v
50 | return final_dict
51 |
52 |
53 | def check_parameterNamesMatch(checkpoints):
54 | parameter_names = set(checkpoints[0].keys())
55 |
56 | if len(checkpoints) >= 2:
57 | # raise ValueError("Number of models is less than 2.")
58 | for checkpoint in checkpoints[1:]:
59 | current_parameterNames = set(checkpoint.keys())
60 | if current_parameterNames != parameter_names:
61 | raise ValueError(
62 | "Differing parameter names in models. "
63 | f"The different parameters are {parameter_names.symmetric_difference(current_parameterNames)}"
64 | )
65 |
66 | def check_state_dicts_equal(state_dict1, state_dict2):
67 | if set(state_dict1.keys()) != set(state_dict2.keys()):
68 | return False
69 |
70 | for key in state_dict1.keys():
71 | if not torch.equal(state_dict1[key], state_dict2[key]):
72 | return False
73 |
74 | return True
75 |
76 |
77 |
78 | ## TIES MERGING UTILS
79 |
80 | def topk_values_mask(M, K=0.7, return_mask=False):
81 | if K > 1:
82 | K /= 100
83 |
84 | original_shape = M.shape
85 | if M.dim() == 1:
86 | M = M.unsqueeze(0)
87 |
88 | n, d = M.shape
89 | k = int(d * K)
90 | k = d - k # Keep top k elements instead of bottom k elements
91 |
92 | # Find the k-th smallest element by magnitude for each row
93 | kth_values, _ = M.abs().kthvalue(k, dim=1, keepdim=True)
94 | # Create a mask tensor with True for the top k elements in each row
95 | mask = M.abs() >= kth_values
96 | final_mask = mask.squeeze() if original_shape == M.squeeze().shape else mask
97 |
98 | if return_mask:
99 | return M * final_mask, final_mask.float().mean(dim=1), final_mask
100 | return M * final_mask, final_mask.float().mean(dim=1)
101 |
102 |
103 | def resolve_zero_signs(sign_to_mult, method="majority"):
104 | majority_sign = torch.sign(sign_to_mult.sum())
105 |
106 | if method == "majority":
107 | sign_to_mult[sign_to_mult == 0] = majority_sign
108 | elif method == "minority":
109 | sign_to_mult[sign_to_mult == 0] = -1 * majority_sign
110 | return sign_to_mult
111 |
112 |
113 | def resolve_sign(Tensor):
114 | sign_to_mult = torch.sign(Tensor.sum(dim=0))
115 | sign_to_mult = resolve_zero_signs(sign_to_mult, "majority")
116 | return sign_to_mult
117 |
118 |
119 | def disjoint_merge(Tensor, merge_func, sign_to_mult):
120 | merge_func = merge_func.split("-")[-1]
121 |
122 | # If sign is provided then we select the corresponding entries and aggregate.
123 | if sign_to_mult is not None:
124 | rows_to_keep = torch.where(
125 | sign_to_mult.unsqueeze(0) > 0, Tensor > 0, Tensor < 0
126 | )
127 | selected_entries = Tensor * rows_to_keep
128 | # Else we select all non-zero entries and aggregate.
129 | else:
130 | rows_to_keep = Tensor != 0
131 | selected_entries = Tensor * rows_to_keep
132 |
133 | if merge_func == "mean":
134 | non_zero_counts = (selected_entries != 0).sum(dim=0).float()
135 | disjoint_aggs = torch.sum(selected_entries, dim=0) / torch.clamp(non_zero_counts, min=1)
136 | elif merge_func == "sum":
137 | disjoint_aggs = torch.sum(selected_entries, dim=0)
138 | elif merge_func == "max":
139 | disjoint_aggs = selected_entries.abs().max(dim=0)[0]
140 | disjoint_aggs *= sign_to_mult
141 | else:
142 | raise ValueError(f"Merge method {merge_func} is not defined.")
143 |
144 | return disjoint_aggs
145 |
146 |
147 | def ties_merging(
148 | flat_task_checks,
149 | reset_thresh=None,
150 | merge_func="",
151 | ):
152 | all_checks = flat_task_checks.clone()
153 | updated_checks, *_ = topk_values_mask(
154 | all_checks, K=reset_thresh, return_mask=False
155 | )
156 | print(f"RESOLVING SIGN")
157 | final_signs = resolve_sign(updated_checks)
158 | assert final_signs is not None
159 |
160 | print(f"Disjoint AGGREGATION: {merge_func}")
161 | merged_tv = disjoint_merge(updated_checks, merge_func, final_signs)
162 |
163 | return merged_tv
164 |
165 | def disjoint_merge_split(Tensor, merge_func, sign_to_mult):
166 | merge_func = merge_func.split("-")[-1]
167 |
168 | # If sign is provided then we select the corresponding entries and aggregate.
169 | if sign_to_mult is not None:
170 | rows_to_keep = torch.where(
171 | sign_to_mult.unsqueeze(0) > 0, Tensor > 0, Tensor < 0
172 | )
173 | selected_entries = Tensor * rows_to_keep
174 | # Else we select all non-zero entries and aggregate.
175 | else:
176 | rows_to_keep = Tensor != 0
177 | selected_entries = Tensor * rows_to_keep
178 |
179 | if merge_func == "sum":
180 | disjoint_aggs = torch.sum(selected_entries, dim=0)
181 | else:
182 | raise ValueError(f"Merge method {merge_func} is not defined.")
183 |
184 | return selected_entries, disjoint_aggs
185 |
186 |
187 | def ties_merging_split(
188 | flat_task_checks,
189 | reset_thresh=None,
190 | merge_func="",
191 | ):
192 | all_checks = flat_task_checks.clone()
193 | updated_checks, *_ = topk_values_mask(
194 | all_checks, K=reset_thresh, return_mask=False
195 | )
196 | print(f"RESOLVING SIGN")
197 | final_signs = resolve_sign(updated_checks)
198 | assert final_signs is not None
199 |
200 | print(f"Disjoint AGGREGATION: {merge_func}")
201 | selected_entries, merged_tv = disjoint_merge_split(updated_checks, merge_func, final_signs)
202 |
203 | return selected_entries, merged_tv
204 |
--------------------------------------------------------------------------------
/src/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import pickle
4 | import math
5 | import numpy as np
6 | import torchvision
7 |
8 | class NormalizeInverse(torchvision.transforms.Normalize):
9 | def __init__(self, mean, std):
10 | mean = torch.as_tensor(mean)
11 | std = torch.as_tensor(std)
12 | std_inv = 1 / (std + 1e-7)
13 | mean_inv = -mean * std_inv
14 | super().__init__(mean=mean_inv, std=std_inv)
15 |
16 | def __call__(self, tensor):
17 | return super().__call__(tensor.clone())
18 |
19 | def corner_mask_generation(patch=None, image_size=(3, 224, 224)):
20 | applied_patch = np.zeros(image_size)
21 | x_location = image_size[1]-patch.shape[1]
22 | y_location = image_size[2]-patch.shape[2]
23 | applied_patch[:, x_location:x_location + patch.shape[1], y_location:y_location + patch.shape[2]] = patch
24 | mask = applied_patch.copy()
25 | mask[mask != 0] = 1.0
26 | return applied_patch, mask, x_location, y_location
27 |
28 | def assign_learning_rate(param_group, new_lr):
29 | param_group["lr"] = new_lr
30 |
31 |
32 | def _warmup_lr(base_lr, warmup_length, step):
33 | return base_lr * (step + 1) / warmup_length
34 |
35 |
36 | def cosine_lr(optimizer, base_lrs, warmup_length, steps):
37 | if not isinstance(base_lrs, list):
38 | base_lrs = [base_lrs for _ in optimizer.param_groups]
39 | assert len(base_lrs) == len(optimizer.param_groups)
40 | def _lr_adjuster(step):
41 | for param_group, base_lr in zip(optimizer.param_groups, base_lrs):
42 | if step < warmup_length:
43 | lr = _warmup_lr(base_lr, warmup_length, step)
44 | else:
45 | e = step - warmup_length
46 | es = steps - warmup_length
47 | lr = 0.5 * (1 + np.cos(np.pi * e / es)) * base_lr
48 | assign_learning_rate(param_group, lr)
49 | return _lr_adjuster
50 |
51 |
52 | def accuracy(output, target, topk=(1,)):
53 | pred = output.topk(max(topk), 1, True, True)[1].t()
54 | correct = pred.eq(target.view(1, -1).expand_as(pred))
55 | return [float(correct[:k].reshape(-1).float().sum(0, keepdim=True).cpu().numpy()) for k in topk]
56 |
57 |
58 | def torch_load_old(save_path, device=None):
59 | with open(save_path, 'rb') as f:
60 | classifier = pickle.load(f)
61 | if device is not None:
62 | classifier = classifier.to(device)
63 | return classifier
64 |
65 |
66 | def torch_save(model, save_path):
67 | if os.path.dirname(save_path) != '':
68 | os.makedirs(os.path.dirname(save_path), exist_ok=True)
69 | torch.save(model.cpu(), save_path)
70 |
71 |
72 | def torch_load(save_path, device=None):
73 | model = torch.load(save_path)
74 | if device is not None:
75 | model = model.to(device)
76 | return model
77 |
78 |
79 |
80 | def get_logits(inputs, classifier):
81 | assert callable(classifier)
82 | if hasattr(classifier, 'to'):
83 | classifier = classifier.to(inputs.device)
84 | return classifier(inputs)
85 |
86 |
87 | def get_probs(inputs, classifier):
88 | if hasattr(classifier, 'predict_proba'):
89 | probs = classifier.predict_proba(inputs.detach().cpu().numpy())
90 | return torch.from_numpy(probs)
91 | logits = get_logits(inputs, classifier)
92 | return logits.softmax(dim=1)
93 |
94 |
95 | class LabelSmoothing(torch.nn.Module):
96 | def __init__(self, smoothing=0.0):
97 | super(LabelSmoothing, self).__init__()
98 | self.confidence = 1.0 - smoothing
99 | self.smoothing = smoothing
100 |
101 | def forward(self, x, target):
102 | logprobs = torch.nn.functional.log_softmax(x, dim=-1)
103 |
104 | nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1))
105 | nll_loss = nll_loss.squeeze(1)
106 | smooth_loss = -logprobs.mean(dim=-1)
107 | loss = self.confidence * nll_loss + self.smoothing * smooth_loss
108 | return loss.mean()
109 |
--------------------------------------------------------------------------------