├── README.md ├── assets └── image.png ├── dataloader ├── DataSets.py ├── __init__.py ├── augmentations.py ├── class_maps.py ├── jigsaw │ ├── jigsaw_process.py │ ├── permutations_100.npy │ ├── permutations_1000.npy │ ├── permutations_30.npy │ └── permutations_31.npy ├── text_list_generator.py └── text_lists │ ├── MiniDomainNet │ ├── classes.txt │ ├── clipart_test.txt │ ├── clipart_test_shuffled.txt │ ├── clipart_train.txt │ ├── painting_test.txt │ ├── painting_test_shuffled.txt │ ├── painting_train.txt │ ├── real_test.txt │ ├── real_test_shuffled.txt │ ├── real_train.txt │ ├── sketch_test.txt │ ├── sketch_test_shuffled.txt │ └── sketch_train.txt │ ├── OfficeHome │ ├── Art_test.txt │ ├── Art_test_shuffled.txt │ ├── Art_train.txt │ ├── Clipart_test.txt │ ├── Clipart_test_shuffled.txt │ ├── Clipart_train.txt │ ├── Product_test.txt │ ├── Product_test_shuffled.txt │ ├── Product_train.txt │ ├── RealWorld_test.txt │ ├── RealWorld_test_shuffled.txt │ ├── RealWorld_train.txt │ └── classes.txt │ ├── PACS │ ├── art_painting_test.txt │ ├── art_painting_test_shuffled.txt │ ├── art_painting_train.txt │ ├── art_painting_val.txt │ ├── cartoon_test.txt │ ├── cartoon_test_shuffled.txt │ ├── cartoon_train.txt │ ├── cartoon_val.txt │ ├── classes.txt │ ├── photo_test.txt │ ├── photo_test_shuffled.txt │ ├── photo_train.txt │ ├── photo_val.txt │ ├── sketch_test.txt │ ├── sketch_test_shuffled.txt │ ├── sketch_train.txt │ └── sketch_val.txt │ └── VLCS │ ├── CALTECH_test.txt │ ├── CALTECH_test_shuffled.txt │ ├── CALTECH_train.txt │ ├── CALTECH_val.txt │ ├── LABELME_test.txt │ ├── LABELME_test_shuffled.txt │ ├── LABELME_train.txt │ ├── LABELME_val.txt │ ├── PASCAL_test.txt │ ├── PASCAL_test_shuffled.txt │ ├── PASCAL_train.txt │ ├── PASCAL_val.txt │ ├── SUN_test.txt │ ├── SUN_test_shuffled.txt │ ├── SUN_train.txt │ ├── SUN_val.txt │ └── classes.txt ├── framework ├── ERM.py ├── __init__.py ├── backbones.py ├── basic_train_funcs.py ├── engine.py ├── exp.py ├── log.py ├── loss_and_acc.py ├── meta_util.py └── registry.py ├── main.py ├── models ├── AdaptorHeads.py ├── AdaptorHelper.py ├── DomainAdaptor.py ├── LAME.py ├── MetaModel.py ├── TTA.py └── __init__.py ├── script ├── TTA.sh ├── TTA_jigsaw.sh ├── TTA_rotation.sh ├── deepall.sh └── meta.sh └── utils ├── draw_figures.py ├── tensor_utils.py └── visualize.py /README.md: -------------------------------------------------------------------------------- 1 | # DomainAdaptor 2 | 3 | The implementation of ICCV 2023 paper 《 [DomainAdaptor: A Novel Approach to Test-time Adaptation](https://arxiv.org/abs/2308.10297) 》 4 | 5 | ![](assets/image.png) 6 | 7 | ## Install packages 8 | 9 | ```bash 10 | conda install pytorch torchvision cudatoolkit 11 | conda install matplotlib tqdm tensorboardX 12 | ``` 13 | 14 | 15 | ## Dataset structure 16 | 17 | ``` 18 | PACS 19 | ├── kfold 20 | │ ├── art_painting 21 | │ ├── cartoon 22 | │ ├── photo 23 | │ └── sketch 24 | VLCS 25 | ├── CALTECH 26 | │ ├── crossval 27 | │ ├── full 28 | │ ├── test 29 | │ └── train 30 | ├── LABELME 31 | │ ├── crossval 32 | │ ├── full 33 | | ... 34 | OfficeHome 35 | ├── Art 36 | │ ├── Alarm_Clock 37 | │ ├── Backpack 38 | │ ├── Batteries 39 | │ ├── Bed 40 | │ ├── Bike 41 | │ ├── Bottle 42 | | ... 43 | ``` 44 | 45 | The data root can be modified in [main.py](main.py) or pase the args `--data-root your_data_root`. 46 | 47 | ## Run the code 48 | 49 | The code of DomainAdaptor is in [models/DomainAdaptor.py](models/DomainAdaptor.py). 50 | 51 | The pretrained deepall models are available at [Google Drive](https://drive.google.com/drive/folders/1Ne7FiEVv45JHJqELZ1F_0c5cknwyal40?usp=drive_link). 52 | Or you can train the deepall models by yourself with the following code: 53 | 54 | ```bash 55 | bash script/deepall.sh 56 | ``` 57 | 58 | With the pretrained models, you can run the following code to evaluate with DomainAdaptor: 59 | 60 | ```bash 61 | bash script/TTA.sh 62 | ``` 63 | 64 | ## Citation 65 | 66 | ``` 67 | @inproceedings{zhang2023domainadaptor, 68 | title={DomainAdaptor: A Novel Approach to Test-time Adaptation}, 69 | author={Zhang, Jian and Qi, Lei and Shi, Yinghuan and Gao, Yang}, 70 | bootitle={Proceedings of the IEEE/CVF International Conference on Computer Vision}, 71 | year={2023} 72 | } 73 | ``` 74 | 75 | -------------------------------------------------------------------------------- /assets/image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/koncle/DomainAdaptor/5c3d24efb465a56a9a8b57454ea0d23b4ed819c8/assets/image.png -------------------------------------------------------------------------------- /dataloader/DataSets.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import os 3 | import threading 4 | from pathlib import Path 5 | 6 | import torch 7 | import functools 8 | import numpy as np 9 | from copy import deepcopy 10 | from torchvision import transforms 11 | from torch.utils.data import Dataset, ConcatDataset, DataLoader 12 | from torchvision.datasets.folder import default_loader as img_loader 13 | 14 | from dataloader.augmentations import RandAugment, Rotation, TestTimeAug 15 | from dataloader.jigsaw.jigsaw_process import JigsawDataset 16 | from framework.registry import Datasets 17 | from utils.tensor_utils import Timer 18 | 19 | IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp') 20 | 21 | 22 | class MetaDataLoader(DataLoader): 23 | def __init__(self, *args, **kwargs): 24 | super(MetaDataLoader, self).__init__(*args, **kwargs) 25 | self.iter = None 26 | 27 | def __next__(self): 28 | return self.next() 29 | 30 | def next(self): 31 | if self.iter is None: 32 | self.iter = iter(self) 33 | try: 34 | ret = next(self.iter) 35 | except: 36 | self.iter = iter(self) 37 | ret = next(self.iter) 38 | return ret 39 | 40 | 41 | DataSource = None 42 | ToMemory = False 43 | 44 | 45 | class DGDataset(Dataset): 46 | def __init__(self, samples, split, args, extra_aug_func_dict=None): 47 | self.args = args 48 | self.samples = samples 49 | self.img_size = self.args.img_size 50 | self.min_scale = self.args.min_scale 51 | self.pre_transform, self.transform = self.set_transform(split) 52 | self.extra_aug_func_dict = extra_aug_func_dict 53 | self.augmentations = { 54 | 'tta': TestTimeAug(args, jitter=False, randaug=False), 55 | 'rot': Rotation(), 56 | 'jigsaw': JigsawDataset(jig_classes=31) 57 | } 58 | 59 | def __len__(self): 60 | return len(self.samples) 61 | 62 | def set_transform(self, split): 63 | transform = [ 64 | transforms.ToTensor(), 65 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 66 | ] 67 | 68 | if split == 'train' and not self.args.do_not_transform: 69 | pre_transform = transforms.Compose([ 70 | transforms.RandomResizedCrop(self.img_size, scale=(self.min_scale, 1.0)), 71 | transforms.RandomHorizontalFlip()] 72 | ) 73 | if self.args.color_jitter: 74 | transform.insert(0, transforms.ColorJitter(.4, .4, .4, .4)) 75 | else: 76 | pre_transform = transforms.Compose([ 77 | transforms.Resize((self.img_size, self.img_size)) 78 | ]) 79 | 80 | return pre_transform, transforms.Compose(transform) 81 | 82 | def __getitem__(self, index): 83 | path, target, domain = self.samples[index] 84 | target, domain = int(target), int(domain) 85 | 86 | ret = {} 87 | if DataSource is not None: 88 | origin_image = DataSource[path] 89 | else: 90 | origin_image = img_loader(path) 91 | 92 | image_o = self.transform(self.pre_transform(origin_image)) 93 | ret.update({'x': image_o, 'label': target}) 94 | 95 | if self.args.domain_label: 96 | ret.update({'domain_label': torch.tensor(domain).long()}) 97 | 98 | if self.args.TTAug: 99 | ret.update({'tta': self.augmentations['tta'].test_aug(origin_image, self.args.TTA_bs)}) 100 | 101 | if self.extra_aug_func_dict is not None: 102 | for key, func in self.extra_aug_func_dict.items(): 103 | ret.update({key: func(origin_image)}) 104 | 105 | if self.args.jigsaw: 106 | jigsaw_data, jigsaw_label = self.augmentations['jigsaw'](image_o) 107 | ret.update({'jigsaw_x': jigsaw_data, 'jigsaw_label': jigsaw_label}) 108 | 109 | if self.args.rot: 110 | rot_data, rot_label = self.augmentations['rot'](image_o) 111 | ret.update({'rot_x': rot_data, 'rot_label': rot_label}) 112 | 113 | if self.args.data_path: 114 | ret.update({'data_path': path}) 115 | return ret 116 | 117 | 118 | class BaseDatasetConfig(object): 119 | Name = 'Base' 120 | NumClasses = 0 121 | Domains = [] 122 | SplitRatio = -1 123 | RelativePath = '' 124 | Classes = None 125 | ClassOffset = 0 126 | ClassMaps = None 127 | 128 | def __init__(self, args): 129 | self.args = args 130 | self.root_dir = os.path.join(args.data_root, self.RelativePath) 131 | self.source_domains, self.target_domains = self.split_train_test_domains(args) 132 | self.aug_funcs = None 133 | self.to_memory = ToMemory 134 | self.loader_args = { 135 | 'pin_memory': True, 136 | # shuffle=True is required for Tent. If set False, the data imbalance problem will be encountered. 137 | # 'shuffle': True, 138 | 'num_workers': self.args.workers, 139 | } 140 | 141 | def split_train_test_domains(self, args): 142 | if args.src[0] != -1: 143 | source_domain, target_domain = [], [] 144 | for i in args.src: 145 | source_domain.append(self.Domains[i]) 146 | for j in args.tgt: 147 | target_domain.append(self.Domains[j]) 148 | else: 149 | source_domain = deepcopy(self.Domains) 150 | target_domain = [source_domain.pop(args.exp_num[0])] 151 | 152 | print('Source domain: ', end=''), [print(domain, end=', ') for domain in source_domain] 153 | print('Target domain: ', end=''), [print(domain) for domain in target_domain] 154 | return source_domain, target_domain 155 | 156 | def load_classes(self): 157 | class_text_path = Path(__file__).parent / 'text_lists' / self.Name / 'classes.txt' 158 | with open(str(class_text_path), 'r') as f: 159 | classes = [line.strip() for line in f.readlines()] 160 | self.Classes = {i: c for i, c in enumerate(classes)} 161 | 162 | def load_text(self, domain, split, domain_idx): 163 | # text is in the same folder of this dataset 164 | filename = f'{domain}_{split}.txt' 165 | # use the fixed shuffled txt for test 166 | if self.args.shuffled and split == 'test': 167 | filename = f'{domain}_{split}_shuffled.txt' 168 | text_path = Path(__file__).parent / 'text_lists' / self.Name / filename 169 | # print(text_path, 'Loaded') 170 | samples = [] 171 | class_nums = [0] * self.NumClasses 172 | with open(str(text_path), 'r') as f: 173 | for line in f.readlines(): 174 | path, claz = line.strip().split(' ') 175 | 176 | if self.ClassMaps is not None: 177 | claz = self.ClassMaps[path.split('/')[1]] 178 | 179 | path = os.path.join(self.root_dir, path) 180 | class_nums[int(claz)+self.ClassOffset] += 1 181 | if self.args.small_dataset and class_nums[int(claz)+self.ClassOffset] > 100: 182 | continue 183 | samples.append((path, int(claz) + self.ClassOffset, domain_idx)) 184 | return samples 185 | 186 | def load_dataset(self, mode): 187 | assert mode in ['train', 'val', 'test'] 188 | datasets = [] 189 | domains = self.source_domains if mode != 'test' else self.target_domains 190 | all_samples = [] 191 | for i, d in enumerate(domains): 192 | samples = self.load_text(d, mode, i) 193 | dataset = DGDataset(samples, mode, self.args, self.aug_funcs) 194 | datasets.append(dataset) 195 | all_samples.extend(samples) 196 | print(f"{mode}: len({d})={len(samples)}", end=', ') 197 | print() 198 | return datasets, all_samples 199 | 200 | def preload_images(self, samples): 201 | global DataSource 202 | if DataSource is None: 203 | DataSource = {} 204 | with Timer(): 205 | for i, (path, _, _) in enumerate(samples): 206 | # print('{}/{}'.format(i, len(samples))) 207 | DataSource[path] = img_loader(path) 208 | print('Preloaded all images') 209 | 210 | def random_split(self, dataset: DGDataset): 211 | lengths, samples = len(dataset), np.array(dataset.samples) 212 | indices = torch.randperm(lengths).tolist() 213 | train_indices, val_indices = indices[:int(lengths * self.SplitRatio)], indices[int(lengths * self.SplitRatio):] 214 | return DGDataset(samples[train_indices], 'train', self.args, self.aug_funcs), DGDataset(samples[val_indices], 'val', self.args, self.aug_funcs) 215 | 216 | def get_datasets(self, aug_funcs=None): 217 | all_samples = [] 218 | self.aug_funcs = aug_funcs 219 | # use official split or random split data 220 | if self.SplitRatio == -1: 221 | # use official split for train and val set 222 | train_datasets, train_samples = self.load_dataset('train') 223 | val_datasets, val_samples = self.load_dataset('val') 224 | all_samples.extend(train_samples), all_samples.extend(val_samples) 225 | else: 226 | # random split from the train set 227 | train_val_datasets, train_val_samples = self.load_dataset('train') 228 | all_samples.extend(train_val_samples) 229 | train_datasets, val_datasets = [], [] 230 | for d in train_val_datasets: 231 | t_d, v_d = self.random_split(d) 232 | train_datasets.append(t_d), val_datasets.append(v_d) 233 | 234 | test_datasets, test_samples = self.load_dataset('test') 235 | all_samples.extend(test_samples) 236 | if self.to_memory: 237 | self.preload_images(all_samples) 238 | return train_datasets, val_datasets, test_datasets 239 | 240 | def analyze_datasets(self, datasets): 241 | classes = [0] * self.NumClasses 242 | for d in datasets: 243 | for _, claz, _ in d.samples: 244 | classes[int(claz)] += 1 245 | print(classes) 246 | 247 | def get_loaders(self, aug_funcs): 248 | datasets = self.get_datasets(aug_funcs) 249 | train_datasets, val_datasets, test_datasets = [ConcatDataset(d) for d in datasets] 250 | 251 | bs = self.args.batch_size 252 | if self.args.loader == 'meta': 253 | train_loader = MetaDataLoader(train_datasets, batch_sampler=DomainSampler(train_datasets, bs, replace=self.args.replace, mvrml='mvrml' in self.args.train), 254 | **self.loader_args) 255 | val_loader = MetaDataLoader(val_datasets, batch_sampler=DomainSampler(val_datasets, bs, replace=self.args.replace, mvrml='mvrml' in self.args.train), 256 | **self.loader_args) 257 | test_loader = MetaDataLoader(test_datasets, batch_size=bs, shuffle=False, **self.loader_args) 258 | elif self.args.loader == 'interleaved': 259 | train_loader = MetaDataLoader(train_datasets, drop_last=True, batch_size=bs, shuffle=True, **self.loader_args) 260 | val_loader = MetaDataLoader(val_datasets, drop_last=False, batch_size=bs, shuffle=False, **self.loader_args) 261 | test_loader = MetaDataLoader(test_datasets, batch_sampler=InterleavedSampler(test_datasets, bs), **self.loader_args) 262 | else: 263 | train_loader = MetaDataLoader(train_datasets, drop_last=True, batch_size=bs, shuffle=True, **self.loader_args) 264 | val_loader = MetaDataLoader(val_datasets, drop_last=False, batch_size=bs, shuffle=False, **self.loader_args) 265 | test_loader = MetaDataLoader(test_datasets, batch_size=bs, shuffle=True, **self.loader_args) 266 | loaders = [train_loader, val_loader, test_loader] 267 | return loaders 268 | 269 | 270 | class DomainSampler(object): 271 | def __init__(self, concatedDataset, batch_size, replace=False, mvrml=False): 272 | assert isinstance(concatedDataset, ConcatDataset) 273 | self.domain_sizes = concatedDataset.cumulative_sizes 274 | self.domains = len(self.domain_sizes) 275 | self.batch_size = batch_size 276 | self.num_batches = self.domain_sizes[-1] // (batch_size * self.domains) 277 | self.domain_sizes = [0] + self.domain_sizes 278 | 279 | self.replace = replace # if replace is set to True, samples are put back 280 | self.sample_temperature = 0 # equally sampled from all domains 281 | self.mvrml = mvrml 282 | if self.mvrml: 283 | print("Training with mvrml!!!") 284 | 285 | def __iter__(self): 286 | domains = range(len(self.domain_sizes) - 1) 287 | 288 | real_batch_size = self.batch_size * self.domains 289 | domain_prob = (torch.rand(self.domains) * self.sample_temperature).softmax(0) 290 | batch_sizes = [int(p.item() * real_batch_size) for p in domain_prob] 291 | left_samples = real_batch_size - np.sum(batch_sizes) 292 | if left_samples > 0: 293 | batch_sizes[-1] += left_samples 294 | 295 | for iter_idx in range(self.num_batches): 296 | rand_domains = list(np.random.choice(domains, size=len(domains), replace=self.replace)) 297 | # if self.mvrml: 298 | # rand_domains += [rand_domains[-1]] 299 | sampled_idx = [np.random.choice(range(self.domain_sizes[idx], self.domain_sizes[idx + 1]), size=batch_sizes[idx], replace=False) 300 | for idx in rand_domains] 301 | sampled_idx = np.concatenate(sampled_idx) 302 | yield sampled_idx 303 | 304 | def __len__(self): 305 | return self.num_batches 306 | 307 | 308 | class InterleavedSampler(object): 309 | def __init__(self, concatedDataset, batch_size): 310 | assert isinstance(concatedDataset, ConcatDataset) 311 | self.original_dataset = concatedDataset 312 | self.batch_size = batch_size 313 | 314 | def init(self): 315 | self.datasets = copy.deepcopy(self.original_dataset.datasets) 316 | self.original_counts = [0] + copy.deepcopy(self.original_dataset.cumulative_sizes[:len(self.datasets)]) 317 | self.counts = copy.deepcopy(self.original_counts) 318 | self.num_batches = sum([len(d)//self.batch_size for d in self.datasets]) 319 | 320 | def __iter__(self): 321 | self.init() 322 | for iter_idx in range(len(self)): 323 | idx = np.random.randint(0, len(self.datasets)) 324 | # idx = 0 325 | dataset = self.datasets[idx] 326 | count = self.counts[idx] 327 | if (count-self.original_counts[idx] + self.batch_size) > len(dataset): 328 | # self.counts = [0] * len(self.original_counts) 329 | self.datasets.pop(idx) 330 | self.counts.pop(idx) 331 | self.original_counts.pop(idx) 332 | idx = 0 333 | # print('========'*10) 334 | 335 | sample_idx = np.arange(self.counts[idx], self.counts[idx]+self.batch_size) 336 | self.counts[idx] += self.batch_size 337 | # print(sample_idx[0]) 338 | yield sample_idx 339 | 340 | def __len__(self): 341 | return self.num_batches #* 10 342 | 343 | 344 | @Datasets.register("PACS") 345 | class PACS(BaseDatasetConfig): 346 | # PACS follow , official split with 0.9 vs 0.1 347 | Name = 'PACS' 348 | NumClasses = 7 349 | SplitRatio = -1 350 | RelativePath = 'PACS/kfold' 351 | Domains = ['art_painting', 'cartoon', 'photo', 'sketch'] 352 | ClassOffset = -1 # text_lists start from 1 not 0 353 | Classes = {i : k for i, k in enumerate(['dog', 'elephant', 'giraffe', 'guitar','horse', 'house', 'person'])} 354 | 355 | 356 | @Datasets.register("VLCS") 357 | class VLCS(BaseDatasetConfig): 358 | # VLCS follow , split with 0.7 vs 0.3 359 | Name = 'VLCS' 360 | NumClasses = 5 361 | SplitRatio = -1 362 | RelativePath = 'VLCS' 363 | Domains = ['CALTECH', 'LABELME', 'PASCAL', 'SUN'] 364 | 365 | 366 | @Datasets.register("OH") 367 | @Datasets.register("OfficeHome") 368 | class OfficeHome(BaseDatasetConfig): 369 | Name = 'OfficeHome' 370 | NumClasses = 65 371 | SplitRatio = 0.9 372 | RelativePath = 'OfficeHome' 373 | Domains = ['Art', 'Clipart', 'Product', 'RealWorld'] 374 | 375 | 376 | @Datasets.register("MDN") 377 | @Datasets.register("MiniDomainNet") 378 | class MiniDomainNet(BaseDatasetConfig): 379 | Name = 'MiniDomainNet' 380 | NumClasses = 126 381 | SplitRatio = 0.9 382 | RelativePath = 'DomainNet' 383 | Domains = ['clipart', 'painting', 'real', 'sketch'] -------------------------------------------------------------------------------- /dataloader/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import importlib 3 | 4 | 5 | def import_all_modules_in_current_folders(): 6 | imported_modules = [] 7 | for module in os.listdir(os.path.dirname(__file__)): 8 | if module == '__init__.py' or module[-3:] != '.py': 9 | continue 10 | importlib.import_module('.' + module[:-3], __package__) # '.' before module_name is required 11 | imported_modules.append(module) 12 | del module 13 | print('Successfully imported modules : ', imported_modules) 14 | 15 | import_all_modules_in_current_folders() 16 | -------------------------------------------------------------------------------- /dataloader/augmentations.py: -------------------------------------------------------------------------------- 1 | # code in this file is adpated from rpmcruz/autoaugment 2 | # https://github.com/rpmcruz/autoaugment/blob/master/transformations.py 3 | import random 4 | 5 | import PIL 6 | import PIL.ImageDraw 7 | import PIL.ImageEnhance 8 | import PIL.ImageOps 9 | import numpy as np 10 | import torch 11 | from PIL import Image 12 | from torchvision import transforms 13 | 14 | 15 | # from einops import rearrange 16 | 17 | 18 | def AutoContrast(img, _): 19 | return PIL.ImageOps.autocontrast(img) 20 | 21 | 22 | def Invert(img, _): 23 | return PIL.ImageOps.invert(img) 24 | 25 | 26 | def Equalize(img, _): 27 | return PIL.ImageOps.equalize(img) 28 | 29 | 30 | def Flip(img, _): # not from the paper 31 | return PIL.ImageOps.mirror(img) 32 | 33 | 34 | def ShearX(img, v): # [-0.3, 0.3] 35 | assert -0.3 <= v <= 0.3 36 | if random.random() > 0.5: 37 | v = -v 38 | return img.transform(img.size, PIL.Image.AFFINE, (1, v, 0, 0, 1, 0)) 39 | 40 | 41 | def ShearY(img, v): # [-0.3, 0.3] 42 | assert -0.3 <= v <= 0.3 43 | if random.random() > 0.5: 44 | v = -v 45 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, v, 1, 0)) 46 | 47 | 48 | def TranslateXabs(img, v): # [-150, 150] => percentage: [-0.45, 0.45] 49 | assert 0 <= v 50 | if random.random() > 0.5: 51 | v = -v 52 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0)) 53 | 54 | 55 | def TranslateYabs(img, v): # [-150, 150] => percentage: [-0.45, 0.45] 56 | assert 0 <= v 57 | if random.random() > 0.5: 58 | v = -v 59 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v)) 60 | 61 | 62 | def Rotate(img, v): # [-30, 30] 63 | assert -30 <= v <= 30 64 | if random.random() > 0.5: 65 | v = -v 66 | return img.rotate(v) 67 | 68 | 69 | def Solarize(img, v): # [0, 256] All pixels above this greyscale level are inverted. 70 | assert 0 <= v <= 256 71 | return PIL.ImageOps.solarize(img, v) 72 | 73 | 74 | def SolarizeAdd(img, addition=0, threshold=128): 75 | img_np = np.array(img).astype(np.int) 76 | img_np = img_np + addition 77 | img_np = np.clip(img_np, 0, 255) 78 | img_np = img_np.astype(np.uint8) 79 | img = Image.fromarray(img_np) 80 | return PIL.ImageOps.solarize(img, threshold) 81 | 82 | 83 | def Posterize(img, v): # [4, 8] number of bits to keep for each channel 84 | v = int(v) 85 | v = max(1, v) 86 | return PIL.ImageOps.posterize(img, v) 87 | 88 | 89 | def Contrast(img, v): # [0, 1] grey -> original 90 | assert 0.1 <= v <= 1.9 91 | return PIL.ImageEnhance.Contrast(img).enhance(v) 92 | 93 | 94 | def Color(img, v): # [0, 1] black -> original 95 | assert 0.1 <= v <= 1.9 96 | return PIL.ImageEnhance.Color(img).enhance(v) 97 | 98 | 99 | def Brightness(img, v): # [0, 1] black -> white 100 | assert 0.1 <= v <= 1.9 101 | return PIL.ImageEnhance.Brightness(img).enhance(v) 102 | 103 | 104 | def Sharpness(img, v): # [0, 1, 2] blured -> original -> sharpened 105 | assert 0.1 <= v <= 1.9 106 | return PIL.ImageEnhance.Sharpness(img).enhance(v) 107 | 108 | 109 | def Cutout(img, v): # [0, 60] => percentage: [0, 0.2] 110 | assert 0.0 <= v <= 0.2 111 | if v <= 0.: 112 | return img 113 | 114 | v = v * img.size[0] 115 | return CutoutAbs(img, v) 116 | 117 | 118 | def CutoutAbs(img, v): # [0, 60] => percentage: [0, 0.2] 119 | # assert 0 <= v <= 20 120 | if v < 0: 121 | return img 122 | w, h = img.size 123 | x0 = np.random.uniform(w) 124 | y0 = np.random.uniform(h) 125 | 126 | x0 = int(max(0, x0 - v / 2.)) 127 | y0 = int(max(0, y0 - v / 2.)) 128 | x1 = min(w, x0 + v) 129 | y1 = min(h, y0 + v) 130 | 131 | xy = (x0, y0, x1, y1) 132 | color = (125, 123, 114) 133 | # color = (0, 0, 0) 134 | img = img.copy() 135 | PIL.ImageDraw.Draw(img).rectangle(xy, color) 136 | return img 137 | 138 | 139 | def mixup(imgs): # [0, 0.4] 140 | def f(img1, v): 141 | i = np.random.choice(len(imgs)) 142 | img2 = PIL.Image.fromarray(imgs[i]) 143 | return PIL.Image.blend(img1, img2, v) 144 | 145 | return f 146 | 147 | 148 | def Identity(img, v): 149 | return img 150 | 151 | 152 | def augment_list(): 153 | # https://github.com/tensorflow/tpu/blob/8462d083dd89489a79e3200bcc8d4063bf362186/models/official/efficientnet/autoaugment.py#L505 154 | l = [ 155 | (Identity, 0., 1.0), 156 | 157 | (Equalize, 0, 1), # 0 158 | (Invert, 0, 1), # 1 159 | 160 | (Posterize, 0, 4), # 2 161 | (Solarize, 0, 256), # 3 162 | (SolarizeAdd, 0, 110), # 4 163 | 164 | (AutoContrast, 0, 1), # 5 165 | 166 | (Color, 0.1, 1.9), # 7 167 | (Brightness, 0.1, 1.9), # 8 168 | (Sharpness, 0.1, 1.9), # 9 169 | 170 | (Contrast, 0.1, 1.9), # 6 171 | # (CutoutAbs, 0, 40), # 12 172 | # (Rotate, 0, 30), # 15 173 | ] 174 | return l 175 | 176 | 177 | class Lighting(object): 178 | """Lighting noise(AlexNet - style PCA - based noise)""" 179 | 180 | def __init__(self, alphastd, eigval, eigvec): 181 | self.alphastd = alphastd 182 | self.eigval = torch.Tensor(eigval) 183 | self.eigvec = torch.Tensor(eigvec) 184 | 185 | def __call__(self, img): 186 | if self.alphastd == 0: 187 | return img 188 | 189 | alpha = img.new().resize_(3).normal_(0, self.alphastd) 190 | rgb = self.eigvec.type_as(img).clone() \ 191 | .mul(alpha.view(1, 3).expand(3, 3)) \ 192 | .mul(self.eigval.view(1, 3).expand(3, 3)) \ 193 | .sum(1).squeeze() 194 | 195 | return img.add(rgb.view(3, 1, 1).expand_as(img)) 196 | 197 | 198 | class CutoutDefault(object): 199 | """ 200 | Reference : https://github.com/quark0/darts/blob/master/cnn/utils.py 201 | """ 202 | 203 | def __init__(self, length): 204 | self.length = length 205 | 206 | def __call__(self, img): 207 | h, w = img.size(1), img.size(2) 208 | mask = np.ones((h, w), np.float32) 209 | y = np.random.randint(h) 210 | x = np.random.randint(w) 211 | 212 | y1 = np.clip(y - self.length // 2, 0, h) 213 | y2 = np.clip(y + self.length // 2, 0, h) 214 | x1 = np.clip(x - self.length // 2, 0, w) 215 | x2 = np.clip(x + self.length // 2, 0, w) 216 | 217 | mask[y1: y2, x1: x2] = 0. 218 | mask = torch.from_numpy(mask) 219 | mask = mask.expand_as(img) 220 | img *= mask 221 | return img 222 | 223 | 224 | class RandAugment: 225 | def __init__(self, n=4, m=5): 226 | self.n = n 227 | self.m = m 228 | self.augment_list = augment_list() 229 | from torchvision import transforms 230 | self.post_transform = transforms.Compose([ 231 | transforms.ToTensor(), 232 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 233 | ]) 234 | 235 | def __call__(self, img, n, m): 236 | if n <= 0: 237 | return img 238 | idxes = [random.randint(0, len(self.augment_list) - 1) for i in range(n)] 239 | return self.post_transform(self.aug_img(idxes, img, m)) 240 | 241 | def aug_img(self, idxes, img, m): 242 | for idx in idxes: 243 | op, minval, maxval = self.augment_list[idx] 244 | if m == -1: 245 | new_m = np.random.randint(0, 30) 246 | else: 247 | new_m = m 248 | val = (float(new_m) / 30) * float(maxval - minval) + minval 249 | img = op(img, val) 250 | return img 251 | 252 | def aug_batch(self, batch, n=5, m=4): 253 | imgs = [] 254 | for img in batch: 255 | idxes = [random.randint(0, len(self.augment_list) - 1) for i in range(n)] 256 | img = self.aug_img(idxes, img, m) 257 | imgs.append(img) 258 | return imgs 259 | 260 | def aug_sequential(self, img, n=5, m=4): 261 | imgs = [img] 262 | idxes = [random.randint(0, len(self.augment_list) - 1) for i in range(n)] 263 | for idx in idxes: 264 | img = self.aug_img([idx], img, m) 265 | imgs.append(img) 266 | return imgs 267 | 268 | 269 | class TestTimeAug(object): 270 | def __init__(self, args, randaug=False, jitter=False): 271 | self.args = args 272 | p = .1 273 | 274 | self.crop, self.flip, self.jitter, self.randaug = True, True, jitter, randaug 275 | 276 | self.resized_crop = transforms.RandomResizedCrop(args.img_size, scale=(args.min_scale, 1)) 277 | self.rand_flip = transforms.RandomHorizontalFlip() 278 | self.randaug_op = RandAugment() 279 | self.jitter_op = transforms.ColorJitter(p, p, p, p) 280 | 281 | self.post_t = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) 282 | 283 | def test_aug(self, image, bs=16): 284 | inputs = [] 285 | 286 | for i in range(bs): 287 | img = image 288 | if self.flip: 289 | img = self.rand_flip(img) 290 | 291 | if self.jitter: 292 | img2 = self.jitter_op(img) 293 | img = Image.blend(img, img2, alpha=0.5) 294 | 295 | if self.randaug: 296 | img = self.randaug_op.aug_batch([img])[0] 297 | 298 | if self.crop: 299 | img = self.resized_crop(img) 300 | 301 | img = self.post_t(img) 302 | inputs.append(img) 303 | 304 | inputs = torch.stack(inputs, 0) 305 | return inputs 306 | 307 | 308 | class Rotation(object): 309 | def tensor_rot_90(self, x): 310 | return x.flip(-1).transpose(-2, -1) 311 | 312 | def tensor_rot_180(self, x): 313 | return x.flip(-1).flip(-2) 314 | 315 | def tensor_rot_270(self, x): 316 | return x.transpose(-2, -1).flip(-1) 317 | 318 | def rotate_batch_with_labels(self, batch, labels): 319 | images = [] 320 | for img, label in zip(batch, labels): 321 | if label == 1: 322 | img = self.tensor_rot_90(img) 323 | elif label == 2: 324 | img = self.tensor_rot_180(img) 325 | elif label == 3: 326 | img = self.tensor_rot_270(img) 327 | images.append(img.unsqueeze(0)) 328 | return torch.cat(images) 329 | 330 | def __call__(self, img, rot_type='rand'): # rotate a single image 331 | data, label = self.rotate_batch(img.unsqueeze(0), rot_type) 332 | return data[0], label[0] 333 | 334 | def rotate_batch(self, batch, rot_type='rand'): 335 | if rot_type == 'rand': 336 | labels = torch.randint(4, (len(batch),), dtype=torch.long) 337 | elif rot_type == 'expand': 338 | labels = torch.cat([torch.zeros(len(batch), dtype=torch.long), 339 | torch.zeros(len(batch), dtype=torch.long) + 1, 340 | torch.zeros(len(batch), dtype=torch.long) + 2, 341 | torch.zeros(len(batch), dtype=torch.long) + 3]) 342 | batch = batch.repeat((4, 1, 1, 1)) 343 | else: 344 | assert isinstance(rot_type, int) 345 | labels = torch.zeros((len(batch),), dtype=torch.long) + rot_type 346 | return self.rotate_batch_with_labels(batch, labels), labels.to(batch.device) 347 | -------------------------------------------------------------------------------- /dataloader/class_maps.py: -------------------------------------------------------------------------------- 1 | OfficeHome = {'53': 'Alarm_Clock', '39': 'Backpack', '26': 'Batteries', '47': 'Bed', '28': 'Bike', '0': 'Bottle', '3': 'Bucket', 2 | '2': 'Calculator', '41': 'Calendar', '35': 'Candles', '49': 'Chair', '43': 'Clipboards', '21': 'Computer', '52': 'Couch', 3 | '16': 'Curtains', '18': 'Desk_Lamp', '14': 'Drill', '6': 'Eraser', '11': 'Exit_Sign', '45': 'Fan', '59': 'File_Cabinet', 4 | '15': 'Flipflops', '61': 'Flowers', '1': 'Folder', '48': 'Fork', '17': 'Glasses', '9': 'Hammer', '4': 'Helmet', '27': 'Kettle', 5 | '29': 'Keyboard', '32': 'Knives', '60': 'Lamp_Shade', '5': 'Laptop', '64': 'Marker', '40': 'Monitor', '42': 'Mop', '33': 'Mouse', 6 | '44': 'Mug', '63': 'Notebook', '57': 'Oven', '51': 'Pan', '34': 'Paper_Clip', '56': 'Pen', '36': 'Pencil', '50': 'Postit_Notes', 7 | '58': 'Printer', '22': 'Push_Pin', '20': 'Radio', '37': 'Refrigerator', '7': 'Ruler', '55': 'Scissors', '13': 'Screwdriver', 8 | '54': 'Shelf', '8': 'Sink', '62': 'Sneakers', '38': 'Soda', '30': 'Speaker', '19': 'Spoon', '24': 'TV', '31': 'Table', 9 | '12': 'Telephone', '46': 'ToothBrush', '23': 'Toys', '25': 'Trash_Can', '10': 'Webcam'} 10 | PACS = {'0': 'dog', '1': 'elephant', '2': 'giraffe', '3': 'guitar', '4': 'horse', '5': 'house', '6': 'person'} 11 | VLCS = {'0': 'bird', '1': 'car', '2': 'chair', '3': 'dog', '4': 'person'} 12 | digits_dg = {str(i):str(i) for i in range(10)} 13 | DomainNet = {'0': 'aircraft_carrier', '1': 'airplane', '2': 'alarm_clock', '3': 'ambulance', '4': 'angel', '5': 'animal_migration', '6': 'ant', 14 | '7': 'anvil', '8': 'apple', '9': 'arm', '10': 'asparagus', '11': 'axe', '12': 'backpack', '13': 'banana', '14': 'bandage', 15 | '15': 'barn', '16': 'baseball', '17': 'baseball_bat', '18': 'basket', '19': 'basketball', '20': 'bat', '21': 'bathtub', 16 | '22': 'beach', '23': 'bear', '24': 'beard', '25': 'bed', '26': 'bee', '27': 'belt', '28': 'bench', '29': 'bicycle', 17 | '30': 'binoculars', '31': 'bird', '32': 'birthday_cake', '33': 'blackberry', '34': 'blueberry', '35': 'book', '36': 'boomerang', 18 | '37': 'bottlecap', '38': 'bowtie', '39': 'bracelet', '40': 'brain', '41': 'bread', '42': 'bridge', '43': 'broccoli', '44': 'broom', 19 | '45': 'bucket', '46': 'bulldozer', '47': 'bus', '48': 'bush', '49': 'butterfly', '50': 'cactus', '51': 'cake', '52': 'calculator', 20 | '53': 'calendar', '54': 'camel', '55': 'camera', '56': 'camouflage', '57': 'campfire', '58': 'candle', '59': 'cannon', 21 | '60': 'canoe', '61': 'car', '62': 'carrot', '63': 'castle', '64': 'cat', '65': 'ceiling_fan', '66': 'cello', '67': 'cell_phone', 22 | '68': 'chair', '69': 'chandelier', '70': 'church', '71': 'circle', '72': 'clarinet', '73': 'clock', '74': 'cloud', 23 | '75': 'coffee_cup', '76': 'compass', '77': 'computer', '78': 'cookie', '79': 'cooler', '80': 'couch', '81': 'cow', '82': 'crab', 24 | '83': 'crayon', '84': 'crocodile', '85': 'crown', '86': 'cruise_ship', '87': 'cup', '88': 'diamond', '89': 'dishwasher', 25 | '90': 'diving_board', '91': 'dog', '92': 'dolphin', '93': 'donut', '94': 'door', '95': 'dragon', '96': 'dresser', '97': 'drill', 26 | '98': 'drums', '99': 'duck', '100': 'dumbbell', '101': 'ear', '102': 'elbow', '103': 'elephant', '104': 'envelope', '105': 'eraser', 27 | '106': 'eye', '107': 'eyeglasses', '108': 'face', '109': 'fan', '110': 'feather', '111': 'fence', '112': 'finger', 28 | '113': 'fire_hydrant', '114': 'fireplace', '115': 'firetruck', '116': 'fish', '117': 'flamingo', '118': 'flashlight', 29 | '119': 'flip_flops', '120': 'floor_lamp', '121': 'flower', '122': 'flying_saucer', '123': 'foot', '124': 'fork', '125': 'frog', 30 | '126': 'frying_pan', '127': 'garden', '128': 'garden_hose', '129': 'giraffe', '130': 'goatee', '131': 'golf_club', '132': 'grapes', 31 | '133': 'grass', '134': 'guitar', '135': 'hamburger', '136': 'hammer', '137': 'hand', '138': 'harp', '139': 'hat', 32 | '140': 'headphones', '141': 'hedgehog', '142': 'helicopter', '143': 'helmet', '144': 'hexagon', '145': 'hockey_puck', 33 | '146': 'hockey_stick', '147': 'horse', '148': 'hospital', '149': 'hot_air_balloon', '150': 'hot_dog', '151': 'hot_tub', 34 | '152': 'hourglass', '153': 'house', '154': 'house_plant', '155': 'hurricane', '156': 'ice_cream', '157': 'jacket', '158': 'jail', 35 | '159': 'kangaroo', '160': 'key', '161': 'keyboard', '162': 'knee', '163': 'knife', '164': 'ladder', '165': 'lantern', 36 | '166': 'laptop', '167': 'leaf', '168': 'leg', '169': 'light_bulb', '170': 'lighter', '171': 'lighthouse', '172': 'lightning', 37 | '173': 'line', '174': 'lion', '175': 'lipstick', '176': 'lobster', '177': 'lollipop', '178': 'mailbox', '179': 'map', 38 | '180': 'marker', '181': 'matches', '182': 'megaphone', '183': 'mermaid', '184': 'microphone', '185': 'microwave', '186': 'monkey', 39 | '187': 'moon', '188': 'mosquito', '189': 'motorbike', '190': 'mountain', '191': 'mouse', '192': 'moustache', '193': 'mouth', 40 | '194': 'mug', '195': 'mushroom', '196': 'nail', '197': 'necklace', '198': 'nose', '199': 'ocean', '200': 'octagon', 41 | '201': 'octopus', '202': 'onion', '203': 'oven', '204': 'owl', '205': 'paintbrush', '206': 'paint_can', '207': 'palm_tree', 42 | '208': 'panda', '209': 'pants', '210': 'paper_clip', '211': 'parachute', '212': 'parrot', '213': 'passport', '214': 'peanut', 43 | '215': 'pear', '216': 'peas', '217': 'pencil', '218': 'penguin', '219': 'piano', '220': 'pickup_truck', '221': 'picture_frame', 44 | '222': 'pig', '223': 'pillow', '224': 'pineapple', '225': 'pizza', '226': 'pliers', '227': 'police_car', '228': 'pond', 45 | '229': 'pool', '230': 'popsicle', '231': 'postcard', '232': 'potato', '233': 'power_outlet', '234': 'purse', '235': 'rabbit', 46 | '236': 'raccoon', '237': 'radio', '238': 'rain', '239': 'rainbow', '240': 'rake', '241': 'remote_control', '242': 'rhinoceros', 47 | '243': 'rifle', '244': 'river', '245': 'roller_coaster', '246': 'rollerskates', '247': 'sailboat', '248': 'sandwich', '249': 'saw', 48 | '250': 'saxophone', '251': 'school_bus', '252': 'scissors', '253': 'scorpion', '254': 'screwdriver', '255': 'sea_turtle', 49 | '256': 'see_saw', '257': 'shark', '258': 'sheep', '259': 'shoe', '260': 'shorts', '261': 'shovel', '262': 'sink', 50 | '263': 'skateboard', '264': 'skull', '265': 'skyscraper', '266': 'sleeping_bag', '267': 'smiley_face', '268': 'snail', 51 | '269': 'snake', '270': 'snorkel', '271': 'snowflake', '272': 'snowman', '273': 'soccer_ball', '274': 'sock', '275': 'speedboat', 52 | '276': 'spider', '277': 'spoon', '278': 'spreadsheet', '279': 'square', '280': 'squiggle', '281': 'squirrel', '282': 'stairs', 53 | '283': 'star', '284': 'steak', '285': 'stereo', '286': 'stethoscope', '287': 'stitches', '288': 'stop_sign', '289': 'stove', 54 | '290': 'strawberry', '291': 'streetlight', '292': 'string_bean', '293': 'submarine', '294': 'suitcase', '295': 'sun', '296': 'swan', 55 | '297': 'sweater', '298': 'swing_set', '299': 'sword', '300': 'syringe', '301': 'table', '302': 'teapot', '303': 'teddy-bear', 56 | '304': 'telephone', '305': 'television', '306': 'tennis_racquet', '307': 'tent', '308': 'The_Eiffel_Tower', 57 | '309': 'The_Great_Wall_of_China', '310': 'The_Mona_Lisa', '311': 'tiger', '312': 'toaster', '313': 'toe', '314': 'toilet', 58 | '315': 'tooth', '316': 'toothbrush', '317': 'toothpaste', '318': 'tornado', '319': 'tractor', '320': 'traffic_light', 59 | '321': 'train', '322': 'tree', '323': 'triangle', '324': 'trombone', '325': 'truck', '326': 'trumpet', '327': 't-shirt', 60 | '328': 'umbrella', '329': 'underwear', '330': 'van', '331': 'vase', '332': 'violin', '333': 'washing_machine', '334': 'watermelon', 61 | '335': 'waterslide', '336': 'whale', '337': 'wheel', '338': 'windmill', '339': 'wine_bottle', '340': 'wine_glass', 62 | '341': 'wristwatch', '342': 'yoga', '343': 'zebra', '344': 'zigzag'} 63 | 64 | OfficeHome31 = {'5': 'calculator', '24': 'ring_binder', '21': 'printer', '11': 'keyboard', '26': 'scissors', '12': 'laptop_computer', 65 | '16': 'mouse', '15': 'monitor', '17': 'mug', '29': 'tape_dispenser', '19': 'pen', '1': 'bike', '23': 'punchers', 66 | '0': 'back_pack', '8': 'desktop_computer', '27': 'speaker', '14': 'mobile_phone', '18': 'paper_notebook', '25': 'ruler', 67 | '13': 'letter_tray', '9': 'file_cabinet', '20': 'phone', '3': 'bookcase', '22': 'projector', '28': 'stapler', '30': 'trash_can', 68 | '2': 'bike_helmet', '10': 'headphones', '7': 'desk_lamp', '6': 'desk_chair', '4': 'bottle'} 69 | Visda17 = {'0': 'aeroplane', '1': 'bicycle', '2': 'bus', '3': 'car', '4': 'horse', '5': 'knife', '6': 'motorcycle', '7': 'person', '8': 'plant', 70 | '9': 'skateboard', '10': 'train', '11': 'truck'} 71 | 72 | 73 | if __name__ == '__main__': 74 | dataset = digits_dg 75 | idx = [int(k) for k in dataset.keys()] 76 | for i in idx: 77 | print(dataset[str(i)]) 78 | -------------------------------------------------------------------------------- /dataloader/jigsaw/jigsaw_process.py: -------------------------------------------------------------------------------- 1 | from random import random 2 | 3 | import numpy as np 4 | import torch 5 | import torchvision 6 | 7 | 8 | class JigsawDataset(): 9 | def __init__(self, jig_classes=31, bias_whole_image=None): 10 | self.permutations = self.__retrieve_permutations(jig_classes) 11 | self.grid_size = 3 12 | self.bias_whole_image = bias_whole_image 13 | self.make_grid = lambda x: torchvision.utils.make_grid(x, self.grid_size, padding=0) 14 | 15 | def __call__(self, img): 16 | n_grids = self.grid_size ** 2 17 | tiles = [None] * n_grids 18 | for n in range(n_grids): 19 | tiles[n] = self.get_tile(img, n) 20 | 21 | order = np.random.randint(len(self.permutations) + 1) # added 1 for class 0: unsorted 22 | if self.bias_whole_image: 23 | if self.bias_whole_image > random(): 24 | order = 0 25 | 26 | if order == 0: 27 | data = tiles 28 | else: 29 | data = [tiles[self.permutations[order - 1][t]] for t in range(n_grids)] 30 | 31 | data = torch.stack(data, 0) 32 | return self.make_grid(data), int(order) 33 | # return {'jigsaw': self.make_grid(data), 'jigsaw_label': int(order)} 34 | 35 | def get_tile(self, img, n): 36 | w = int(img.shape[-1] / self.grid_size) 37 | y = int(n / self.grid_size) 38 | x = int(n % self.grid_size) 39 | tile = img[:, y * w:(y + 1) * w, x * w:(x + 1) * w] 40 | # tile = img.crop([x * w, y * w, (x + 1) * w, (y + 1) * w]) 41 | return tile 42 | 43 | def __retrieve_permutations(self, classes): 44 | all_perm = np.load('dataloader/jigsaw/permutations_%d.npy' % (classes)) 45 | # from range [1,9] to [0,8] 46 | if all_perm.min() == 1: 47 | all_perm = all_perm - 1 48 | return all_perm 49 | -------------------------------------------------------------------------------- /dataloader/jigsaw/permutations_100.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/koncle/DomainAdaptor/5c3d24efb465a56a9a8b57454ea0d23b4ed819c8/dataloader/jigsaw/permutations_100.npy -------------------------------------------------------------------------------- /dataloader/jigsaw/permutations_1000.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/koncle/DomainAdaptor/5c3d24efb465a56a9a8b57454ea0d23b4ed819c8/dataloader/jigsaw/permutations_1000.npy -------------------------------------------------------------------------------- /dataloader/jigsaw/permutations_30.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/koncle/DomainAdaptor/5c3d24efb465a56a9a8b57454ea0d23b4ed819c8/dataloader/jigsaw/permutations_30.npy -------------------------------------------------------------------------------- /dataloader/jigsaw/permutations_31.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/koncle/DomainAdaptor/5c3d24efb465a56a9a8b57454ea0d23b4ed819c8/dataloader/jigsaw/permutations_31.npy -------------------------------------------------------------------------------- /dataloader/text_list_generator.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | from utils.tensor_utils import mkdir 4 | 5 | 6 | def modify_image_list_paths(path): 7 | file_dir = Path(path) 8 | for i, file in enumerate(file_dir.iterdir()): 9 | pre = str(file).split('/')[-1].split('_')[0] + '/' # + '/images/' 10 | with open(str(file).split('.')[0] + '_.txt', 'w') as fw: 11 | with open(str(file), 'r') as fr: 12 | for line in fr.readlines(): 13 | a = pre + line 14 | fw.write(a) 15 | 16 | 17 | def get_class_path_map_from_txt(path, class_idx=2): 18 | file_dir = Path(path) 19 | class_maps = {} 20 | with open(str(file_dir), 'r') as f: 21 | for line in f.readlines(): 22 | path, claz = line.strip().split(' ') 23 | # target = path.split('/')[class_idx] 24 | if claz in class_maps: 25 | class_maps[claz].append(path) 26 | else: 27 | class_maps[claz] = [path] 28 | return class_maps 29 | 30 | 31 | def get_dataset_class_map_in_folder(folder_path, class_idx): 32 | # only txt contains 'train' and 'test' are included in class_map 33 | file_dir = Path(folder_path) 34 | for file in file_dir.iterdir(): 35 | if not ('train' in str(file) or 'test' in str(file)): 36 | continue 37 | clas_maps = get_class_path_map_from_txt(file, class_idx) 38 | class_num = max([int(i) for i in list(clas_maps.keys())]) + 1 39 | print(clas_maps) 40 | print('Class num : {}'.format(class_num)) 41 | break 42 | 43 | 44 | def shuffle_text(path, to_path): 45 | import numpy as np 46 | 47 | with Path(path).open('r') as f: 48 | lines = f.readlines() 49 | lines = np.random.permutation(lines) 50 | 51 | with Path(to_path).open('w') as f: 52 | for l in lines: 53 | f.write(l) 54 | 55 | 56 | def generate_shuffled_text(dataset_path): 57 | dataset_path = Path(dataset_path) 58 | for file in dataset_path.iterdir(): 59 | if file.name.endswith('test.txt'): 60 | target_file_name = str(file)[:-4] + '_shuffled.txt' 61 | shuffle_text(str(file), target_file_name) 62 | print(f'generated {target_file_name}') 63 | 64 | 65 | def generate_text_list_no_val(folder, save_path): 66 | folder, save_path = Path(folder), Path(save_path) 67 | mkdir(save_path) 68 | for domain in folder.iterdir(): 69 | if not domain.is_dir(): 70 | continue 71 | text_list = [] 72 | for i, class_folder in enumerate(domain.iterdir()): 73 | for img in class_folder.iterdir(): 74 | text_list.append(f'{domain.name}/{class_folder.name}/{img.name} {i}\n') 75 | 76 | with open((save_path / f'{domain.name}_train.txt').absolute(), 'w') as f: 77 | f.writelines(text_list) 78 | 79 | with open((save_path / f'{domain.name}_test.txt').absolute(), 'w') as f: 80 | f.writelines(text_list) 81 | print(f'Writed {domain.name}') 82 | 83 | 84 | if __name__ == '__main__': 85 | generate_text_list_no_val('/data/DataSets/PACS', 'text_lists/PACS') 86 | folder = Path('./text_lists') 87 | # for dataset in folder.iterdir(): 88 | # if 'PACS' not in str(dataset): 89 | # continue 90 | generate_shuffled_text('text_lists/ColoredMNIST') 91 | -------------------------------------------------------------------------------- /dataloader/text_lists/MiniDomainNet/classes.txt: -------------------------------------------------------------------------------- 1 | aircraft_carrier 2 | airplane 3 | alarm_clock 4 | ambulance 5 | angel 6 | animal_migration 7 | ant 8 | anvil 9 | apple 10 | arm 11 | asparagus 12 | axe 13 | backpack 14 | banana 15 | bandage 16 | barn 17 | baseball 18 | baseball_bat 19 | basket 20 | basketball 21 | bat 22 | bathtub 23 | beach 24 | bear 25 | beard 26 | bed 27 | bee 28 | belt 29 | bench 30 | bicycle 31 | binoculars 32 | bird 33 | birthday_cake 34 | blackberry 35 | blueberry 36 | book 37 | boomerang 38 | bottlecap 39 | bowtie 40 | bracelet 41 | brain 42 | bread 43 | bridge 44 | broccoli 45 | broom 46 | bucket 47 | bulldozer 48 | bus 49 | bush 50 | butterfly 51 | cactus 52 | cake 53 | calculator 54 | calendar 55 | camel 56 | camera 57 | camouflage 58 | campfire 59 | candle 60 | cannon 61 | canoe 62 | car 63 | carrot 64 | castle 65 | cat 66 | ceiling_fan 67 | cello 68 | cell_phone 69 | chair 70 | chandelier 71 | church 72 | circle 73 | clarinet 74 | clock 75 | cloud 76 | coffee_cup 77 | compass 78 | computer 79 | cookie 80 | cooler 81 | couch 82 | cow 83 | crab 84 | crayon 85 | crocodile 86 | crown 87 | cruise_ship 88 | cup 89 | diamond 90 | dishwasher 91 | diving_board 92 | dog 93 | dolphin 94 | donut 95 | door 96 | dragon 97 | dresser 98 | drill 99 | drums 100 | duck 101 | dumbbell 102 | ear 103 | elbow 104 | elephant 105 | envelope 106 | eraser 107 | eye 108 | eyeglasses 109 | face 110 | fan 111 | feather 112 | fence 113 | finger 114 | fire_hydrant 115 | fireplace 116 | firetruck 117 | fish 118 | flamingo 119 | flashlight 120 | flip_flops 121 | floor_lamp 122 | flower 123 | flying_saucer 124 | foot 125 | fork 126 | frog 127 | frying_pan 128 | garden 129 | garden_hose 130 | giraffe 131 | goatee 132 | golf_club 133 | grapes 134 | grass 135 | guitar 136 | hamburger 137 | hammer 138 | hand 139 | harp 140 | hat 141 | headphones 142 | hedgehog 143 | helicopter 144 | helmet 145 | hexagon 146 | hockey_puck 147 | hockey_stick 148 | horse 149 | hospital 150 | hot_air_balloon 151 | hot_dog 152 | hot_tub 153 | hourglass 154 | house 155 | house_plant 156 | hurricane 157 | ice_cream 158 | jacket 159 | jail 160 | kangaroo 161 | key 162 | keyboard 163 | knee 164 | knife 165 | ladder 166 | lantern 167 | laptop 168 | leaf 169 | leg 170 | light_bulb 171 | lighter 172 | lighthouse 173 | lightning 174 | line 175 | lion 176 | lipstick 177 | lobster 178 | lollipop 179 | mailbox 180 | map 181 | marker 182 | matches 183 | megaphone 184 | mermaid 185 | microphone 186 | microwave 187 | monkey 188 | moon 189 | mosquito 190 | motorbike 191 | mountain 192 | mouse 193 | moustache 194 | mouth 195 | mug 196 | mushroom 197 | nail 198 | necklace 199 | nose 200 | ocean 201 | octagon 202 | octopus 203 | onion 204 | oven 205 | owl 206 | paintbrush 207 | paint_can 208 | palm_tree 209 | panda 210 | pants 211 | paper_clip 212 | parachute 213 | parrot 214 | passport 215 | peanut 216 | pear 217 | peas 218 | pencil 219 | penguin 220 | piano 221 | pickup_truck 222 | picture_frame 223 | pig 224 | pillow 225 | pineapple 226 | pizza 227 | pliers 228 | police_car 229 | pond 230 | pool 231 | popsicle 232 | postcard 233 | potato 234 | power_outlet 235 | purse 236 | rabbit 237 | raccoon 238 | radio 239 | rain 240 | rainbow 241 | rake 242 | remote_control 243 | rhinoceros 244 | rifle 245 | river 246 | roller_coaster 247 | rollerskates 248 | sailboat 249 | sandwich 250 | saw 251 | saxophone 252 | school_bus 253 | scissors 254 | scorpion 255 | screwdriver 256 | sea_turtle 257 | see_saw 258 | shark 259 | sheep 260 | shoe 261 | shorts 262 | shovel 263 | sink 264 | skateboard 265 | skull 266 | skyscraper 267 | sleeping_bag 268 | smiley_face 269 | snail 270 | snake 271 | snorkel 272 | snowflake 273 | snowman 274 | soccer_ball 275 | sock 276 | speedboat 277 | spider 278 | spoon 279 | spreadsheet 280 | square 281 | squiggle 282 | squirrel 283 | stairs 284 | star 285 | steak 286 | stereo 287 | stethoscope 288 | stitches 289 | stop_sign 290 | stove 291 | strawberry 292 | streetlight 293 | string_bean 294 | submarine 295 | suitcase 296 | sun 297 | swan 298 | sweater 299 | swing_set 300 | sword 301 | syringe 302 | table 303 | teapot 304 | teddy-bear 305 | telephone 306 | television 307 | tennis_racquet 308 | tent 309 | The_Eiffel_Tower 310 | The_Great_Wall_of_China 311 | The_Mona_Lisa 312 | tiger 313 | toaster 314 | toe 315 | toilet 316 | tooth 317 | toothbrush 318 | toothpaste 319 | tornado 320 | tractor 321 | traffic_light 322 | train 323 | tree 324 | triangle 325 | trombone 326 | truck 327 | trumpet 328 | t-shirt 329 | umbrella 330 | underwear 331 | van 332 | vase 333 | violin 334 | washing_machine 335 | watermelon 336 | waterslide 337 | whale 338 | wheel 339 | windmill 340 | wine_bottle 341 | wine_glass 342 | wristwatch 343 | yoga 344 | zebra 345 | zigzag -------------------------------------------------------------------------------- /dataloader/text_lists/OfficeHome/classes.txt: -------------------------------------------------------------------------------- 1 | Alarm_Clock 2 | Backpack 3 | Batteries 4 | Bed 5 | Bike 6 | Bottle 7 | Bucket 8 | Calculator 9 | Calendar 10 | Candles 11 | Chair 12 | Clipboards 13 | Computer 14 | Couch 15 | Curtains 16 | Desk_Lamp 17 | Drill 18 | Eraser 19 | Exit_Sign 20 | Fan 21 | File_Cabinet 22 | Flipflops 23 | Flowers 24 | Folder 25 | Fork 26 | Glasses 27 | Hammer 28 | Helmet 29 | Kettle 30 | Keyboard 31 | Knives 32 | Lamp_Shade 33 | Laptop 34 | Marker 35 | Monitor 36 | Mop 37 | Mouse 38 | Mug 39 | Notebook 40 | Oven 41 | Pan 42 | Paper_Clip 43 | Pen 44 | Pencil 45 | Postit_Notes 46 | Printer 47 | Push_Pin 48 | Radio 49 | Refrigerator 50 | Ruler 51 | Scissors 52 | Screwdriver 53 | Shelf 54 | Sink 55 | Sneakers 56 | Soda 57 | Speaker 58 | Spoon 59 | TV 60 | Table 61 | Telephone 62 | ToothBrush 63 | Toys 64 | Trash_Can 65 | Webcam -------------------------------------------------------------------------------- /dataloader/text_lists/PACS/art_painting_val.txt: -------------------------------------------------------------------------------- 1 | art_painting/dog/pic_225.jpg 1 2 | art_painting/dog/pic_249.jpg 1 3 | art_painting/dog/pic_306.jpg 1 4 | art_painting/dog/pic_241.jpg 1 5 | art_painting/dog/pic_219.jpg 1 6 | art_painting/dog/pic_252.jpg 1 7 | art_painting/dog/pic_309.jpg 1 8 | art_painting/dog/pic_255.jpg 1 9 | art_painting/dog/pic_310.jpg 1 10 | art_painting/dog/pic_247.jpg 1 11 | art_painting/dog/pic_236.jpg 1 12 | art_painting/dog/pic_242.jpg 1 13 | art_painting/dog/pic_257.jpg 1 14 | art_painting/dog/pic_314.jpg 1 15 | art_painting/dog/pic_317.jpg 1 16 | art_painting/dog/pic_315.jpg 1 17 | art_painting/dog/pic_248.jpg 1 18 | art_painting/dog/pic_250.jpg 1 19 | art_painting/dog/pic_282.jpg 1 20 | art_painting/dog/pic_260.jpg 1 21 | art_painting/dog/pic_316.jpg 1 22 | art_painting/dog/pic_305.jpg 1 23 | art_painting/dog/pic_300.jpg 1 24 | art_painting/dog/pic_365.jpg 1 25 | art_painting/dog/pic_296.jpg 1 26 | art_painting/dog/pic_301.jpg 1 27 | art_painting/dog/pic_298.jpg 1 28 | art_painting/dog/pic_291.jpg 1 29 | art_painting/dog/pic_313.jpg 1 30 | art_painting/dog/pic_311.jpg 1 31 | art_painting/dog/pic_312.jpg 1 32 | art_painting/dog/pic_308.jpg 1 33 | art_painting/dog/pic_329.jpg 1 34 | art_painting/dog/pic_322.jpg 1 35 | art_painting/dog/pic_323.jpg 1 36 | art_painting/dog/pic_330.jpg 1 37 | art_painting/dog/pic_371.jpg 1 38 | art_painting/dog/pic_339.jpg 1 39 | art_painting/elephant/pic_243.jpg 2 40 | art_painting/elephant/pic_154.jpg 2 41 | art_painting/elephant/pic_239.jpg 2 42 | art_painting/elephant/pic_156.jpg 2 43 | art_painting/elephant/pic_167.jpg 2 44 | art_painting/elephant/pic_168.jpg 2 45 | art_painting/elephant/pic_162.jpg 2 46 | art_painting/elephant/pic_161.jpg 2 47 | art_painting/elephant/pic_159.jpg 2 48 | art_painting/elephant/pic_160.jpg 2 49 | art_painting/elephant/pic_158.jpg 2 50 | art_painting/elephant/pic_157.jpg 2 51 | art_painting/elephant/pic_166.jpg 2 52 | art_painting/elephant/pic_171.jpg 2 53 | art_painting/elephant/pic_169.jpg 2 54 | art_painting/elephant/pic_170.jpg 2 55 | art_painting/elephant/pic_176.jpg 2 56 | art_painting/elephant/pic_175.jpg 2 57 | art_painting/elephant/pic_173.jpg 2 58 | art_painting/elephant/pic_172.jpg 2 59 | art_painting/elephant/pic_082.jpg 2 60 | art_painting/elephant/pic_081.jpg 2 61 | art_painting/elephant/pic_080.jpg 2 62 | art_painting/elephant/pic_078.jpg 2 63 | art_painting/elephant/pic_079.jpg 2 64 | art_painting/elephant/pic_093.jpg 2 65 | art_painting/giraffe/pic_134.jpg 3 66 | art_painting/giraffe/pic_129.jpg 3 67 | art_painting/giraffe/pic_127.jpg 3 68 | art_painting/giraffe/pic_151.jpg 3 69 | art_painting/giraffe/pic_131.jpg 3 70 | art_painting/giraffe/pic_158.jpg 3 71 | art_painting/giraffe/pic_144.jpg 3 72 | art_painting/giraffe/pic_238.jpg 3 73 | art_painting/giraffe/pic_222.jpg 3 74 | art_painting/giraffe/pic_185.jpg 3 75 | art_painting/giraffe/pic_160.jpg 3 76 | art_painting/giraffe/pic_155.jpg 3 77 | art_painting/giraffe/pic_209.jpg 3 78 | art_painting/giraffe/pic_228.jpg 3 79 | art_painting/giraffe/pic_169.jpg 3 80 | art_painting/giraffe/pic_198.jpg 3 81 | art_painting/giraffe/pic_145.jpg 3 82 | art_painting/giraffe/pic_273.jpg 3 83 | art_painting/giraffe/pic_303.jpg 3 84 | art_painting/giraffe/pic_284.jpg 3 85 | art_painting/giraffe/pic_302.jpg 3 86 | art_painting/giraffe/pic_286.jpg 3 87 | art_painting/giraffe/pic_287.jpg 3 88 | art_painting/giraffe/pic_301.jpg 3 89 | art_painting/giraffe/pic_295.jpg 3 90 | art_painting/giraffe/pic_296.jpg 3 91 | art_painting/giraffe/pic_311.jpg 3 92 | art_painting/giraffe/pic_309.jpg 3 93 | art_painting/giraffe/pic_310.jpg 3 94 | art_painting/guitar/pic_125.jpg 4 95 | art_painting/guitar/pic_124.jpg 4 96 | art_painting/guitar/pic_179.jpg 4 97 | art_painting/guitar/pic_147.jpg 4 98 | art_painting/guitar/pic_146.jpg 4 99 | art_painting/guitar/pic_183.jpg 4 100 | art_painting/guitar/pic_126.jpg 4 101 | art_painting/guitar/pic_172.jpg 4 102 | art_painting/guitar/pic_137.jpg 4 103 | art_painting/guitar/pic_180.jpg 4 104 | art_painting/guitar/pic_150.jpg 4 105 | art_painting/guitar/pic_176.jpg 4 106 | art_painting/guitar/pic_187.jpg 4 107 | art_painting/guitar/pic_186.jpg 4 108 | art_painting/guitar/pic_184.jpg 4 109 | art_painting/guitar/pic_174.jpg 4 110 | art_painting/guitar/pic_165.jpg 4 111 | art_painting/guitar/pic_161.jpg 4 112 | art_painting/guitar/pic_162.jpg 4 113 | art_painting/horse/pic_034.jpg 5 114 | art_painting/horse/pic_040.jpg 5 115 | art_painting/horse/pic_039.jpg 5 116 | art_painting/horse/pic_042.jpg 5 117 | art_painting/horse/pic_028.jpg 5 118 | art_painting/horse/pic_037.jpg 5 119 | art_painting/horse/pic_041.jpg 5 120 | art_painting/horse/pic_033.jpg 5 121 | art_painting/horse/pic_038.jpg 5 122 | art_painting/horse/pic_025.jpg 5 123 | art_painting/horse/pic_023.jpg 5 124 | art_painting/horse/pic_045.jpg 5 125 | art_painting/horse/pic_030.jpg 5 126 | art_painting/horse/pic_043.jpg 5 127 | art_painting/horse/pic_021.jpg 5 128 | art_painting/horse/pic_026.jpg 5 129 | art_painting/horse/pic_046.jpg 5 130 | art_painting/horse/pic_001.jpg 5 131 | art_painting/horse/pic_002.jpg 5 132 | art_painting/horse/pic_003.jpg 5 133 | art_painting/horse/pic_004.jpg 5 134 | art_painting/house/pic_313.jpg 6 135 | art_painting/house/pic_169.jpg 6 136 | art_painting/house/pic_168.jpg 6 137 | art_painting/house/pic_308.jpg 6 138 | art_painting/house/pic_167.jpg 6 139 | art_painting/house/pic_310.jpg 6 140 | art_painting/house/pic_314.jpg 6 141 | art_painting/house/pic_170.jpg 6 142 | art_painting/house/pic_316.jpg 6 143 | art_painting/house/pic_175.jpg 6 144 | art_painting/house/pic_173.jpg 6 145 | art_painting/house/pic_322.jpg 6 146 | art_painting/house/pic_321.jpg 6 147 | art_painting/house/pic_320.jpg 6 148 | art_painting/house/pic_178.jpg 6 149 | art_painting/house/pic_331.jpg 6 150 | art_painting/house/pic_001.jpg 6 151 | art_painting/house/pic_002.jpg 6 152 | art_painting/house/pic_003.jpg 6 153 | art_painting/house/pic_004.jpg 6 154 | art_painting/house/pic_005.jpg 6 155 | art_painting/house/pic_006.jpg 6 156 | art_painting/house/pic_007.jpg 6 157 | art_painting/house/pic_008.jpg 6 158 | art_painting/house/pic_009.jpg 6 159 | art_painting/house/pic_010.jpg 6 160 | art_painting/house/pic_011.jpg 6 161 | art_painting/house/pic_013.jpg 6 162 | art_painting/house/pic_015.jpg 6 163 | art_painting/house/pic_030.jpg 6 164 | art_painting/person/pic_280.jpg 7 165 | art_painting/person/pic_278.jpg 7 166 | art_painting/person/pic_277.jpg 7 167 | art_painting/person/pic_276.jpg 7 168 | art_painting/person/pic_275.jpg 7 169 | art_painting/person/pic_273.jpg 7 170 | art_painting/person/pic_284.jpg 7 171 | art_painting/person/pic_283.jpg 7 172 | art_painting/person/pic_281.jpg 7 173 | art_painting/person/pic_282.jpg 7 174 | art_painting/person/pic_285.jpg 7 175 | art_painting/person/pic_269.jpg 7 176 | art_painting/person/pic_297.jpg 7 177 | art_painting/person/pic_298.jpg 7 178 | art_painting/person/pic_296.jpg 7 179 | art_painting/person/pic_295.jpg 7 180 | art_painting/person/pic_134.jpg 7 181 | art_painting/person/pic_133.jpg 7 182 | art_painting/person/pic_135.jpg 7 183 | art_painting/person/pic_310.jpg 7 184 | art_painting/person/pic_141.jpg 7 185 | art_painting/person/pic_001.jpg 7 186 | art_painting/person/pic_002.jpg 7 187 | art_painting/person/pic_003.jpg 7 188 | art_painting/person/pic_004.jpg 7 189 | art_painting/person/pic_005.jpg 7 190 | art_painting/person/pic_048.jpg 7 191 | art_painting/person/pic_050.jpg 7 192 | art_painting/person/pic_052.jpg 7 193 | art_painting/person/pic_055.jpg 7 194 | art_painting/person/pic_056.jpg 7 195 | art_painting/person/pic_065.jpg 7 196 | art_painting/person/pic_331.jpg 7 197 | art_painting/person/pic_330.jpg 7 198 | art_painting/person/pic_176.jpg 7 199 | art_painting/person/pic_416.jpg 7 200 | art_painting/person/pic_420.jpg 7 201 | art_painting/person/pic_426.jpg 7 202 | art_painting/person/pic_424.jpg 7 203 | art_painting/person/pic_423.jpg 7 204 | art_painting/person/pic_421.jpg 7 205 | art_painting/person/pic_183.jpg 7 206 | art_painting/person/pic_428.jpg 7 207 | art_painting/person/pic_430.jpg 7 208 | art_painting/person/pic_429.jpg 7 209 | -------------------------------------------------------------------------------- /dataloader/text_lists/PACS/cartoon_val.txt: -------------------------------------------------------------------------------- 1 | cartoon/dog/pic_383.jpg 1 2 | cartoon/dog/pic_382.jpg 1 3 | cartoon/dog/pic_386.jpg 1 4 | cartoon/dog/pic_384.jpg 1 5 | cartoon/dog/pic_385.jpg 1 6 | cartoon/dog/pic_391.jpg 1 7 | cartoon/dog/pic_390.jpg 1 8 | cartoon/dog/pic_392.jpg 1 9 | cartoon/dog/pic_393.jpg 1 10 | cartoon/dog/pic_405.jpg 1 11 | cartoon/dog/pic_403.jpg 1 12 | cartoon/dog/pic_417.jpg 1 13 | cartoon/dog/pic_416.jpg 1 14 | cartoon/dog/pic_415.jpg 1 15 | cartoon/dog/pic_150.jpg 1 16 | cartoon/dog/pic_233.jpg 1 17 | cartoon/dog/pic_232.jpg 1 18 | cartoon/dog/pic_227.jpg 1 19 | cartoon/dog/pic_228.jpg 1 20 | cartoon/dog/pic_229.jpg 1 21 | cartoon/dog/pic_226.jpg 1 22 | cartoon/dog/pic_230.jpg 1 23 | cartoon/dog/pic_286.jpg 1 24 | cartoon/dog/pic_285.jpg 1 25 | cartoon/dog/pic_276.jpg 1 26 | cartoon/dog/pic_262.jpg 1 27 | cartoon/dog/pic_259.jpg 1 28 | cartoon/dog/pic_257.jpg 1 29 | cartoon/dog/pic_254.jpg 1 30 | cartoon/dog/pic_252.jpg 1 31 | cartoon/dog/pic_249.jpg 1 32 | cartoon/dog/pic_001.jpg 1 33 | cartoon/dog/pic_003.jpg 1 34 | cartoon/dog/pic_004.jpg 1 35 | cartoon/dog/pic_005.jpg 1 36 | cartoon/dog/pic_006.jpg 1 37 | cartoon/dog/pic_031.jpg 1 38 | cartoon/dog/pic_043.jpg 1 39 | cartoon/dog/pic_025.jpg 1 40 | cartoon/elephant/pic_211.jpg 2 41 | cartoon/elephant/pic_154.jpg 2 42 | cartoon/elephant/pic_153.jpg 2 43 | cartoon/elephant/pic_237.jpg 2 44 | cartoon/elephant/pic_227.jpg 2 45 | cartoon/elephant/pic_226.jpg 2 46 | cartoon/elephant/pic_225.jpg 2 47 | cartoon/elephant/pic_155.jpg 2 48 | cartoon/elephant/pic_165.jpg 2 49 | cartoon/elephant/pic_164.jpg 2 50 | cartoon/elephant/pic_162.jpg 2 51 | cartoon/elephant/pic_157.jpg 2 52 | cartoon/elephant/pic_156.jpg 2 53 | cartoon/elephant/pic_166.jpg 2 54 | cartoon/elephant/pic_168.jpg 2 55 | cartoon/elephant/pic_167.jpg 2 56 | cartoon/elephant/pic_169.jpg 2 57 | cartoon/elephant/pic_171.jpg 2 58 | cartoon/elephant/pic_170.jpg 2 59 | cartoon/elephant/pic_240.jpg 2 60 | cartoon/elephant/pic_243.jpg 2 61 | cartoon/elephant/pic_242.jpg 2 62 | cartoon/elephant/pic_244.jpg 2 63 | cartoon/elephant/pic_172.jpg 2 64 | cartoon/elephant/pic_247.jpg 2 65 | cartoon/elephant/pic_248.jpg 2 66 | cartoon/elephant/pic_251.jpg 2 67 | cartoon/elephant/pic_250.jpg 2 68 | cartoon/elephant/pic_249.jpg 2 69 | cartoon/elephant/pic_252.jpg 2 70 | cartoon/elephant/pic_258.jpg 2 71 | cartoon/elephant/pic_257.jpg 2 72 | cartoon/elephant/pic_173.jpg 2 73 | cartoon/elephant/pic_412.jpg 2 74 | cartoon/elephant/pic_411.jpg 2 75 | cartoon/elephant/pic_408.jpg 2 76 | cartoon/elephant/pic_409.jpg 2 77 | cartoon/elephant/pic_413.jpg 2 78 | cartoon/elephant/pic_410.jpg 2 79 | cartoon/elephant/pic_431.jpg 2 80 | cartoon/elephant/pic_430.jpg 2 81 | cartoon/elephant/pic_428.jpg 2 82 | cartoon/elephant/pic_425.jpg 2 83 | cartoon/elephant/pic_423.jpg 2 84 | cartoon/elephant/pic_424.jpg 2 85 | cartoon/elephant/pic_420.jpg 2 86 | cartoon/giraffe/pic_005.jpg 3 87 | cartoon/giraffe/pic_006.jpg 3 88 | cartoon/giraffe/pic_007.jpg 3 89 | cartoon/giraffe/pic_008.jpg 3 90 | cartoon/giraffe/pic_009.jpg 3 91 | cartoon/giraffe/pic_010.jpg 3 92 | cartoon/giraffe/pic_011.jpg 3 93 | cartoon/giraffe/pic_012.jpg 3 94 | cartoon/giraffe/pic_013.jpg 3 95 | cartoon/giraffe/pic_014.jpg 3 96 | cartoon/giraffe/pic_015.jpg 3 97 | cartoon/giraffe/pic_016.jpg 3 98 | cartoon/giraffe/pic_017.jpg 3 99 | cartoon/giraffe/pic_018.jpg 3 100 | cartoon/giraffe/pic_019.jpg 3 101 | cartoon/giraffe/pic_020.jpg 3 102 | cartoon/giraffe/pic_022.jpg 3 103 | cartoon/giraffe/pic_025.jpg 3 104 | cartoon/giraffe/pic_024.jpg 3 105 | cartoon/giraffe/pic_091.jpg 3 106 | cartoon/giraffe/pic_090.jpg 3 107 | cartoon/giraffe/pic_087.jpg 3 108 | cartoon/giraffe/pic_086.jpg 3 109 | cartoon/giraffe/pic_085.jpg 3 110 | cartoon/giraffe/pic_095.jpg 3 111 | cartoon/giraffe/pic_096.jpg 3 112 | cartoon/giraffe/pic_093.jpg 3 113 | cartoon/giraffe/pic_094.jpg 3 114 | cartoon/giraffe/pic_106.jpg 3 115 | cartoon/giraffe/pic_108.jpg 3 116 | cartoon/giraffe/pic_104.jpg 3 117 | cartoon/giraffe/pic_103.jpg 3 118 | cartoon/giraffe/pic_101.jpg 3 119 | cartoon/giraffe/pic_100.jpg 3 120 | cartoon/giraffe/pic_099.jpg 3 121 | cartoon/guitar/pic_072.jpg 4 122 | cartoon/guitar/pic_003.jpg 4 123 | cartoon/guitar/pic_004.jpg 4 124 | cartoon/guitar/pic_005.jpg 4 125 | cartoon/guitar/pic_006.jpg 4 126 | cartoon/guitar/pic_007.jpg 4 127 | cartoon/guitar/pic_009.jpg 4 128 | cartoon/guitar/pic_010.jpg 4 129 | cartoon/guitar/pic_011.jpg 4 130 | cartoon/guitar/pic_012.jpg 4 131 | cartoon/guitar/pic_013.jpg 4 132 | cartoon/guitar/pic_016.jpg 4 133 | cartoon/guitar/pic_017.jpg 4 134 | cartoon/guitar/pic_020.jpg 4 135 | cartoon/horse/pic_329.jpg 5 136 | cartoon/horse/pic_317.jpg 5 137 | cartoon/horse/pic_331.jpg 5 138 | cartoon/horse/pic_333.jpg 5 139 | cartoon/horse/pic_332.jpg 5 140 | cartoon/horse/pic_334.jpg 5 141 | cartoon/horse/pic_324.jpg 5 142 | cartoon/horse/pic_318.jpg 5 143 | cartoon/horse/pic_338.jpg 5 144 | cartoon/horse/pic_337.jpg 5 145 | cartoon/horse/pic_341.jpg 5 146 | cartoon/horse/pic_340.jpg 5 147 | cartoon/horse/pic_335.jpg 5 148 | cartoon/horse/pic_342.jpg 5 149 | cartoon/horse/pic_347.jpg 5 150 | cartoon/horse/pic_346.jpg 5 151 | cartoon/horse/pic_343.jpg 5 152 | cartoon/horse/pic_336.jpg 5 153 | cartoon/horse/pic_348.jpg 5 154 | cartoon/horse/pic_339.jpg 5 155 | cartoon/horse/pic_349.jpg 5 156 | cartoon/horse/pic_139.jpg 5 157 | cartoon/horse/pic_132.jpg 5 158 | cartoon/horse/pic_141.jpg 5 159 | cartoon/horse/pic_133.jpg 5 160 | cartoon/horse/pic_162.jpg 5 161 | cartoon/horse/pic_155.jpg 5 162 | cartoon/horse/pic_159.jpg 5 163 | cartoon/horse/pic_156.jpg 5 164 | cartoon/horse/pic_151.jpg 5 165 | cartoon/horse/pic_149.jpg 5 166 | cartoon/horse/pic_147.jpg 5 167 | cartoon/horse/pic_161.jpg 5 168 | cartoon/house/pic_103.jpg 6 169 | cartoon/house/pic_091.jpg 6 170 | cartoon/house/pic_089.jpg 6 171 | cartoon/house/pic_092.jpg 6 172 | cartoon/house/pic_093.jpg 6 173 | cartoon/house/pic_107.jpg 6 174 | cartoon/house/pic_104.jpg 6 175 | cartoon/house/pic_114.jpg 6 176 | cartoon/house/pic_112.jpg 6 177 | cartoon/house/pic_109.jpg 6 178 | cartoon/house/pic_108.jpg 6 179 | cartoon/house/pic_102.jpg 6 180 | cartoon/house/pic_099.jpg 6 181 | cartoon/house/pic_098.jpg 6 182 | cartoon/house/pic_097.jpg 6 183 | cartoon/house/pic_111.jpg 6 184 | cartoon/house/pic_320.jpg 6 185 | cartoon/house/pic_321.jpg 6 186 | cartoon/house/pic_315.jpg 6 187 | cartoon/house/pic_322.jpg 6 188 | cartoon/house/pic_323.jpg 6 189 | cartoon/house/pic_311.jpg 6 190 | cartoon/house/pic_324.jpg 6 191 | cartoon/house/pic_327.jpg 6 192 | cartoon/house/pic_312.jpg 6 193 | cartoon/house/pic_314.jpg 6 194 | cartoon/house/pic_328.jpg 6 195 | cartoon/house/pic_069.jpg 6 196 | cartoon/house/pic_079.jpg 6 197 | cartoon/person/pic_308.jpg 7 198 | cartoon/person/pic_307.jpg 7 199 | cartoon/person/pic_306.jpg 7 200 | cartoon/person/pic_313.jpg 7 201 | cartoon/person/pic_323.jpg 7 202 | cartoon/person/pic_319.jpg 7 203 | cartoon/person/pic_320.jpg 7 204 | cartoon/person/pic_321.jpg 7 205 | cartoon/person/pic_318.jpg 7 206 | cartoon/person/pic_317.jpg 7 207 | cartoon/person/pic_316.jpg 7 208 | cartoon/person/pic_324.jpg 7 209 | cartoon/person/pic_334.jpg 7 210 | cartoon/person/pic_331.jpg 7 211 | cartoon/person/pic_332.jpg 7 212 | cartoon/person/pic_333.jpg 7 213 | cartoon/person/pic_144.jpg 7 214 | cartoon/person/pic_145.jpg 7 215 | cartoon/person/pic_143.jpg 7 216 | cartoon/person/pic_138.jpg 7 217 | cartoon/person/pic_154.jpg 7 218 | cartoon/person/pic_151.jpg 7 219 | cartoon/person/pic_152.jpg 7 220 | cartoon/person/pic_148.jpg 7 221 | cartoon/person/pic_149.jpg 7 222 | cartoon/person/pic_155.jpg 7 223 | cartoon/person/pic_157.jpg 7 224 | cartoon/person/pic_159.jpg 7 225 | cartoon/person/pic_176.jpg 7 226 | cartoon/person/pic_171.jpg 7 227 | cartoon/person/pic_168.jpg 7 228 | cartoon/person/pic_169.jpg 7 229 | cartoon/person/pic_167.jpg 7 230 | cartoon/person/pic_056.jpg 7 231 | cartoon/person/pic_071.jpg 7 232 | cartoon/person/pic_070.jpg 7 233 | cartoon/person/pic_069.jpg 7 234 | cartoon/person/pic_073.jpg 7 235 | cartoon/person/pic_075.jpg 7 236 | cartoon/person/pic_076.jpg 7 237 | cartoon/person/pic_068.jpg 7 238 | -------------------------------------------------------------------------------- /dataloader/text_lists/PACS/classes.txt: -------------------------------------------------------------------------------- 1 | dog 2 | elephant 3 | giraffe 4 | guitar 5 | horse 6 | house 7 | person -------------------------------------------------------------------------------- /dataloader/text_lists/PACS/photo_val.txt: -------------------------------------------------------------------------------- 1 | photo/dog/056_0001.jpg 1 2 | photo/dog/056_0002.jpg 1 3 | photo/dog/056_0003.jpg 1 4 | photo/dog/056_0004.jpg 1 5 | photo/dog/056_0005.jpg 1 6 | photo/dog/056_0006.jpg 1 7 | photo/dog/056_0007.jpg 1 8 | photo/dog/056_0009.jpg 1 9 | photo/dog/056_0010.jpg 1 10 | photo/dog/056_0011.jpg 1 11 | photo/dog/056_0012.jpg 1 12 | photo/dog/056_0013.jpg 1 13 | photo/dog/056_0014.jpg 1 14 | photo/dog/056_0015.jpg 1 15 | photo/dog/056_0016.jpg 1 16 | photo/dog/056_0017.jpg 1 17 | photo/dog/056_0018.jpg 1 18 | photo/dog/056_0020.jpg 1 19 | photo/dog/056_0021.jpg 1 20 | photo/elephant/064_0001.jpg 2 21 | photo/elephant/064_0002.jpg 2 22 | photo/elephant/064_0003.jpg 2 23 | photo/elephant/064_0004.jpg 2 24 | photo/elephant/064_0005.jpg 2 25 | photo/elephant/064_0006.jpg 2 26 | photo/elephant/064_0007.jpg 2 27 | photo/elephant/064_0008.jpg 2 28 | photo/elephant/064_0009.jpg 2 29 | photo/elephant/064_0010.jpg 2 30 | photo/elephant/064_0011.jpg 2 31 | photo/elephant/064_0012.jpg 2 32 | photo/elephant/064_0013.jpg 2 33 | photo/elephant/064_0014.jpg 2 34 | photo/elephant/064_0015.jpg 2 35 | photo/elephant/064_0016.jpg 2 36 | photo/elephant/064_0017.jpg 2 37 | photo/elephant/064_0018.jpg 2 38 | photo/elephant/064_0019.jpg 2 39 | photo/elephant/064_0020.jpg 2 40 | photo/elephant/064_0021.jpg 2 41 | photo/giraffe/084_0001.jpg 3 42 | photo/giraffe/084_0002.jpg 3 43 | photo/giraffe/084_0003.jpg 3 44 | photo/giraffe/084_0004.jpg 3 45 | photo/giraffe/084_0005.jpg 3 46 | photo/giraffe/084_0006.jpg 3 47 | photo/giraffe/084_0007.jpg 3 48 | photo/giraffe/084_0008.jpg 3 49 | photo/giraffe/084_0009.jpg 3 50 | photo/giraffe/084_0010.jpg 3 51 | photo/giraffe/084_0011.jpg 3 52 | photo/giraffe/084_0012.jpg 3 53 | photo/giraffe/084_0013.jpg 3 54 | photo/giraffe/084_0014.jpg 3 55 | photo/giraffe/084_0015.jpg 3 56 | photo/giraffe/084_0016.jpg 3 57 | photo/giraffe/084_0017.jpg 3 58 | photo/giraffe/084_0018.jpg 3 59 | photo/giraffe/084_0019.jpg 3 60 | photo/guitar/063_0001.jpg 4 61 | photo/guitar/063_0002.jpg 4 62 | photo/guitar/063_0003.jpg 4 63 | photo/guitar/063_0004.jpg 4 64 | photo/guitar/063_0005.jpg 4 65 | photo/guitar/063_0006.jpg 4 66 | photo/guitar/063_0007.jpg 4 67 | photo/guitar/063_0008.jpg 4 68 | photo/guitar/063_0009.jpg 4 69 | photo/guitar/063_0010.jpg 4 70 | photo/guitar/063_0012.jpg 4 71 | photo/guitar/063_0013.jpg 4 72 | photo/guitar/063_0016.jpg 4 73 | photo/guitar/063_0018.jpg 4 74 | photo/guitar/063_0019.jpg 4 75 | photo/guitar/063_0020.jpg 4 76 | photo/guitar/063_0021.jpg 4 77 | photo/guitar/063_0022.jpg 4 78 | photo/guitar/063_0023.jpg 4 79 | photo/horse/105_0002.jpg 5 80 | photo/horse/105_0003.jpg 5 81 | photo/horse/105_0007.jpg 5 82 | photo/horse/105_0008.jpg 5 83 | photo/horse/105_0009.jpg 5 84 | photo/horse/105_0010.jpg 5 85 | photo/horse/105_0012.jpg 5 86 | photo/horse/105_0013.jpg 5 87 | photo/horse/105_0022.jpg 5 88 | photo/horse/105_0025.jpg 5 89 | photo/horse/105_0028.jpg 5 90 | photo/horse/105_0029.jpg 5 91 | photo/horse/105_0030.jpg 5 92 | photo/horse/105_0033.jpg 5 93 | photo/horse/105_0037.jpg 5 94 | photo/horse/105_0038.jpg 5 95 | photo/horse/105_0041.jpg 5 96 | photo/horse/105_0042.jpg 5 97 | photo/horse/105_0047.jpg 5 98 | photo/horse/105_0048.jpg 5 99 | photo/house/pic_010.jpg 6 100 | photo/house/pic_011.jpg 6 101 | photo/house/pic_012.jpg 6 102 | photo/house/pic_013.jpg 6 103 | photo/house/pic_014.jpg 6 104 | photo/house/pic_015.jpg 6 105 | photo/house/pic_016.jpg 6 106 | photo/house/pic_017.jpg 6 107 | photo/house/pic_018.jpg 6 108 | photo/house/pic_021.jpg 6 109 | photo/house/pic_019.jpg 6 110 | photo/house/pic_022.jpg 6 111 | photo/house/pic_020.jpg 6 112 | photo/house/pic_023.jpg 6 113 | photo/house/pic_024.jpg 6 114 | photo/house/pic_026.jpg 6 115 | photo/house/pic_025.jpg 6 116 | photo/house/pic_027.jpg 6 117 | photo/house/pic_028.jpg 6 118 | photo/house/pic_029.jpg 6 119 | photo/house/pic_031.jpg 6 120 | photo/house/pic_239.jpg 6 121 | photo/house/pic_240.jpg 6 122 | photo/house/pic_241.jpg 6 123 | photo/house/pic_242.jpg 6 124 | photo/house/pic_248.jpg 6 125 | photo/house/pic_246.jpg 6 126 | photo/house/pic_247.jpg 6 127 | photo/house/pic_244.jpg 6 128 | photo/person/253_0001.jpg 7 129 | photo/person/253_0002.jpg 7 130 | photo/person/253_0003.jpg 7 131 | photo/person/253_0004.jpg 7 132 | photo/person/253_0005.jpg 7 133 | photo/person/253_0006.jpg 7 134 | photo/person/253_0007.jpg 7 135 | photo/person/253_0008.jpg 7 136 | photo/person/253_0009.jpg 7 137 | photo/person/253_0010.jpg 7 138 | photo/person/253_0011.jpg 7 139 | photo/person/253_0012.jpg 7 140 | photo/person/253_0013.jpg 7 141 | photo/person/253_0014.jpg 7 142 | photo/person/253_0015.jpg 7 143 | photo/person/253_0016.jpg 7 144 | photo/person/253_0017.jpg 7 145 | photo/person/253_0018.jpg 7 146 | photo/person/253_0019.jpg 7 147 | photo/person/253_0020.jpg 7 148 | photo/person/253_0021.jpg 7 149 | photo/person/253_0022.jpg 7 150 | photo/person/253_0023.jpg 7 151 | photo/person/253_0024.jpg 7 152 | photo/person/253_0025.jpg 7 153 | photo/person/253_0026.jpg 7 154 | photo/person/253_0027.jpg 7 155 | photo/person/253_0028.jpg 7 156 | photo/person/253_0029.jpg 7 157 | photo/person/253_0030.jpg 7 158 | photo/person/253_0031.jpg 7 159 | photo/person/253_0032.jpg 7 160 | photo/person/253_0033.jpg 7 161 | photo/person/253_0034.jpg 7 162 | photo/person/253_0035.jpg 7 163 | photo/person/253_0036.jpg 7 164 | photo/person/253_0037.jpg 7 165 | photo/person/253_0038.jpg 7 166 | photo/person/253_0039.jpg 7 167 | photo/person/253_0040.jpg 7 168 | photo/person/253_0041.jpg 7 169 | photo/person/253_0042.jpg 7 170 | photo/person/253_0043.jpg 7 171 | photo/person/253_0044.jpg 7 172 | -------------------------------------------------------------------------------- /dataloader/text_lists/PACS/sketch_val.txt: -------------------------------------------------------------------------------- 1 | sketch/dog/n02103406_343-1.png 1 2 | sketch/dog/n02103406_343-2.png 1 3 | sketch/dog/n02103406_343-3.png 1 4 | sketch/dog/n02103406_343-4.png 1 5 | sketch/dog/n02103406_343-5.png 1 6 | sketch/dog/n02103406_343-6.png 1 7 | sketch/dog/n02103406_343-7.png 1 8 | sketch/dog/n02103406_343-8.png 1 9 | sketch/dog/n02103406_343-9.png 1 10 | sketch/dog/n02103406_346-1.png 1 11 | sketch/dog/n02103406_346-2.png 1 12 | sketch/dog/n02103406_346-3.png 1 13 | sketch/dog/n02103406_346-4.png 1 14 | sketch/dog/n02103406_346-5.png 1 15 | sketch/dog/n02103406_346-6.png 1 16 | sketch/dog/n02103406_346-7.png 1 17 | sketch/dog/n02103406_371-1.png 1 18 | sketch/dog/n02103406_371-2.png 1 19 | sketch/dog/n02103406_371-3.png 1 20 | sketch/dog/n02103406_371-4.png 1 21 | sketch/dog/n02103406_371-5.png 1 22 | sketch/dog/n02103406_371-6.png 1 23 | sketch/dog/n02103406_371-7.png 1 24 | sketch/dog/n02103406_371-8.png 1 25 | sketch/dog/n02103406_371-9.png 1 26 | sketch/dog/n02103406_371-10.png 1 27 | sketch/dog/n02103406_371-11.png 1 28 | sketch/dog/n02103406_651-1.png 1 29 | sketch/dog/n02103406_651-2.png 1 30 | sketch/dog/n02103406_651-3.png 1 31 | sketch/dog/n02103406_651-4.png 1 32 | sketch/dog/n02103406_651-5.png 1 33 | sketch/dog/n02103406_651-6.png 1 34 | sketch/dog/n02103406_651-7.png 1 35 | sketch/dog/n02103406_865-1.png 1 36 | sketch/dog/n02103406_865-2.png 1 37 | sketch/dog/n02103406_865-3.png 1 38 | sketch/dog/n02103406_865-4.png 1 39 | sketch/dog/n02103406_865-5.png 1 40 | sketch/dog/n02103406_865-6.png 1 41 | sketch/dog/n02103406_865-7.png 1 42 | sketch/dog/n02103406_865-8.png 1 43 | sketch/dog/n02103406_865-9.png 1 44 | sketch/dog/n02103406_865-10.png 1 45 | sketch/dog/n02103406_865-11.png 1 46 | sketch/dog/n02103406_936-1.png 1 47 | sketch/dog/n02103406_936-2.png 1 48 | sketch/dog/n02103406_936-3.png 1 49 | sketch/dog/n02103406_936-4.png 1 50 | sketch/dog/n02103406_936-5.png 1 51 | sketch/dog/n02103406_936-6.png 1 52 | sketch/dog/n02103406_936-7.png 1 53 | sketch/dog/n02103406_936-8.png 1 54 | sketch/dog/n02103406_936-9.png 1 55 | sketch/dog/n02103406_995-1.png 1 56 | sketch/dog/n02103406_995-2.png 1 57 | sketch/dog/n02103406_995-3.png 1 58 | sketch/dog/n02103406_995-4.png 1 59 | sketch/dog/n02103406_995-5.png 1 60 | sketch/dog/n02103406_995-6.png 1 61 | sketch/dog/n02103406_1011-1.png 1 62 | sketch/dog/n02103406_1011-2.png 1 63 | sketch/dog/n02103406_1011-3.png 1 64 | sketch/dog/n02103406_1011-4.png 1 65 | sketch/dog/n02103406_1011-5.png 1 66 | sketch/dog/n02103406_1138-1.png 1 67 | sketch/dog/n02103406_1138-2.png 1 68 | sketch/dog/n02103406_1138-3.png 1 69 | sketch/dog/n02103406_1138-4.png 1 70 | sketch/dog/n02103406_1138-5.png 1 71 | sketch/dog/n02103406_1138-6.png 1 72 | sketch/dog/n02103406_1138-7.png 1 73 | sketch/dog/n02103406_1138-8.png 1 74 | sketch/dog/n02103406_1170-1.png 1 75 | sketch/dog/n02103406_1170-2.png 1 76 | sketch/dog/n02103406_1170-3.png 1 77 | sketch/dog/n02103406_1170-4.png 1 78 | sketch/dog/n02103406_1170-5.png 1 79 | sketch/elephant/n02503517_79-1.png 2 80 | sketch/elephant/n02503517_79-2.png 2 81 | sketch/elephant/n02503517_79-3.png 2 82 | sketch/elephant/n02503517_79-4.png 2 83 | sketch/elephant/n02503517_79-5.png 2 84 | sketch/elephant/n02503517_86-1.png 2 85 | sketch/elephant/n02503517_86-2.png 2 86 | sketch/elephant/n02503517_86-3.png 2 87 | sketch/elephant/n02503517_86-4.png 2 88 | sketch/elephant/n02503517_86-5.png 2 89 | sketch/elephant/n02503517_86-6.png 2 90 | sketch/elephant/n02503517_184-1.png 2 91 | sketch/elephant/n02503517_184-2.png 2 92 | sketch/elephant/n02503517_184-3.png 2 93 | sketch/elephant/n02503517_184-4.png 2 94 | sketch/elephant/n02503517_184-5.png 2 95 | sketch/elephant/n02503517_184-6.png 2 96 | sketch/elephant/n02503517_184-7.png 2 97 | sketch/elephant/n02503517_184-8.png 2 98 | sketch/elephant/n02503517_184-9.png 2 99 | sketch/elephant/n02503517_194-1.png 2 100 | sketch/elephant/n02503517_194-2.png 2 101 | sketch/elephant/n02503517_194-3.png 2 102 | sketch/elephant/n02503517_194-4.png 2 103 | sketch/elephant/n02503517_194-5.png 2 104 | sketch/elephant/n02503517_194-6.png 2 105 | sketch/elephant/n02503517_311-1.png 2 106 | sketch/elephant/n02503517_311-2.png 2 107 | sketch/elephant/n02503517_311-3.png 2 108 | sketch/elephant/n02503517_311-4.png 2 109 | sketch/elephant/n02503517_311-5.png 2 110 | sketch/elephant/n02503517_311-6.png 2 111 | sketch/elephant/n02503517_564-1.png 2 112 | sketch/elephant/n02503517_564-2.png 2 113 | sketch/elephant/n02503517_564-3.png 2 114 | sketch/elephant/n02503517_564-4.png 2 115 | sketch/elephant/n02503517_564-5.png 2 116 | sketch/elephant/n02503517_753-1.png 2 117 | sketch/elephant/n02503517_753-2.png 2 118 | sketch/elephant/n02503517_753-3.png 2 119 | sketch/elephant/n02503517_753-4.png 2 120 | sketch/elephant/n02503517_753-5.png 2 121 | sketch/elephant/n02503517_753-6.png 2 122 | sketch/elephant/n02503517_759-1.png 2 123 | sketch/elephant/n02503517_759-2.png 2 124 | sketch/elephant/n02503517_759-3.png 2 125 | sketch/elephant/n02503517_759-4.png 2 126 | sketch/elephant/n02503517_759-5.png 2 127 | sketch/elephant/n02503517_759-6.png 2 128 | sketch/elephant/n02503517_759-7.png 2 129 | sketch/elephant/n02503517_759-8.png 2 130 | sketch/elephant/n02503517_792-1.png 2 131 | sketch/elephant/n02503517_792-2.png 2 132 | sketch/elephant/n02503517_792-3.png 2 133 | sketch/elephant/n02503517_792-4.png 2 134 | sketch/elephant/n02503517_792-5.png 2 135 | sketch/elephant/n02503517_1292-1.png 2 136 | sketch/elephant/n02503517_1292-2.png 2 137 | sketch/elephant/n02503517_1292-3.png 2 138 | sketch/elephant/n02503517_1292-4.png 2 139 | sketch/elephant/n02503517_1292-5.png 2 140 | sketch/elephant/n02503517_1292-6.png 2 141 | sketch/elephant/n02503517_1292-7.png 2 142 | sketch/elephant/n02503517_1359-1.png 2 143 | sketch/elephant/n02503517_1359-2.png 2 144 | sketch/elephant/n02503517_1359-3.png 2 145 | sketch/elephant/n02503517_1359-4.png 2 146 | sketch/elephant/n02503517_1359-5.png 2 147 | sketch/elephant/n02503517_1359-6.png 2 148 | sketch/elephant/n02503517_1383-1.png 2 149 | sketch/elephant/n02503517_1383-2.png 2 150 | sketch/elephant/n02503517_1383-3.png 2 151 | sketch/elephant/n02503517_1383-4.png 2 152 | sketch/elephant/n02503517_1383-5.png 2 153 | sketch/elephant/n02503517_1383-6.png 2 154 | sketch/giraffe/n02439033_67-1.png 3 155 | sketch/giraffe/n02439033_67-2.png 3 156 | sketch/giraffe/n02439033_67-3.png 3 157 | sketch/giraffe/n02439033_67-4.png 3 158 | sketch/giraffe/n02439033_67-5.png 3 159 | sketch/giraffe/n02439033_67-6.png 3 160 | sketch/giraffe/n02439033_67-7.png 3 161 | sketch/giraffe/n02439033_221-1.png 3 162 | sketch/giraffe/n02439033_221-2.png 3 163 | sketch/giraffe/n02439033_221-3.png 3 164 | sketch/giraffe/n02439033_221-4.png 3 165 | sketch/giraffe/n02439033_221-5.png 3 166 | sketch/giraffe/n02439033_221-6.png 3 167 | sketch/giraffe/n02439033_376-1.png 3 168 | sketch/giraffe/n02439033_376-2.png 3 169 | sketch/giraffe/n02439033_376-3.png 3 170 | sketch/giraffe/n02439033_376-4.png 3 171 | sketch/giraffe/n02439033_376-5.png 3 172 | sketch/giraffe/n02439033_569-1.png 3 173 | sketch/giraffe/n02439033_569-2.png 3 174 | sketch/giraffe/n02439033_569-3.png 3 175 | sketch/giraffe/n02439033_569-4.png 3 176 | sketch/giraffe/n02439033_569-5.png 3 177 | sketch/giraffe/n02439033_628-1.png 3 178 | sketch/giraffe/n02439033_628-2.png 3 179 | sketch/giraffe/n02439033_628-3.png 3 180 | sketch/giraffe/n02439033_628-4.png 3 181 | sketch/giraffe/n02439033_628-5.png 3 182 | sketch/giraffe/n02439033_628-6.png 3 183 | sketch/giraffe/n02439033_628-7.png 3 184 | sketch/giraffe/n02439033_628-8.png 3 185 | sketch/giraffe/n02439033_628-9.png 3 186 | sketch/giraffe/n02439033_628-10.png 3 187 | sketch/giraffe/n02439033_866-1.png 3 188 | sketch/giraffe/n02439033_866-2.png 3 189 | sketch/giraffe/n02439033_866-3.png 3 190 | sketch/giraffe/n02439033_866-4.png 3 191 | sketch/giraffe/n02439033_866-5.png 3 192 | sketch/giraffe/n02439033_866-6.png 3 193 | sketch/giraffe/n02439033_991-1.png 3 194 | sketch/giraffe/n02439033_991-2.png 3 195 | sketch/giraffe/n02439033_991-3.png 3 196 | sketch/giraffe/n02439033_991-4.png 3 197 | sketch/giraffe/n02439033_991-5.png 3 198 | sketch/giraffe/n02439033_991-6.png 3 199 | sketch/giraffe/n02439033_991-7.png 3 200 | sketch/giraffe/n02439033_991-8.png 3 201 | sketch/giraffe/n02439033_991-9.png 3 202 | sketch/giraffe/n02439033_1327-1.png 3 203 | sketch/giraffe/n02439033_1327-2.png 3 204 | sketch/giraffe/n02439033_1327-3.png 3 205 | sketch/giraffe/n02439033_1327-4.png 3 206 | sketch/giraffe/n02439033_1327-5.png 3 207 | sketch/giraffe/n02439033_1508-1.png 3 208 | sketch/giraffe/n02439033_1508-2.png 3 209 | sketch/giraffe/n02439033_1508-3.png 3 210 | sketch/giraffe/n02439033_1508-4.png 3 211 | sketch/giraffe/n02439033_1508-5.png 3 212 | sketch/giraffe/n02439033_1508-6.png 3 213 | sketch/giraffe/n02439033_1508-7.png 3 214 | sketch/giraffe/n02439033_1508-8.png 3 215 | sketch/giraffe/n02439033_1508-9.png 3 216 | sketch/giraffe/n02439033_2486-1.png 3 217 | sketch/giraffe/n02439033_2486-2.png 3 218 | sketch/giraffe/n02439033_2486-3.png 3 219 | sketch/giraffe/n02439033_2486-4.png 3 220 | sketch/giraffe/n02439033_2486-5.png 3 221 | sketch/giraffe/n02439033_2486-6.png 3 222 | sketch/giraffe/n02439033_2500-1.png 3 223 | sketch/giraffe/n02439033_2500-2.png 3 224 | sketch/giraffe/n02439033_2500-3.png 3 225 | sketch/giraffe/n02439033_2500-4.png 3 226 | sketch/giraffe/n02439033_2500-5.png 3 227 | sketch/giraffe/n02439033_2500-6.png 3 228 | sketch/giraffe/n02439033_2677-1.png 3 229 | sketch/giraffe/n02439033_2677-2.png 3 230 | sketch/guitar/7601.png 4 231 | sketch/guitar/7602.png 4 232 | sketch/guitar/7603.png 4 233 | sketch/guitar/7604.png 4 234 | sketch/guitar/7605.png 4 235 | sketch/guitar/7606.png 4 236 | sketch/guitar/7607.png 4 237 | sketch/guitar/7608.png 4 238 | sketch/guitar/7609.png 4 239 | sketch/guitar/7610.png 4 240 | sketch/guitar/7611.png 4 241 | sketch/guitar/7612.png 4 242 | sketch/guitar/7613.png 4 243 | sketch/guitar/7614.png 4 244 | sketch/guitar/7615.png 4 245 | sketch/guitar/7616.png 4 246 | sketch/guitar/7617.png 4 247 | sketch/guitar/7618.png 4 248 | sketch/guitar/7619.png 4 249 | sketch/guitar/7620.png 4 250 | sketch/guitar/7621.png 4 251 | sketch/guitar/7622.png 4 252 | sketch/guitar/7623.png 4 253 | sketch/guitar/7624.png 4 254 | sketch/guitar/7625.png 4 255 | sketch/guitar/7626.png 4 256 | sketch/guitar/7627.png 4 257 | sketch/guitar/7628.png 4 258 | sketch/guitar/7629.png 4 259 | sketch/guitar/7630.png 4 260 | sketch/guitar/7631.png 4 261 | sketch/guitar/7632.png 4 262 | sketch/guitar/7633.png 4 263 | sketch/guitar/7634.png 4 264 | sketch/guitar/7635.png 4 265 | sketch/guitar/7636.png 4 266 | sketch/guitar/7637.png 4 267 | sketch/guitar/7638.png 4 268 | sketch/guitar/7639.png 4 269 | sketch/guitar/7640.png 4 270 | sketch/guitar/7641.png 4 271 | sketch/guitar/7642.png 4 272 | sketch/guitar/7643.png 4 273 | sketch/guitar/7644.png 4 274 | sketch/guitar/7645.png 4 275 | sketch/guitar/7646.png 4 276 | sketch/guitar/7647.png 4 277 | sketch/guitar/7648.png 4 278 | sketch/guitar/7649.png 4 279 | sketch/guitar/7650.png 4 280 | sketch/guitar/7651.png 4 281 | sketch/guitar/7652.png 4 282 | sketch/guitar/7653.png 4 283 | sketch/guitar/7654.png 4 284 | sketch/guitar/7655.png 4 285 | sketch/guitar/7656.png 4 286 | sketch/guitar/7657.png 4 287 | sketch/guitar/7658.png 4 288 | sketch/guitar/7659.png 4 289 | sketch/guitar/7660.png 4 290 | sketch/guitar/7661.png 4 291 | sketch/horse/n02374451_54-1.png 5 292 | sketch/horse/n02374451_54-2.png 5 293 | sketch/horse/n02374451_54-3.png 5 294 | sketch/horse/n02374451_54-4.png 5 295 | sketch/horse/n02374451_54-5.png 5 296 | sketch/horse/n02374451_54-6.png 5 297 | sketch/horse/n02374451_54-7.png 5 298 | sketch/horse/n02374451_54-8.png 5 299 | sketch/horse/n02374451_54-9.png 5 300 | sketch/horse/n02374451_54-10.png 5 301 | sketch/horse/n02374451_245-1.png 5 302 | sketch/horse/n02374451_245-2.png 5 303 | sketch/horse/n02374451_245-3.png 5 304 | sketch/horse/n02374451_245-4.png 5 305 | sketch/horse/n02374451_245-5.png 5 306 | sketch/horse/n02374451_245-6.png 5 307 | sketch/horse/n02374451_257-1.png 5 308 | sketch/horse/n02374451_257-2.png 5 309 | sketch/horse/n02374451_257-3.png 5 310 | sketch/horse/n02374451_257-4.png 5 311 | sketch/horse/n02374451_257-5.png 5 312 | sketch/horse/n02374451_257-6.png 5 313 | sketch/horse/n02374451_257-7.png 5 314 | sketch/horse/n02374451_262-1.png 5 315 | sketch/horse/n02374451_262-2.png 5 316 | sketch/horse/n02374451_262-3.png 5 317 | sketch/horse/n02374451_262-4.png 5 318 | sketch/horse/n02374451_262-5.png 5 319 | sketch/horse/n02374451_262-6.png 5 320 | sketch/horse/n02374451_262-7.png 5 321 | sketch/horse/n02374451_262-8.png 5 322 | sketch/horse/n02374451_262-9.png 5 323 | sketch/horse/n02374451_262-10.png 5 324 | sketch/horse/n02374451_262-11.png 5 325 | sketch/horse/n02374451_262-12.png 5 326 | sketch/horse/n02374451_276-1.png 5 327 | sketch/horse/n02374451_276-2.png 5 328 | sketch/horse/n02374451_276-3.png 5 329 | sketch/horse/n02374451_276-4.png 5 330 | sketch/horse/n02374451_276-5.png 5 331 | sketch/horse/n02374451_276-6.png 5 332 | sketch/horse/n02374451_276-7.png 5 333 | sketch/horse/n02374451_276-8.png 5 334 | sketch/horse/n02374451_276-9.png 5 335 | sketch/horse/n02374451_276-10.png 5 336 | sketch/horse/n02374451_388-2.png 5 337 | sketch/horse/n02374451_388-3.png 5 338 | sketch/horse/n02374451_388-4.png 5 339 | sketch/horse/n02374451_388-5.png 5 340 | sketch/horse/n02374451_388-6.png 5 341 | sketch/horse/n02374451_388-7.png 5 342 | sketch/horse/n02374451_388-8.png 5 343 | sketch/horse/n02374451_388-9.png 5 344 | sketch/horse/n02374451_388-10.png 5 345 | sketch/horse/n02374451_468-1.png 5 346 | sketch/horse/n02374451_468-2.png 5 347 | sketch/horse/n02374451_468-3.png 5 348 | sketch/horse/n02374451_468-4.png 5 349 | sketch/horse/n02374451_468-5.png 5 350 | sketch/horse/n02374451_468-6.png 5 351 | sketch/horse/n02374451_468-7.png 5 352 | sketch/horse/n02374451_468-8.png 5 353 | sketch/horse/n02374451_468-9.png 5 354 | sketch/horse/n02374451_468-10.png 5 355 | sketch/horse/n02374451_490-1.png 5 356 | sketch/horse/n02374451_490-2.png 5 357 | sketch/horse/n02374451_490-3.png 5 358 | sketch/horse/n02374451_490-4.png 5 359 | sketch/horse/n02374451_490-5.png 5 360 | sketch/horse/n02374451_490-6.png 5 361 | sketch/horse/n02374451_490-7.png 5 362 | sketch/horse/n02374451_503-1.png 5 363 | sketch/horse/n02374451_503-2.png 5 364 | sketch/horse/n02374451_503-3.png 5 365 | sketch/horse/n02374451_503-4.png 5 366 | sketch/horse/n02374451_503-5.png 5 367 | sketch/horse/n02374451_503-6.png 5 368 | sketch/horse/n02374451_557-1.png 5 369 | sketch/horse/n02374451_557-2.png 5 370 | sketch/horse/n02374451_557-3.png 5 371 | sketch/horse/n02374451_557-4.png 5 372 | sketch/horse/n02374451_557-5.png 5 373 | sketch/house/8801.png 6 374 | sketch/house/8802.png 6 375 | sketch/house/8803.png 6 376 | sketch/house/8804.png 6 377 | sketch/house/8805.png 6 378 | sketch/house/8806.png 6 379 | sketch/house/8807.png 6 380 | sketch/house/8808.png 6 381 | sketch/house/8809.png 6 382 | sketch/person/12081.png 7 383 | sketch/person/12082.png 7 384 | sketch/person/12083.png 7 385 | sketch/person/12084.png 7 386 | sketch/person/12085.png 7 387 | sketch/person/12086.png 7 388 | sketch/person/12087.png 7 389 | sketch/person/12088.png 7 390 | sketch/person/12089.png 7 391 | sketch/person/12090.png 7 392 | sketch/person/12091.png 7 393 | sketch/person/12092.png 7 394 | sketch/person/12093.png 7 395 | sketch/person/12094.png 7 396 | sketch/person/12095.png 7 397 | sketch/person/12096.png 7 398 | sketch/person/12097.png 7 399 | -------------------------------------------------------------------------------- /dataloader/text_lists/VLCS/CALTECH_val.txt: -------------------------------------------------------------------------------- 1 | CALTECH/test/0/test_imgs_1.jpg 0 2 | CALTECH/test/0/test_imgs_10.jpg 0 3 | CALTECH/test/0/test_imgs_11.jpg 0 4 | CALTECH/test/0/test_imgs_12.jpg 0 5 | CALTECH/test/0/test_imgs_13.jpg 0 6 | CALTECH/test/0/test_imgs_14.jpg 0 7 | CALTECH/test/0/test_imgs_15.jpg 0 8 | CALTECH/test/0/test_imgs_16.jpg 0 9 | CALTECH/test/0/test_imgs_17.jpg 0 10 | CALTECH/test/0/test_imgs_18.jpg 0 11 | CALTECH/test/0/test_imgs_19.jpg 0 12 | CALTECH/test/0/test_imgs_2.jpg 0 13 | CALTECH/test/0/test_imgs_20.jpg 0 14 | CALTECH/test/0/test_imgs_21.jpg 0 15 | CALTECH/test/0/test_imgs_22.jpg 0 16 | CALTECH/test/0/test_imgs_23.jpg 0 17 | CALTECH/test/0/test_imgs_24.jpg 0 18 | CALTECH/test/0/test_imgs_25.jpg 0 19 | CALTECH/test/0/test_imgs_26.jpg 0 20 | CALTECH/test/0/test_imgs_27.jpg 0 21 | CALTECH/test/0/test_imgs_28.jpg 0 22 | CALTECH/test/0/test_imgs_29.jpg 0 23 | CALTECH/test/0/test_imgs_3.jpg 0 24 | CALTECH/test/0/test_imgs_30.jpg 0 25 | CALTECH/test/0/test_imgs_31.jpg 0 26 | CALTECH/test/0/test_imgs_32.jpg 0 27 | CALTECH/test/0/test_imgs_33.jpg 0 28 | CALTECH/test/0/test_imgs_34.jpg 0 29 | CALTECH/test/0/test_imgs_35.jpg 0 30 | CALTECH/test/0/test_imgs_36.jpg 0 31 | CALTECH/test/0/test_imgs_37.jpg 0 32 | CALTECH/test/0/test_imgs_38.jpg 0 33 | CALTECH/test/0/test_imgs_39.jpg 0 34 | CALTECH/test/0/test_imgs_4.jpg 0 35 | CALTECH/test/0/test_imgs_40.jpg 0 36 | CALTECH/test/0/test_imgs_41.jpg 0 37 | CALTECH/test/0/test_imgs_42.jpg 0 38 | CALTECH/test/0/test_imgs_43.jpg 0 39 | CALTECH/test/0/test_imgs_44.jpg 0 40 | CALTECH/test/0/test_imgs_45.jpg 0 41 | CALTECH/test/0/test_imgs_46.jpg 0 42 | CALTECH/test/0/test_imgs_47.jpg 0 43 | CALTECH/test/0/test_imgs_48.jpg 0 44 | CALTECH/test/0/test_imgs_49.jpg 0 45 | CALTECH/test/0/test_imgs_5.jpg 0 46 | CALTECH/test/0/test_imgs_50.jpg 0 47 | CALTECH/test/0/test_imgs_51.jpg 0 48 | CALTECH/test/0/test_imgs_52.jpg 0 49 | CALTECH/test/0/test_imgs_53.jpg 0 50 | CALTECH/test/0/test_imgs_54.jpg 0 51 | CALTECH/test/0/test_imgs_55.jpg 0 52 | CALTECH/test/0/test_imgs_56.jpg 0 53 | CALTECH/test/0/test_imgs_57.jpg 0 54 | CALTECH/test/0/test_imgs_58.jpg 0 55 | CALTECH/test/0/test_imgs_59.jpg 0 56 | CALTECH/test/0/test_imgs_6.jpg 0 57 | CALTECH/test/0/test_imgs_60.jpg 0 58 | CALTECH/test/0/test_imgs_61.jpg 0 59 | CALTECH/test/0/test_imgs_62.jpg 0 60 | CALTECH/test/0/test_imgs_63.jpg 0 61 | CALTECH/test/0/test_imgs_64.jpg 0 62 | CALTECH/test/0/test_imgs_65.jpg 0 63 | CALTECH/test/0/test_imgs_66.jpg 0 64 | CALTECH/test/0/test_imgs_67.jpg 0 65 | CALTECH/test/0/test_imgs_68.jpg 0 66 | CALTECH/test/0/test_imgs_69.jpg 0 67 | CALTECH/test/0/test_imgs_7.jpg 0 68 | CALTECH/test/0/test_imgs_70.jpg 0 69 | CALTECH/test/0/test_imgs_71.jpg 0 70 | CALTECH/test/0/test_imgs_8.jpg 0 71 | CALTECH/test/0/test_imgs_9.jpg 0 72 | CALTECH/test/1/test_imgs_100.jpg 1 73 | CALTECH/test/1/test_imgs_101.jpg 1 74 | CALTECH/test/1/test_imgs_102.jpg 1 75 | CALTECH/test/1/test_imgs_103.jpg 1 76 | CALTECH/test/1/test_imgs_104.jpg 1 77 | CALTECH/test/1/test_imgs_105.jpg 1 78 | CALTECH/test/1/test_imgs_106.jpg 1 79 | CALTECH/test/1/test_imgs_107.jpg 1 80 | CALTECH/test/1/test_imgs_108.jpg 1 81 | CALTECH/test/1/test_imgs_72.jpg 1 82 | CALTECH/test/1/test_imgs_73.jpg 1 83 | CALTECH/test/1/test_imgs_74.jpg 1 84 | CALTECH/test/1/test_imgs_75.jpg 1 85 | CALTECH/test/1/test_imgs_76.jpg 1 86 | CALTECH/test/1/test_imgs_77.jpg 1 87 | CALTECH/test/1/test_imgs_78.jpg 1 88 | CALTECH/test/1/test_imgs_79.jpg 1 89 | CALTECH/test/1/test_imgs_80.jpg 1 90 | CALTECH/test/1/test_imgs_81.jpg 1 91 | CALTECH/test/1/test_imgs_82.jpg 1 92 | CALTECH/test/1/test_imgs_83.jpg 1 93 | CALTECH/test/1/test_imgs_84.jpg 1 94 | CALTECH/test/1/test_imgs_85.jpg 1 95 | CALTECH/test/1/test_imgs_86.jpg 1 96 | CALTECH/test/1/test_imgs_87.jpg 1 97 | CALTECH/test/1/test_imgs_88.jpg 1 98 | CALTECH/test/1/test_imgs_89.jpg 1 99 | CALTECH/test/1/test_imgs_90.jpg 1 100 | CALTECH/test/1/test_imgs_91.jpg 1 101 | CALTECH/test/1/test_imgs_92.jpg 1 102 | CALTECH/test/1/test_imgs_93.jpg 1 103 | CALTECH/test/1/test_imgs_94.jpg 1 104 | CALTECH/test/1/test_imgs_95.jpg 1 105 | CALTECH/test/1/test_imgs_96.jpg 1 106 | CALTECH/test/1/test_imgs_97.jpg 1 107 | CALTECH/test/1/test_imgs_98.jpg 1 108 | CALTECH/test/1/test_imgs_99.jpg 1 109 | CALTECH/test/2/test_imgs_109.jpg 2 110 | CALTECH/test/2/test_imgs_110.jpg 2 111 | CALTECH/test/2/test_imgs_111.jpg 2 112 | CALTECH/test/2/test_imgs_112.jpg 2 113 | CALTECH/test/2/test_imgs_113.jpg 2 114 | CALTECH/test/2/test_imgs_114.jpg 2 115 | CALTECH/test/2/test_imgs_115.jpg 2 116 | CALTECH/test/2/test_imgs_116.jpg 2 117 | CALTECH/test/2/test_imgs_117.jpg 2 118 | CALTECH/test/2/test_imgs_118.jpg 2 119 | CALTECH/test/2/test_imgs_119.jpg 2 120 | CALTECH/test/2/test_imgs_120.jpg 2 121 | CALTECH/test/2/test_imgs_121.jpg 2 122 | CALTECH/test/2/test_imgs_122.jpg 2 123 | CALTECH/test/2/test_imgs_123.jpg 2 124 | CALTECH/test/2/test_imgs_124.jpg 2 125 | CALTECH/test/2/test_imgs_125.jpg 2 126 | CALTECH/test/2/test_imgs_126.jpg 2 127 | CALTECH/test/2/test_imgs_127.jpg 2 128 | CALTECH/test/2/test_imgs_128.jpg 2 129 | CALTECH/test/2/test_imgs_129.jpg 2 130 | CALTECH/test/2/test_imgs_130.jpg 2 131 | CALTECH/test/2/test_imgs_131.jpg 2 132 | CALTECH/test/2/test_imgs_132.jpg 2 133 | CALTECH/test/2/test_imgs_133.jpg 2 134 | CALTECH/test/2/test_imgs_134.jpg 2 135 | CALTECH/test/2/test_imgs_135.jpg 2 136 | CALTECH/test/2/test_imgs_136.jpg 2 137 | CALTECH/test/2/test_imgs_137.jpg 2 138 | CALTECH/test/2/test_imgs_138.jpg 2 139 | CALTECH/test/2/test_imgs_139.jpg 2 140 | CALTECH/test/2/test_imgs_140.jpg 2 141 | CALTECH/test/2/test_imgs_141.jpg 2 142 | CALTECH/test/2/test_imgs_142.jpg 2 143 | CALTECH/test/2/test_imgs_143.jpg 2 144 | CALTECH/test/3/test_imgs_144.jpg 3 145 | CALTECH/test/3/test_imgs_145.jpg 3 146 | CALTECH/test/3/test_imgs_146.jpg 3 147 | CALTECH/test/3/test_imgs_147.jpg 3 148 | CALTECH/test/3/test_imgs_148.jpg 3 149 | CALTECH/test/3/test_imgs_149.jpg 3 150 | CALTECH/test/3/test_imgs_150.jpg 3 151 | CALTECH/test/3/test_imgs_151.jpg 3 152 | CALTECH/test/3/test_imgs_152.jpg 3 153 | CALTECH/test/3/test_imgs_153.jpg 3 154 | CALTECH/test/3/test_imgs_154.jpg 3 155 | CALTECH/test/3/test_imgs_155.jpg 3 156 | CALTECH/test/3/test_imgs_156.jpg 3 157 | CALTECH/test/3/test_imgs_157.jpg 3 158 | CALTECH/test/3/test_imgs_158.jpg 3 159 | CALTECH/test/3/test_imgs_159.jpg 3 160 | CALTECH/test/3/test_imgs_160.jpg 3 161 | CALTECH/test/3/test_imgs_161.jpg 3 162 | CALTECH/test/3/test_imgs_162.jpg 3 163 | CALTECH/test/3/test_imgs_163.jpg 3 164 | CALTECH/test/4/test_imgs_164.jpg 4 165 | CALTECH/test/4/test_imgs_165.jpg 4 166 | CALTECH/test/4/test_imgs_166.jpg 4 167 | CALTECH/test/4/test_imgs_167.jpg 4 168 | CALTECH/test/4/test_imgs_168.jpg 4 169 | CALTECH/test/4/test_imgs_169.jpg 4 170 | CALTECH/test/4/test_imgs_170.jpg 4 171 | CALTECH/test/4/test_imgs_171.jpg 4 172 | CALTECH/test/4/test_imgs_172.jpg 4 173 | CALTECH/test/4/test_imgs_173.jpg 4 174 | CALTECH/test/4/test_imgs_174.jpg 4 175 | CALTECH/test/4/test_imgs_175.jpg 4 176 | CALTECH/test/4/test_imgs_176.jpg 4 177 | CALTECH/test/4/test_imgs_177.jpg 4 178 | CALTECH/test/4/test_imgs_178.jpg 4 179 | CALTECH/test/4/test_imgs_179.jpg 4 180 | CALTECH/test/4/test_imgs_180.jpg 4 181 | CALTECH/test/4/test_imgs_181.jpg 4 182 | CALTECH/test/4/test_imgs_182.jpg 4 183 | CALTECH/test/4/test_imgs_183.jpg 4 184 | CALTECH/test/4/test_imgs_184.jpg 4 185 | CALTECH/test/4/test_imgs_185.jpg 4 186 | CALTECH/test/4/test_imgs_186.jpg 4 187 | CALTECH/test/4/test_imgs_187.jpg 4 188 | CALTECH/test/4/test_imgs_188.jpg 4 189 | CALTECH/test/4/test_imgs_189.jpg 4 190 | CALTECH/test/4/test_imgs_190.jpg 4 191 | CALTECH/test/4/test_imgs_191.jpg 4 192 | CALTECH/test/4/test_imgs_192.jpg 4 193 | CALTECH/test/4/test_imgs_193.jpg 4 194 | CALTECH/test/4/test_imgs_194.jpg 4 195 | CALTECH/test/4/test_imgs_195.jpg 4 196 | CALTECH/test/4/test_imgs_196.jpg 4 197 | CALTECH/test/4/test_imgs_197.jpg 4 198 | CALTECH/test/4/test_imgs_198.jpg 4 199 | CALTECH/test/4/test_imgs_199.jpg 4 200 | CALTECH/test/4/test_imgs_200.jpg 4 201 | CALTECH/test/4/test_imgs_201.jpg 4 202 | CALTECH/test/4/test_imgs_202.jpg 4 203 | CALTECH/test/4/test_imgs_203.jpg 4 204 | CALTECH/test/4/test_imgs_204.jpg 4 205 | CALTECH/test/4/test_imgs_205.jpg 4 206 | CALTECH/test/4/test_imgs_206.jpg 4 207 | CALTECH/test/4/test_imgs_207.jpg 4 208 | CALTECH/test/4/test_imgs_208.jpg 4 209 | CALTECH/test/4/test_imgs_209.jpg 4 210 | CALTECH/test/4/test_imgs_210.jpg 4 211 | CALTECH/test/4/test_imgs_211.jpg 4 212 | CALTECH/test/4/test_imgs_212.jpg 4 213 | CALTECH/test/4/test_imgs_213.jpg 4 214 | CALTECH/test/4/test_imgs_214.jpg 4 215 | CALTECH/test/4/test_imgs_215.jpg 4 216 | CALTECH/test/4/test_imgs_216.jpg 4 217 | CALTECH/test/4/test_imgs_217.jpg 4 218 | CALTECH/test/4/test_imgs_218.jpg 4 219 | CALTECH/test/4/test_imgs_219.jpg 4 220 | CALTECH/test/4/test_imgs_220.jpg 4 221 | CALTECH/test/4/test_imgs_221.jpg 4 222 | CALTECH/test/4/test_imgs_222.jpg 4 223 | CALTECH/test/4/test_imgs_223.jpg 4 224 | CALTECH/test/4/test_imgs_224.jpg 4 225 | CALTECH/test/4/test_imgs_225.jpg 4 226 | CALTECH/test/4/test_imgs_226.jpg 4 227 | CALTECH/test/4/test_imgs_227.jpg 4 228 | CALTECH/test/4/test_imgs_228.jpg 4 229 | CALTECH/test/4/test_imgs_229.jpg 4 230 | CALTECH/test/4/test_imgs_230.jpg 4 231 | CALTECH/test/4/test_imgs_231.jpg 4 232 | CALTECH/test/4/test_imgs_232.jpg 4 233 | CALTECH/test/4/test_imgs_233.jpg 4 234 | CALTECH/test/4/test_imgs_234.jpg 4 235 | CALTECH/test/4/test_imgs_235.jpg 4 236 | CALTECH/test/4/test_imgs_236.jpg 4 237 | CALTECH/test/4/test_imgs_237.jpg 4 238 | CALTECH/test/4/test_imgs_238.jpg 4 239 | CALTECH/test/4/test_imgs_239.jpg 4 240 | CALTECH/test/4/test_imgs_240.jpg 4 241 | CALTECH/test/4/test_imgs_241.jpg 4 242 | CALTECH/test/4/test_imgs_242.jpg 4 243 | CALTECH/test/4/test_imgs_243.jpg 4 244 | CALTECH/test/4/test_imgs_244.jpg 4 245 | CALTECH/test/4/test_imgs_245.jpg 4 246 | CALTECH/test/4/test_imgs_246.jpg 4 247 | CALTECH/test/4/test_imgs_247.jpg 4 248 | CALTECH/test/4/test_imgs_248.jpg 4 249 | CALTECH/test/4/test_imgs_249.jpg 4 250 | CALTECH/test/4/test_imgs_250.jpg 4 251 | CALTECH/test/4/test_imgs_251.jpg 4 252 | CALTECH/test/4/test_imgs_252.jpg 4 253 | CALTECH/test/4/test_imgs_253.jpg 4 254 | CALTECH/test/4/test_imgs_254.jpg 4 255 | CALTECH/test/4/test_imgs_255.jpg 4 256 | CALTECH/test/4/test_imgs_256.jpg 4 257 | CALTECH/test/4/test_imgs_257.jpg 4 258 | CALTECH/test/4/test_imgs_258.jpg 4 259 | CALTECH/test/4/test_imgs_259.jpg 4 260 | CALTECH/test/4/test_imgs_260.jpg 4 261 | CALTECH/test/4/test_imgs_261.jpg 4 262 | CALTECH/test/4/test_imgs_262.jpg 4 263 | CALTECH/test/4/test_imgs_263.jpg 4 264 | CALTECH/test/4/test_imgs_264.jpg 4 265 | CALTECH/test/4/test_imgs_265.jpg 4 266 | CALTECH/test/4/test_imgs_266.jpg 4 267 | CALTECH/test/4/test_imgs_267.jpg 4 268 | CALTECH/test/4/test_imgs_268.jpg 4 269 | CALTECH/test/4/test_imgs_269.jpg 4 270 | CALTECH/test/4/test_imgs_270.jpg 4 271 | CALTECH/test/4/test_imgs_271.jpg 4 272 | CALTECH/test/4/test_imgs_272.jpg 4 273 | CALTECH/test/4/test_imgs_273.jpg 4 274 | CALTECH/test/4/test_imgs_274.jpg 4 275 | CALTECH/test/4/test_imgs_275.jpg 4 276 | CALTECH/test/4/test_imgs_276.jpg 4 277 | CALTECH/test/4/test_imgs_277.jpg 4 278 | CALTECH/test/4/test_imgs_278.jpg 4 279 | CALTECH/test/4/test_imgs_279.jpg 4 280 | CALTECH/test/4/test_imgs_280.jpg 4 281 | CALTECH/test/4/test_imgs_281.jpg 4 282 | CALTECH/test/4/test_imgs_282.jpg 4 283 | CALTECH/test/4/test_imgs_283.jpg 4 284 | CALTECH/test/4/test_imgs_284.jpg 4 285 | CALTECH/test/4/test_imgs_285.jpg 4 286 | CALTECH/test/4/test_imgs_286.jpg 4 287 | CALTECH/test/4/test_imgs_287.jpg 4 288 | CALTECH/test/4/test_imgs_288.jpg 4 289 | CALTECH/test/4/test_imgs_289.jpg 4 290 | CALTECH/test/4/test_imgs_290.jpg 4 291 | CALTECH/test/4/test_imgs_291.jpg 4 292 | CALTECH/test/4/test_imgs_292.jpg 4 293 | CALTECH/test/4/test_imgs_293.jpg 4 294 | CALTECH/test/4/test_imgs_294.jpg 4 295 | CALTECH/test/4/test_imgs_295.jpg 4 296 | CALTECH/test/4/test_imgs_296.jpg 4 297 | CALTECH/test/4/test_imgs_297.jpg 4 298 | CALTECH/test/4/test_imgs_298.jpg 4 299 | CALTECH/test/4/test_imgs_299.jpg 4 300 | CALTECH/test/4/test_imgs_300.jpg 4 301 | CALTECH/test/4/test_imgs_301.jpg 4 302 | CALTECH/test/4/test_imgs_302.jpg 4 303 | CALTECH/test/4/test_imgs_303.jpg 4 304 | CALTECH/test/4/test_imgs_304.jpg 4 305 | CALTECH/test/4/test_imgs_305.jpg 4 306 | CALTECH/test/4/test_imgs_306.jpg 4 307 | CALTECH/test/4/test_imgs_307.jpg 4 308 | CALTECH/test/4/test_imgs_308.jpg 4 309 | CALTECH/test/4/test_imgs_309.jpg 4 310 | CALTECH/test/4/test_imgs_310.jpg 4 311 | CALTECH/test/4/test_imgs_311.jpg 4 312 | CALTECH/test/4/test_imgs_312.jpg 4 313 | CALTECH/test/4/test_imgs_313.jpg 4 314 | CALTECH/test/4/test_imgs_314.jpg 4 315 | CALTECH/test/4/test_imgs_315.jpg 4 316 | CALTECH/test/4/test_imgs_316.jpg 4 317 | CALTECH/test/4/test_imgs_317.jpg 4 318 | CALTECH/test/4/test_imgs_318.jpg 4 319 | CALTECH/test/4/test_imgs_319.jpg 4 320 | CALTECH/test/4/test_imgs_320.jpg 4 321 | CALTECH/test/4/test_imgs_321.jpg 4 322 | CALTECH/test/4/test_imgs_322.jpg 4 323 | CALTECH/test/4/test_imgs_323.jpg 4 324 | CALTECH/test/4/test_imgs_324.jpg 4 325 | CALTECH/test/4/test_imgs_325.jpg 4 326 | CALTECH/test/4/test_imgs_326.jpg 4 327 | CALTECH/test/4/test_imgs_327.jpg 4 328 | CALTECH/test/4/test_imgs_328.jpg 4 329 | CALTECH/test/4/test_imgs_329.jpg 4 330 | CALTECH/test/4/test_imgs_330.jpg 4 331 | CALTECH/test/4/test_imgs_331.jpg 4 332 | CALTECH/test/4/test_imgs_332.jpg 4 333 | CALTECH/test/4/test_imgs_333.jpg 4 334 | CALTECH/test/4/test_imgs_334.jpg 4 335 | CALTECH/test/4/test_imgs_335.jpg 4 336 | CALTECH/test/4/test_imgs_336.jpg 4 337 | CALTECH/test/4/test_imgs_337.jpg 4 338 | CALTECH/test/4/test_imgs_338.jpg 4 339 | CALTECH/test/4/test_imgs_339.jpg 4 340 | CALTECH/test/4/test_imgs_340.jpg 4 341 | CALTECH/test/4/test_imgs_341.jpg 4 342 | CALTECH/test/4/test_imgs_342.jpg 4 343 | CALTECH/test/4/test_imgs_343.jpg 4 344 | CALTECH/test/4/test_imgs_344.jpg 4 345 | CALTECH/test/4/test_imgs_345.jpg 4 346 | CALTECH/test/4/test_imgs_346.jpg 4 347 | CALTECH/test/4/test_imgs_347.jpg 4 348 | CALTECH/test/4/test_imgs_348.jpg 4 349 | CALTECH/test/4/test_imgs_349.jpg 4 350 | CALTECH/test/4/test_imgs_350.jpg 4 351 | CALTECH/test/4/test_imgs_351.jpg 4 352 | CALTECH/test/4/test_imgs_352.jpg 4 353 | CALTECH/test/4/test_imgs_353.jpg 4 354 | CALTECH/test/4/test_imgs_354.jpg 4 355 | CALTECH/test/4/test_imgs_355.jpg 4 356 | CALTECH/test/4/test_imgs_356.jpg 4 357 | CALTECH/test/4/test_imgs_357.jpg 4 358 | CALTECH/test/4/test_imgs_358.jpg 4 359 | CALTECH/test/4/test_imgs_359.jpg 4 360 | CALTECH/test/4/test_imgs_360.jpg 4 361 | CALTECH/test/4/test_imgs_361.jpg 4 362 | CALTECH/test/4/test_imgs_362.jpg 4 363 | CALTECH/test/4/test_imgs_363.jpg 4 364 | CALTECH/test/4/test_imgs_364.jpg 4 365 | CALTECH/test/4/test_imgs_365.jpg 4 366 | CALTECH/test/4/test_imgs_366.jpg 4 367 | CALTECH/test/4/test_imgs_367.jpg 4 368 | CALTECH/test/4/test_imgs_368.jpg 4 369 | CALTECH/test/4/test_imgs_369.jpg 4 370 | CALTECH/test/4/test_imgs_370.jpg 4 371 | CALTECH/test/4/test_imgs_371.jpg 4 372 | CALTECH/test/4/test_imgs_372.jpg 4 373 | CALTECH/test/4/test_imgs_373.jpg 4 374 | CALTECH/test/4/test_imgs_374.jpg 4 375 | CALTECH/test/4/test_imgs_375.jpg 4 376 | CALTECH/test/4/test_imgs_376.jpg 4 377 | CALTECH/test/4/test_imgs_377.jpg 4 378 | CALTECH/test/4/test_imgs_378.jpg 4 379 | CALTECH/test/4/test_imgs_379.jpg 4 380 | CALTECH/test/4/test_imgs_380.jpg 4 381 | CALTECH/test/4/test_imgs_381.jpg 4 382 | CALTECH/test/4/test_imgs_382.jpg 4 383 | CALTECH/test/4/test_imgs_383.jpg 4 384 | CALTECH/test/4/test_imgs_384.jpg 4 385 | CALTECH/test/4/test_imgs_385.jpg 4 386 | CALTECH/test/4/test_imgs_386.jpg 4 387 | CALTECH/test/4/test_imgs_387.jpg 4 388 | CALTECH/test/4/test_imgs_388.jpg 4 389 | CALTECH/test/4/test_imgs_389.jpg 4 390 | CALTECH/test/4/test_imgs_390.jpg 4 391 | CALTECH/test/4/test_imgs_391.jpg 4 392 | CALTECH/test/4/test_imgs_392.jpg 4 393 | CALTECH/test/4/test_imgs_393.jpg 4 394 | CALTECH/test/4/test_imgs_394.jpg 4 395 | CALTECH/test/4/test_imgs_395.jpg 4 396 | CALTECH/test/4/test_imgs_396.jpg 4 397 | CALTECH/test/4/test_imgs_397.jpg 4 398 | CALTECH/test/4/test_imgs_398.jpg 4 399 | CALTECH/test/4/test_imgs_399.jpg 4 400 | CALTECH/test/4/test_imgs_400.jpg 4 401 | CALTECH/test/4/test_imgs_401.jpg 4 402 | CALTECH/test/4/test_imgs_402.jpg 4 403 | CALTECH/test/4/test_imgs_403.jpg 4 404 | CALTECH/test/4/test_imgs_404.jpg 4 405 | CALTECH/test/4/test_imgs_405.jpg 4 406 | CALTECH/test/4/test_imgs_406.jpg 4 407 | CALTECH/test/4/test_imgs_407.jpg 4 408 | CALTECH/test/4/test_imgs_408.jpg 4 409 | CALTECH/test/4/test_imgs_409.jpg 4 410 | CALTECH/test/4/test_imgs_410.jpg 4 411 | CALTECH/test/4/test_imgs_411.jpg 4 412 | CALTECH/test/4/test_imgs_412.jpg 4 413 | CALTECH/test/4/test_imgs_413.jpg 4 414 | CALTECH/test/4/test_imgs_414.jpg 4 415 | CALTECH/test/4/test_imgs_415.jpg 4 416 | CALTECH/test/4/test_imgs_416.jpg 4 417 | CALTECH/test/4/test_imgs_417.jpg 4 418 | CALTECH/test/4/test_imgs_418.jpg 4 419 | CALTECH/test/4/test_imgs_419.jpg 4 420 | CALTECH/test/4/test_imgs_420.jpg 4 421 | CALTECH/test/4/test_imgs_421.jpg 4 422 | CALTECH/test/4/test_imgs_422.jpg 4 423 | CALTECH/test/4/test_imgs_423.jpg 4 424 | CALTECH/test/4/test_imgs_424.jpg 4 425 | -------------------------------------------------------------------------------- /dataloader/text_lists/VLCS/classes.txt: -------------------------------------------------------------------------------- 1 | bird 2 | car 3 | chair 4 | dog 5 | person -------------------------------------------------------------------------------- /framework/ERM.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from framework.registry import Models, Datasets, Backbones 5 | 6 | __all__ = ['ERM'] 7 | 8 | 9 | @Models.register('ERM') 10 | class ERM(nn.Module): 11 | def __init__(self, num_classes, pretrained, args): 12 | super(ERM, self).__init__() 13 | self.args = args 14 | self.backbone = Backbones[args.backbone](num_classes, pretrained, args) 15 | self.num_classes = num_classes 16 | self.in_ch = self.backbone.in_ch 17 | 18 | def load_pretrained(self, path=None, prefix=None, absolute=False): 19 | if path is None: 20 | path = self.args.path 21 | if not absolute: 22 | cur_domain = Datasets[self.args.dataset].Domains[self.args.exp_num[0]] 23 | path = str(path) + '/{}'.format(cur_domain) + str(self.args.time) + '/models/model_best.pt' 24 | state = torch.load(path, map_location='cpu') 25 | if 'model' in state.keys(): 26 | state = state['model'] 27 | elif 'state_dict' in state.keys(): 28 | state = state['state_dict'] 29 | state['backbone.fc.weight'] = state['classifier.weight'] 30 | elif 'encoder_state_dict' in state.keys(): 31 | backbone = {'backbone.'+k : v for k,v in state['encoder_state_dict'].items()} 32 | backbone['backbone.fc.weight'] = state['classifier_state_dict']['layers.weight'] 33 | state = backbone 34 | 35 | keys = list(state.keys()) 36 | if 'module' in keys[0]: 37 | state = {k.replace('module', 'backbone'): state[k] for k in keys} 38 | 39 | if 'resnet' in keys[len(keys)//2]: 40 | state = {k.replace('resnet', 'backbone'): state[k] for k in keys} 41 | 42 | if prefix is not None: 43 | new_state = {} 44 | for k, v in state.items(): 45 | new_state.update({prefix + k: v}) 46 | state = new_state 47 | ret = self.load_state_dict(state, strict=False) 48 | print('load from {}, state : {}'.format(path, ret)) 49 | return state 50 | 51 | def forward(self, *args, **kwargs): 52 | return self.step(*args, **kwargs) 53 | 54 | def get_lr(self, fc_weight): 55 | old_lr = self.backbone.get_lr(fc_weight) 56 | 57 | new_params = [] 58 | for name, child in self.named_children(): 59 | if 'backbone' not in name: 60 | if hasattr(child, 'get_lr'): 61 | old_lr.extend(child.get_lr(fc_weight)) 62 | else: 63 | new_params.append([child, fc_weight]) 64 | old_lr.extend(new_params) 65 | return old_lr 66 | 67 | def step(self, x, label, **kwargs): 68 | l4, final_logits = self.backbone(x)[-2:] 69 | return { 70 | 'main': {'loss_type': 'ce', 'acc_type': 'acc', 'pred': final_logits, 'target': label, }, 71 | 'logits': final_logits, 72 | 'feat' : [l4.mean((2,3))] 73 | } 74 | -------------------------------------------------------------------------------- /framework/__init__.py: -------------------------------------------------------------------------------- 1 | from .registry import Schedulers 2 | from torch.optim.lr_scheduler import StepLR, ExponentialLR, CosineAnnealingLR 3 | 4 | Schedulers.register('step', StepLR) 5 | Schedulers.register('exp', ExponentialLR) 6 | Schedulers.register('cos', CosineAnnealingLR) 7 | 8 | # import files to access them from registry 9 | from framework.backbones import Resnet 10 | from framework import basic_train_funcs, ERM, basic_train_funcs, loss_and_acc 11 | import dataloader 12 | import models 13 | # dataloader will be imported from other files 14 | 15 | -------------------------------------------------------------------------------- /framework/backbones.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torchvision.models import AlexNet 3 | from torchvision.models import resnet18, resnet50, alexnet 4 | import torch.nn.functional as F 5 | 6 | __all__ = ['AlexNet', 'Resnet'] 7 | 8 | from framework.registry import Backbones 9 | from models.DomainAdaptor import AdaMixBN 10 | 11 | 12 | def init_classifier(fc): 13 | nn.init.xavier_uniform_(fc.weight, .1) 14 | nn.init.constant_(fc.bias, 0.) 15 | return fc 16 | 17 | 18 | @Backbones.register('resnet50') 19 | @Backbones.register('resnet18') 20 | class Resnet(nn.Module): 21 | def __init__(self, num_classes, pretrained=False, args=None): 22 | super(Resnet, self).__init__() 23 | if '50' in args.backbone: 24 | print('Using resnet-50') 25 | resnet = resnet50(pretrained=pretrained) 26 | self.in_ch = 2048 27 | else: 28 | resnet = resnet18(pretrained=pretrained) 29 | self.in_ch = 512 30 | self.conv1 = resnet.conv1 31 | self.relu = resnet.relu 32 | self.bn1 = resnet.bn1 33 | self.maxpool = resnet.maxpool 34 | self.layer1 = resnet.layer1 35 | self.layer2 = resnet.layer2 36 | self.layer3 = resnet.layer3 37 | self.layer4 = resnet.layer4 38 | self.avgpool = resnet.avgpool 39 | self.fc = nn.Linear(self.in_ch, num_classes, bias=False) 40 | if args.in_ch != 3: 41 | self.init_conv1(args.in_ch, pretrained) 42 | 43 | def init_conv1(self, in_ch, pretrained): 44 | model_inplanes = 64 45 | conv1 = nn.Conv2d(in_ch, model_inplanes, kernel_size=7, stride=2, padding=3, bias=False) 46 | old_weights = self.conv1.weight.data 47 | if pretrained: 48 | for i in range(in_ch): 49 | self.conv1.weight.data[:, i, :, :] = old_weights[:, i % 3, :, :] 50 | self.conv1 = conv1 51 | 52 | def forward(self, x): 53 | net = self 54 | x = net.conv1(x) 55 | x = net.bn1(x) 56 | x = net.relu(x) 57 | x = net.maxpool(x) 58 | 59 | l1 = net.layer1(x) 60 | l2 = net.layer2(l1) 61 | l3 = net.layer3(l2) 62 | l4 = net.layer4(l3) 63 | logits = self.fc(l4.mean((2, 3))) 64 | return x, l1, l2, l3, l4, logits 65 | 66 | def get_lr(self, fc_weight): 67 | lrs = [ 68 | ([self.conv1, self.bn1, self.layer1, self.layer2, self.layer3, self.layer4], 1.0), 69 | (self.fc, fc_weight) 70 | ] 71 | return lrs 72 | 73 | 74 | @Backbones.register('alexnet') 75 | class Alexnet(nn.Module): 76 | # PACS : (88.08+60.74+63.44+54.31)/4 77 | # VLCS : (95.12+59.75+65.46+65.45)/4=71.45 78 | # VLCS(1e-4) : (95.97+56.55+67.54+63.80)/4=70.96 79 | def __init__(self, num_classes, pretrained=True, args=None): 80 | super(Alexnet, self).__init__() 81 | self.args = args 82 | cur_alexnet = alexnet(pretrained=pretrained) 83 | self.features = cur_alexnet.features 84 | self.avgpool = cur_alexnet.avgpool 85 | self.feature_layers = nn.Sequential(*list(cur_alexnet.classifier.children())[:-1]) 86 | self.in_ch = cur_alexnet.classifier[-1].in_features 87 | self.fc = nn.Linear(self.in_ch, num_classes, bias=False) 88 | 89 | def forward(self, x): 90 | x = self.features(x) 91 | x = x.view(x.size(0), 256 * 6 * 6) 92 | feats = self.feature_layers(x) 93 | output_class = self.fc(feats) 94 | return feats, output_class 95 | 96 | def get_lr(self, fc_weight): 97 | return [([self.features, self.feature_layers], 1.0), (self.fc, fc_weight)] 98 | 99 | 100 | class Convolution(nn.Module): 101 | 102 | def __init__(self, c_in, c_out, mixbn=False): 103 | super().__init__() 104 | self.conv = nn.Conv2d(c_in, c_out, 3, stride=1, padding=1) 105 | if mixbn: 106 | self.bn = AdaMixBN(c_out) 107 | else: 108 | self.bn = nn.BatchNorm2d(c_out) 109 | self.relu = nn.ReLU(True) 110 | self.seq = nn.Sequential( 111 | self.conv, 112 | self.bn, 113 | self.relu 114 | ) 115 | 116 | def forward(self, x): 117 | x = self.conv(x) 118 | x = self.bn(x) 119 | x = self.relu(x) 120 | return x 121 | 122 | 123 | @Backbones.register('convnet') 124 | class ConvNet(nn.Module): 125 | def __init__(self, num_classes, pretrained=True, args=None): 126 | super(ConvNet, self).__init__() 127 | 128 | c_hidden = 64 129 | mix = True 130 | self.conv1 = Convolution(3, c_hidden, mixbn=mix) 131 | self.conv2 = Convolution(c_hidden, c_hidden, mixbn=mix) 132 | self.conv3 = Convolution(c_hidden, c_hidden, mixbn=mix) 133 | self.conv4 = Convolution(c_hidden, c_hidden, mixbn=mix) 134 | 135 | self._out_features = 2**2 * c_hidden 136 | self.in_ch = 3 137 | self.fc = nn.Linear(self._out_features, num_classes) 138 | 139 | def _check_input(self, x): 140 | H, W = x.shape[2:] 141 | assert (H == 32 and W == 32), "Input to network must be 32x32, " "but got {}x{}".format(H, W) 142 | 143 | def forward(self, x): 144 | self._check_input(x) 145 | x = self.conv1(x) 146 | x = F.max_pool2d(x, 2) 147 | x = self.conv2(x) 148 | x = F.max_pool2d(x, 2) 149 | x = self.conv3(x) 150 | x = F.max_pool2d(x, 2) 151 | x = self.conv4(x) 152 | x = F.max_pool2d(x, 2) 153 | feat = x 154 | x = x.view(x.size(0), -1) 155 | return x[:, :, None, None], self.fc(x) 156 | 157 | def get_lr(self, fc_weight): 158 | return [(self, 1.0)] 159 | -------------------------------------------------------------------------------- /framework/basic_train_funcs.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from framework.loss_and_acc import get_loss_and_acc 4 | from framework.meta_util import meta_learning_MAML, split_image_and_label 5 | from framework.registry import EvalFuncs, TrainFuncs 6 | from utils.tensor_utils import to, AverageMeterDict, zero_and_update 7 | 8 | 9 | @TrainFuncs.register('meta') 10 | @TrainFuncs.register('deepall') 11 | def deepall_train(model, train_data, lr, epoch, args, engine, mode): 12 | running_loss, running_corrects = AverageMeterDict(), AverageMeterDict() 13 | optimizers, device = engine.optimizers, engine.device 14 | 15 | model.train() 16 | for i, data_list in enumerate(train_data): 17 | data_list = to(data_list, device) 18 | output_dicts = model(**data_list, epoch=epoch, step=len(train_data) * epoch + i, engine=engine, train_mode='train') 19 | total_loss = get_loss_and_acc(output_dicts, running_loss, running_corrects) 20 | if total_loss is not None: 21 | total_loss.backward() 22 | 23 | optimizers.step() 24 | optimizers.zero_grad() 25 | return running_loss.get_average_dicts(), running_corrects.get_average_dicts() 26 | 27 | 28 | @EvalFuncs.register('meta') 29 | @EvalFuncs.register('deepall') 30 | def deepall_eval(model, eval_data, lr, epoch, args, engine, mode): 31 | running_loss, running_corrects = AverageMeterDict(), AverageMeterDict() 32 | device = engine.device 33 | 34 | if hasattr(model, 'swa'): 35 | print('Eval with swa : ', end='') 36 | model = model.swa.module 37 | 38 | model.train() if args.TN else model.eval() 39 | 40 | with torch.no_grad(): 41 | for i, data_list in enumerate(eval_data): 42 | data_list = to(data_list, device) 43 | outputs = model(**data_list, epoch=epoch, step=len(eval_data) * epoch + i, engine=engine, train_mode='test') 44 | get_loss_and_acc(outputs, running_loss, running_corrects) 45 | loss, acc = running_loss.get_average_dicts(), running_corrects.get_average_dicts() 46 | if 'main' in acc: 47 | return acc['main'], (loss, acc) 48 | else: 49 | return 0, (loss, acc) 50 | 51 | 52 | @TrainFuncs.register('mldg') 53 | def mldg(meta_model, train_data, meta_lr, epoch, args, engine, mode): 54 | assert args.loader == 'meta' 55 | device, optimizers = engine.device, engine.optimizers 56 | running_loss, running_corrects = AverageMeterDict(), AverageMeterDict() 57 | print('Meta lr : {}, loops : {}'.format(meta_lr, len(train_data))) 58 | 59 | meta_model.train() 60 | for data_list in train_data: 61 | meta_train_data, meta_test_data = split_image_and_label(to(data_list, device), size=args.batch_size, loo=True) 62 | 63 | with meta_learning_MAML(meta_model) as fast_model: 64 | for j in range(args.meta_step): 65 | meta_train_loss = get_loss_and_acc(fast_model.step(**meta_train_data), running_loss, running_corrects) 66 | fast_model.meta_step(meta_train_loss, meta_lr, use_second_order=args.meta_second_order) 67 | 68 | meta_val_loss = get_loss_and_acc(fast_model.step(**meta_test_data), running_loss, running_corrects) 69 | 70 | zero_and_update(optimizers, (meta_train_loss+meta_val_loss)) 71 | return running_loss.get_average_dicts(), running_corrects.get_average_dicts() 72 | 73 | -------------------------------------------------------------------------------- /framework/engine.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | 4 | import torch 5 | import numpy as np 6 | from torch.optim import SGD, Adam 7 | import torch.nn as nn 8 | from tqdm import tqdm 9 | from pathlib import Path 10 | from framework.log import MyLogger 11 | from framework.registry import Models, Datasets, Schedulers, TrainFuncs, EvalFuncs 12 | 13 | 14 | def extract_parameters(model, lr): 15 | if isinstance(model, (list, tuple)): 16 | if isinstance(model[0], nn.Parameter): 17 | params = [{'params': param, 'lr': lr} for param in model] 18 | elif isinstance(model[0], nn.Module): 19 | params = [{'params': model_.parameters(), 'lr': lr} for model_ in model] 20 | else: 21 | raise Exception("Unkown models {}".format(type(model))) 22 | else: 23 | if isinstance(model, nn.Parameter): 24 | params = [{'params': model, 'lr': lr}] 25 | else: 26 | params = [{'params': model.parameters(), 'lr': lr}] 27 | return params 28 | 29 | 30 | def get_optimizers(model, args): 31 | init_lr, fc_weight = args.lr, args.fc_weight 32 | param_lr_lists = model.get_lr(fc_weight) if hasattr(model, 'get_lr') else [(model, fc_weight)] 33 | 34 | if args.opt_split: 35 | optimizer = [] 36 | for model, weight in param_lr_lists: 37 | params = extract_parameters(model, init_lr * weight) 38 | if args.optimizer.lower() == 'sgd': 39 | opt = SGD(params, lr=init_lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=args.nesterov) 40 | elif args.optimizer.lower() == 'adam': 41 | opt = Adam(params, lr=init_lr, betas=(args.beta1, args.beta2), weight_decay=5e-4) # , amsgrad=True) 42 | else: 43 | raise Exception("Unknown optimizer : {}".format(args.optimizer)) 44 | optimizer.append(opt) 45 | else: 46 | params = [] 47 | for model, weight in param_lr_lists: 48 | params.extend(extract_parameters(model, init_lr * weight)) 49 | if args.optimizer.lower() == 'sgd': 50 | optimizer = SGD(params, lr=init_lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=args.nesterov) 51 | elif args.optimizer.lower() == 'adam': 52 | optimizer = Adam(params, lr=init_lr, betas=(args.beta1, args.beta2), weight_decay=5e-4) # , amsgrad=True) 53 | else: 54 | raise Exception("Unknown optimizer : {}".format(args.optimizer)) 55 | return optimizer 56 | 57 | 58 | def get_scheduler(args, opt): 59 | num_epoch = args.num_epoch 60 | if args.dataset == 'digits_dg': 61 | lr_step = 20 62 | else: 63 | lr_step = args.num_epoch * .8 64 | if args.scheduler == 'inv': 65 | schedulers = Schedulers[args.scheduler](optimizer=opt, alpha=10, beta=0.75, total_epoch=num_epoch) 66 | elif args.scheduler == 'step': 67 | schedulers = Schedulers[args.scheduler](optimizer=opt, step_size=lr_step, gamma=args.lr_decay_gamma) 68 | elif args.scheduler == 'cosine': 69 | schedulers = Schedulers[args.scheduler](optimizer=opt, T_max=10, eta_min=1e-5) 70 | else: 71 | raise ValueError('Name of scheduler unknown %s' % args.scheduler) 72 | return num_epoch, schedulers 73 | 74 | 75 | class GenericEngine(object): 76 | def __init__(self, args, time): 77 | self.set_seed(time*10000) 78 | self.args = args 79 | self.time = time 80 | self.args.time = time 81 | 82 | self.path = Path(args.save_path) 83 | print(f'Save path : {self.path.absolute()}') 84 | self.path.mkdir(exist_ok=True) 85 | (self.path / 'models').mkdir(exist_ok=True) 86 | 87 | self.device = torch.device("cuda:" + str(args.gpu) if torch.cuda.is_available() else "cpu") 88 | data_config = Datasets[args.dataset](args) 89 | self.num_classes = data_config.NumClasses 90 | self.model = Models[args.model](num_classes=self.num_classes, pretrained=True, args=args).to(self.device) 91 | (self.source_train, self.source_val, self.target_test), self.target_domain = self.get_loaders() 92 | self.optimizers = get_optimizers(self.model, args) 93 | self.num_epoch, self.schedulers = get_scheduler(args, self.optimizers) 94 | self.logger = MyLogger(args, self.path, '_'.join(self.target_domain)) 95 | 96 | self.global_parameters = {} 97 | 98 | if len(args.load_path) > 0: 99 | self.load_model(args.load_path) 100 | 101 | def get_loaders(self): 102 | data_config = Datasets[self.args.dataset](self.args) 103 | return data_config.get_loaders(self.model.get_aug_funcs() if hasattr(self.model, 'get_aug_funcs') else None), data_config.target_domains 104 | 105 | def set_seed(self, seed=None): 106 | if seed is not None: 107 | random.seed(seed) 108 | np.random.seed(seed) 109 | torch.manual_seed(seed) 110 | torch.cuda.manual_seed(seed) 111 | torch.cuda.manual_seed_all(seed) 112 | torch.backends.cudnn.deterministic = True 113 | torch.backends.cudnn.benchmark = False 114 | 115 | def train(self): 116 | best_acc, test_acc, best_epoch = 0, 0, 0 117 | 118 | if self.args.do_train: 119 | for epoch in tqdm(range(self.num_epoch)): 120 | lr = self.optimizers[0].param_groups[0]['lr'] if isinstance(self.optimizers, (list, tuple)) else self.optimizers.param_groups[0]['lr'] 121 | print('Epoch: {}/{}, Lr: {:.6f}'.format(epoch, self.num_epoch - 1, lr)) 122 | print('Temporary Best Accuracy is {:.4f} ({:.4f} at Epoch {})'.format(test_acc, best_acc, best_epoch)) 123 | 124 | (loss_dict, acc_dict) = TrainFuncs[self.args.train](self.model, self.source_train, lr, epoch, self.args, self, mode='train') 125 | self.logger.log('train', epoch, loss_dict, acc_dict) 126 | 127 | if epoch % self.args.eval_step == 0: 128 | acc, (loss_dict, acc_dict) = EvalFuncs[self.args.eval](self.model, self.source_val, lr, epoch, self.args, self, mode='eval') 129 | self.logger.log('eval', epoch, loss_dict, acc_dict) 130 | 131 | acc_, (loss_dict, acc_dict) = EvalFuncs[self.args.eval](self.model, self.target_test, lr, epoch, self.args, self, mode='test') 132 | self.logger.log('test', epoch, loss_dict, acc_dict) 133 | 134 | if epoch > 0 and epoch % self.args.save_step == 0 and epoch >= self.args.start_save_epoch: 135 | self.save_model(f'{epoch}.pt') 136 | 137 | if acc >= best_acc: 138 | best_acc, test_acc, best_epoch = acc, acc_, epoch 139 | self.save_model('model_best.pt') 140 | 141 | self.schedulers.step() 142 | 143 | if self.args.save_last: 144 | self.save_model('model_last.pt') 145 | 146 | if self.args.test_with_eval: 147 | print('Test with source validation set!!!') 148 | test_acc, test_acc_dict = self.test(best_epoch, best_acc, loader=self.source_val) 149 | else: 150 | test_acc, test_acc_dict = self.test(best_epoch, best_acc, loader=self.target_test) 151 | self.save_global_parameters() 152 | return test_acc, test_acc_dict 153 | 154 | def save_model(self, name='model_best.pt'): 155 | save_dict = { 156 | 'model': self.model.state_dict(), 157 | 'opt': self.optimizers.state_dict() 158 | } 159 | torch.save(save_dict, os.path.join(self.path, 'models', name)) 160 | 161 | def load_model(self, path=None): 162 | if path is None: 163 | path = os.path.join(self.path, 'models', "model_best.pt") 164 | else: 165 | path = os.path.join(path, '_'.join(self.target_domain) + str(self.time), 'models', "model_best.pt") 166 | 167 | if os.path.exists(path): 168 | m = self.model.load_pretrained(path, absolute=True) 169 | if 'opt' in m: 170 | try: 171 | ret1 = self.optimizers.load_state_dict(m['opt']) 172 | except Exception as e: 173 | print(e) 174 | print('Load optimizer from {}'.format(path), ret1) 175 | else: 176 | print('Model in {}, Not found !!!!'.format(path)) 177 | return self.model 178 | 179 | def test(self, best_epoch=0, best_acc=0, loader=None): 180 | lr = self.optimizers[0].param_groups[0]['lr'] if isinstance(self.optimizers, (list, tuple)) else self.optimizers.param_groups[0]['lr'] 181 | self.load_model() 182 | self.model = self.model.to(self.device) 183 | test_acc, (loss_dict, acc_dict) = EvalFuncs[self.args.eval](self.model, loader, lr, best_epoch, self.args, self, mode='test') 184 | self.logger.log_best(best_epoch, best_acc, loss_dict, acc_dict) 185 | return test_acc, acc_dict 186 | 187 | def save_global_parameters(self): 188 | for k, v in self.global_parameters.items(): 189 | np.save(self.target_domain[0]+'-'+k, v) 190 | -------------------------------------------------------------------------------- /framework/exp.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | import inspect 3 | import numpy as np 4 | from copy import deepcopy 5 | from pathlib import Path 6 | 7 | from framework.registry import Models, Datasets, EvalFuncs, show_entries_and_files, TrainFuncs 8 | from framework.engine import GenericEngine 9 | from framework.log import generate_exp 10 | from utils.tensor_utils import mkdir 11 | import torch 12 | 13 | torch.manual_seed(200423) 14 | np.random.seed(200243) 15 | 16 | 17 | def get_mean_dict(dict_list): 18 | keys = dict_list[0].keys() 19 | mean_dict = {} 20 | for k in keys: 21 | mean_dict[k] = np.mean([d[k] for d in dict_list]) 22 | return mean_dict 23 | 24 | 25 | class Experiments(object): 26 | def __init__(self, args): 27 | print(args) 28 | self.args = args 29 | self.exp_num = self.args.exp_num 30 | self.start_time = self.args.start_time 31 | self.times = self.args.times 32 | self.dataset = self.args.dataset 33 | self.domains = Datasets[self.dataset].Domains 34 | self.save_path = Path(self.args.save_path) 35 | mkdir(self.save_path, level=3) 36 | with open(self.save_path / 'args.txt', 'a') as f: 37 | f.write(str(args) + '\n\n') 38 | if args.show_entry: 39 | show_entries_and_files() 40 | 41 | def backup(self, args): 42 | print('Backing up..............') 43 | dir = self.save_path / 'backup' 44 | dir.mkdir(exist_ok=True) 45 | file = Models.get_src_file(args.model) 46 | print(file) 47 | shutil.copy(file, dir) 48 | 49 | file = TrainFuncs.get_src_file(args.train) 50 | print(file) 51 | shutil.copy(file, dir) 52 | 53 | file = EvalFuncs.get_src_file(args.eval) 54 | print(file) 55 | shutil.copy(file, dir) 56 | 57 | from dataloader.augmentations import RandAugment 58 | shutil.copy(inspect.getfile(RandAugment), dir) 59 | 60 | shutil.copy(Datasets.get_src_file(args.dataset), dir) 61 | 62 | def run(self): 63 | print() 64 | try: 65 | if self.args.do_train: 66 | self.backup(self.args) 67 | if self.exp_num[0] == -2: 68 | print('Run All Exp Many Times !!!') 69 | self.run_all_exp_many_times(self.times) 70 | elif self.exp_num[0] == -1: 71 | print('Run All Exp !!!') 72 | self.run_all_exp(self.start_time) 73 | else: 74 | for num in self.exp_num: 75 | print('Run One Exp !!!') 76 | assert num >= 0 77 | self.run_one_exp(exp_idx=num, time=self.args.start_time) 78 | except Exception as e: 79 | import traceback 80 | traceback.print_exc() 81 | with open(self.save_path / 'error.txt', 'w') as f: 82 | traceback.print_exc(None, f) 83 | 84 | def run_all_exp_many_times(self, times=3): 85 | acc_array, acc_dict_array = [], [] 86 | with open(self.save_path / 'many_exp.txt', 'a') as f: 87 | 88 | for t in range(times): 89 | print('============= Run {} ============='.format(self.start_time + t)) 90 | acc_list, (keys, acc_dict_alist) = self.run_all_exp(self.start_time + t) 91 | acc_array.append(acc_list) 92 | acc_dict_array.append(acc_dict_alist) 93 | 94 | acc_array = np.array(acc_array) # times x (domains+1) # +1 is the mean acc 95 | acc_dict_array = np.array(acc_dict_array) # times x (domains+1) x values # +1 is the mean acc 96 | 97 | assert acc_array.shape[1] == len(self.domains) + 1 98 | std = acc_array.std(axis=0) 99 | mean = acc_array.mean(axis=0) 100 | 101 | dict_std = {k: acc_dict_array.std(0)[:, i] for i, k in enumerate(keys)} 102 | dict_mean = {k: acc_dict_array.mean(0)[:, i] for i, k in enumerate(keys)} 103 | 104 | names = self.domains + ['Mean'] 105 | for i, (d, m, s) in enumerate(zip(names, mean, std)): 106 | print('{} : {:.2f}+-{:.2f}'.format(d, m, s)) 107 | f.write('{} : {:.2f}+-{:.2f}\n'.format(d, m, s)) 108 | print() 109 | 110 | for k in keys: 111 | print(k) 112 | for i, (d, m, s) in enumerate(zip(names, dict_mean[k], dict_std[k])): 113 | print('{} : {:.2f}+-{:.2f}'.format(d, m, s)) 114 | print() 115 | 116 | if self.args.do_train: 117 | # generate_last_exp(self.save_path, domains=self.domains, times=times) 118 | generate_exp(self.save_path, domains=self.domains, times=times, type='last') 119 | 120 | def run_all_exp(self, time): 121 | test_acc_list, test_acc_dict_list = [], [] 122 | with open(str(self.save_path / 'all_exp.txt'), 'a') as f: 123 | print('------------- New Exp -------------') 124 | f.write('------------- New Exp -------------\n') 125 | for i, d in enumerate(self.domains): 126 | acc, acc_dict = self.run_one_exp(exp_idx=i, time=time) 127 | print(f'{d} : {acc:.2f} [' + ', '.join([f'{k} : {v:.4f}' for k, v in acc_dict.items()]) + ']\n') 128 | f.write(f'{d} : {acc:.2f} [' + ', '.join([f'{k} : {v:.4f}' for k, v in acc_dict.items()]) + ']\n') 129 | test_acc_list.append(acc * 100) 130 | test_acc_dict_list.append({k:v * 100 for k,v in acc_dict.items()}) 131 | mean_acc = np.mean(test_acc_list) 132 | mean_acc_dict = get_mean_dict(test_acc_dict_list) 133 | print(f'Mean {mean_acc:.2f} [' + ', '.join([f'{k} : {v:.4f}' for k, v in mean_acc_dict.items()]) + ']\n') 134 | f.write(f'{mean_acc:.2f} [' + ', '.join([f'{k} : {v:.4f}' for k, v in mean_acc_dict.items()]) + ']\n') 135 | # print('Mean acc : {} {}\n\n'.format(mean_acc, mean_acc_dict)) 136 | # f.write('Mean acc : {} {}\n\n'.format(mean_acc, mean_acc_dict)) 137 | test_acc_list.append(mean_acc) 138 | test_acc_dict_list.append(mean_acc_dict) 139 | keys = test_acc_dict_list[0].keys() 140 | return test_acc_list, (keys, [[test_acc_dict_list[i][k] for k in keys] for i in range(len(test_acc_list))]) 141 | 142 | def run_one_exp(self, exp_idx=0, time=0): 143 | args = deepcopy(self.args) 144 | args.exp_num = [exp_idx] 145 | args.save_path = str(Path(args.save_path) / '{}{}'.format(self.domains[exp_idx], time)) 146 | engine = GenericEngine(args, time) 147 | test_acc, test_acc_dict = engine.train() 148 | return test_acc, test_acc_dict 149 | -------------------------------------------------------------------------------- /framework/log.py: -------------------------------------------------------------------------------- 1 | import time 2 | from pathlib import Path 3 | import numpy as np 4 | import re 5 | from tensorboardX import SummaryWriter 6 | import os 7 | 8 | from framework.registry import Datasets 9 | 10 | 11 | class MyLogger: 12 | def __init__(self, args, root, target_domain, enable_tensorboard=False, 13 | source_train_path='source_train.txt', source_eval_path='source_eval.txt', target_test_path='target_test.txt'): 14 | self.args = args 15 | root = Path(root) 16 | 17 | src_train_path = root / source_train_path 18 | src_eval_path = root / source_eval_path 19 | target_test_path = root / target_test_path 20 | self.paths = { 21 | 'train': src_train_path, 22 | 'eval': src_eval_path, 23 | 'test': target_test_path, 24 | 'EMA test': target_test_path 25 | } 26 | self.target_domain = target_domain 27 | self.enable_tensorboard = enable_tensorboard 28 | if self.enable_tensorboard: 29 | self.writer = SummaryWriter(str(root / 'tensorboard'), filename_suffix=time.strftime('_%Y-%m-%d_%H-%M-%S')) 30 | 31 | def log_seed(self, seed): 32 | with open(self.paths['test'], 'a') as f: 33 | f.write('Current Seed {}\n'.format(seed)) 34 | 35 | def log_best(self, epoch, best_val_acc, loss, acc): 36 | with open(self.paths['test'], 'a') as f: 37 | f.write('Best test (domain : {}, bs : {}, val : {:.4f}): '.format(self.target_domain, self.args.batch_size, best_val_acc)) 38 | print('Best test (domain : {}, bs : {}, val : {:.4f}): '.format(self.target_domain, self.args.batch_size, best_val_acc), end=' ') 39 | self.log_file('test', epoch, loss, acc) 40 | 41 | def log_str(self, mode, log): 42 | file = self.paths[mode] 43 | with open(file, 'a') as f: 44 | f.write(log + '\n') 45 | 46 | def log(self, mode, epoch, loss_dict, acc_dict): 47 | self.log_file(mode, epoch, loss_dict, acc_dict) 48 | if self.enable_tensorboard: 49 | self.tf_log_file(mode, epoch, loss_dict, acc_dict) 50 | 51 | def log_file(self, mode, epoch, loss_dict, acc_dict): 52 | loss_str = ''.join([' {}: {:.4f}'.format(k, v) for k, v in loss_dict.items()]) 53 | acc_str = ''.join([' {}: {:.4f}'.format(k, v) if isinstance(v, (float, int)) else ' {}: {}'.format(k, v) for k, v in acc_dict.items()]) 54 | t = time.strftime('[%Y-%m-%d %H:%M:%S]') 55 | log = '{} {:5s}: Epoch: {}, Loss : {} \t Acc : {}'.format(t, mode, epoch, loss_str, acc_str) 56 | print(log) 57 | for k in self.paths.keys(): 58 | if k in mode: 59 | self.log_str(k, log) 60 | 61 | def tf_log_file(self, mode, epoch, loss_dict, acc_dict): 62 | assert self.enable_tensorboard, "Tensorboard not enabled!!!!" 63 | for k, v in loss_dict.items(): 64 | self.writer.add_scalar('loss/{}/{}/{}'.format(self.target_domain, mode, k), v, epoch) 65 | for k, v in acc_dict.items(): 66 | self.writer.add_scalar('acc/{}/{}/{}'.format(self.target_domain, mode, k), v, epoch) 67 | self.writer.flush() 68 | 69 | def log_output(self, log_dicts, iter, parent=''): 70 | for k, v in log_dicts.items(): 71 | if 'hist' in k: 72 | self.writer.add_histogram(k, v, iter) 73 | else: 74 | self.writer.add_scalar(parent + '{}/{}'.format(self.target_domain, k), v, iter) 75 | self.writer.flush() 76 | 77 | 78 | def generate_many_exp(path): 79 | import re 80 | p = Path(path) 81 | acc_dict = {} 82 | with open(str(p / 'all_exp.txt'), 'r') as f: 83 | for line in f: 84 | res = re.match(r'(.*) : (.*)', line) 85 | if res is None: 86 | continue 87 | domain, acc = res.groups() 88 | acc = float(acc) 89 | if domain in acc_dict: 90 | acc_dict[domain].append(acc) 91 | else: 92 | acc_dict[domain] = [acc] 93 | 94 | with open(p / 'many_exp.txt', 'a') as f: 95 | for d, acc_list in acc_dict.items(): 96 | m = np.mean(acc_list) 97 | s = np.std(acc_list) 98 | print('{} : {:.2f}+-{:.2f} \n'.format(d, m, s)) 99 | f.write('{} : {:.2f}+-{:.2f} \n'.format(d, m, s)) 100 | f.write('\n') 101 | 102 | 103 | def get_acc_list(folder, line_idx, times, domains): 104 | folder = Path(folder) 105 | # times x domains 106 | accs_list = [] 107 | for i in range(times): 108 | accs = [] 109 | for d in domains: 110 | file = folder / (d + str(i)) / 'target_test.txt' 111 | with open(str(file), 'r') as f: 112 | last_line = f.readlines()[line_idx] 113 | last_acc = float(re.match('.* Acc :.*main: (.*)', last_line).groups()[0][:6]) 114 | accs.append(last_acc) 115 | accs_list.append(accs) 116 | return accs_list 117 | 118 | 119 | def generate_exp(folder, domains, times=5, type='all', with_ema=False): 120 | if type == 'last': 121 | many_exp_file = 'last_many_exp.txt' 122 | all_exp_file = 'last_all_exp.txt' 123 | if with_ema: 124 | line_idx = -3 125 | else: 126 | line_idx = -2 127 | elif type == 'all': 128 | many_exp_file = 'many_exp.txt' 129 | all_exp_file = 'all_exp.txt' 130 | line_idx = -1 131 | elif type == 'ema': 132 | many_exp_file = 'ema_many_exp.txt' 133 | all_exp_file = 'ema_all_exp.txt' 134 | line_idx = -2 135 | 136 | import numpy as np 137 | folder = Path(folder) 138 | accs_list = get_acc_list(folder, line_idx, times, domains) 139 | accs_list = np.array(accs_list) * 100 140 | mean = accs_list.mean(0) 141 | std = accs_list.std(0) 142 | total_mean = accs_list.mean() 143 | total_std = accs_list.mean(1).std() 144 | 145 | with open(str(folder / all_exp_file), 'w') as f: 146 | for acc in accs_list: 147 | f.write('------------- New Exp -------------\n') 148 | for d, ac in zip(domains, acc): 149 | line = '{} : {:.2f}\n'.format(d, ac) 150 | f.write(line) 151 | line = '{} : {:.2f}\n'.format('Mean acc', acc.mean()) 152 | f.write(line) 153 | f.write('\n') 154 | 155 | with open(str(folder / many_exp_file), 'w') as f: 156 | for d, m, s in zip(domains, mean, std): 157 | line = '{} : {:.2f}+-{:.2f}\n'.format(d, m, s) 158 | print(line) 159 | f.write(line) 160 | line = '{} : {:.2f}+-{:.2f}\n'.format('Mean', total_mean, total_std) 161 | print(line) 162 | f.write(line) 163 | print('Finished {}'.format(folder)) 164 | 165 | 166 | def delete(path): 167 | path = Path(path) 168 | for folder in path.iterdir(): 169 | if folder.is_dir() and 'back' not in folder.name: 170 | for model in (folder / 'models').iterdir(): 171 | # if models.name[-4] in '0123456789': 172 | os.remove(str(model)) 173 | print(model) 174 | 175 | 176 | def plot(keys, acc_list, title): 177 | lengths = [len(acc) for acc in acc_list] 178 | max_len = np.max(lengths) 179 | start = [max_len - l for l in lengths] 180 | for i in range(len(keys)): 181 | plt.plot(range(start[i], lengths[i] + start[i]), acc_list[i], label=keys[i]) 182 | plt.legend() 183 | plt.title(title) 184 | plt.show() 185 | 186 | 187 | def get_last_mean_acc(domain_acc_list, keys): 188 | # domain -> task -> acc_list 189 | res = [] 190 | for d in domain_acc_list: 191 | res.append([acc[-1] for acc in d]) 192 | # domain x task 193 | res = np.array(res) 194 | mean_acc = np.mean(res, 0) 195 | print(dict(zip(keys, mean_acc))) 196 | 197 | 198 | def read_file(file, first_keys=None): 199 | if first_keys is None: 200 | first_keys = ['^\[.*\] .*main: (.*)'] 201 | with open(file, 'r') as f: 202 | lines = f.readlines()[:-1] 203 | all_acc_list = [] 204 | for k in first_keys: 205 | acc_list = [] 206 | for line in lines: 207 | acc = re.match(k, line) 208 | if acc is not None: 209 | acc_list.append(float(acc.groups()[0])) 210 | all_acc_list.append(acc_list) 211 | return all_acc_list 212 | 213 | 214 | def read_acc(folder): 215 | folder = Path(folder) 216 | domains = Datasets['PACS'].Domains 217 | times = 5 218 | for t in range(times): 219 | domain_acc_list = [] 220 | for d in domains: 221 | first_keys = ['test'] + ['EMA{}'.format(i) for i in range(5)] 222 | p = folder / (d + str(t)) / 'target_test.txt' 223 | all_acc_list = read_file(p, first_keys) 224 | domain_acc_list.append(all_acc_list) 225 | get_last_mean_acc(domain_acc_list, first_keys) 226 | 227 | 228 | if __name__ == '__main__': 229 | domains = Datasets['MDN'].Domains 230 | # for f in Path('/data/zj/PycharmProjects/TTA/NEW').iterdir(): 231 | # if 'test' not in str(f.absolute()): 232 | # generate_exp(f, domains, times=5, type='last', ) 233 | # for folder in Path('/data/zj/PycharmProjects/DomainAdaptation/script/FOMAML_New').iterdir(): 234 | # delete(folder) 235 | generate_exp('/data/zj/PycharmProjects/TTA/AdaBN/meta_norm_MDN', domains, times=3, type='all') 236 | -------------------------------------------------------------------------------- /framework/meta_util.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import errno 3 | import functools 4 | import os 5 | import signal 6 | import types 7 | from contextlib import contextmanager 8 | 9 | import numpy as np 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | 14 | from framework.loss_and_acc import get_loss_and_acc 15 | from utils.tensor_utils import to, zero_and_update 16 | 17 | 18 | def put_theta(model, theta): 19 | if theta is None: 20 | return model 21 | 22 | def k_param_fn(tmp_model, name=None): 23 | if len(theta) == 0: 24 | return 25 | 26 | if len(tmp_model._modules) != 0: 27 | for (k, v) in tmp_model._modules.items(): 28 | if name == '': 29 | k_param_fn(v, name=str(k)) 30 | else: 31 | k_param_fn(v, name=str(name + '.' + k)) 32 | 33 | # WARN : running_mean, 和 running_var 不是 parameter,所以在 new 中不会被更新 34 | for (k, v) in tmp_model._parameters.items(): 35 | if isinstance(v, torch.Tensor) and str(name + '.' + k) in theta.keys(): 36 | tmp_model._parameters[k] = theta[str(name + '.' + k)] 37 | # else: 38 | # print(name+'.'+k) 39 | # theta.pop(str(name + '.' + k)) 40 | 41 | for (k, v) in tmp_model._buffers.items(): 42 | if isinstance(v, torch.Tensor) and str(name + '.' + k) in theta.keys(): 43 | tmp_model._buffers[k] = theta[str(name + '.' + k)] 44 | # else: 45 | # print(k) 46 | # theta.pop(str(name + '.' + k)) 47 | 48 | k_param_fn(model, name='') 49 | return model 50 | 51 | 52 | def get_parameters(model): 53 | # note : you can direct manipulate these data reference which is related to the original models 54 | parameters = dict(model.named_parameters()) 55 | states = dict(model.named_buffers()) 56 | return parameters, states 57 | 58 | 59 | def put_parameters(model, param, state): 60 | model = put_theta(model, param) 61 | model = put_theta(model, state) 62 | return model 63 | 64 | 65 | def update_parameters(loss, names_weights_dict, lr, use_second_order, retain_graph=True, grads=None, ignore_keys=None): 66 | def contains(key, target_keys): 67 | if isinstance(target_keys, (tuple, list)): 68 | for k in target_keys: 69 | if k in key: 70 | return True 71 | else: 72 | return key in target_keys 73 | 74 | new_dict = {} 75 | for name, p in names_weights_dict.items(): 76 | if p.requires_grad: 77 | new_dict[name] = p 78 | # else: 79 | # print(name) 80 | names_weights_dict = new_dict 81 | 82 | if grads is None: 83 | grads = torch.autograd.grad(loss, names_weights_dict.values(), create_graph=use_second_order, retain_graph=retain_graph, allow_unused=True) 84 | names_grads_wrt_params_dict = dict(zip(names_weights_dict.keys(), grads)) 85 | updated_names_weights_dict = dict() 86 | 87 | for key in names_grads_wrt_params_dict.keys(): 88 | if names_grads_wrt_params_dict[key] is None: 89 | continue # keep the original state unchanged 90 | 91 | if ignore_keys is not None and contains(key, ignore_keys): 92 | # print(f'ignore {key}' ) 93 | continue 94 | 95 | updated_names_weights_dict[key] = names_weights_dict[key] - lr * names_grads_wrt_params_dict[key] 96 | return updated_names_weights_dict 97 | 98 | 99 | def cat_meta_data(data_list): 100 | new_data = {} 101 | for k in data_list[0].keys(): 102 | l = [] 103 | for data in data_list: 104 | l.append(data[k]) 105 | new_data[k] = torch.cat(l, 0) 106 | return new_data 107 | 108 | 109 | def timeout(seconds=10, error_message=os.strerror(errno.ETIME)): 110 | def decorator(func): 111 | def _handle_timeout(signum, frame): 112 | raise TimeoutError(error_message) 113 | 114 | @functools.wraps(func) 115 | def wrapper(*args, **kwargs): 116 | signal.signal(signal.SIGALRM, _handle_timeout) 117 | signal.alarm(seconds) 118 | try: 119 | result = func(*args, **kwargs) 120 | finally: 121 | signal.alarm(0) 122 | return result 123 | 124 | return wrapper 125 | 126 | return decorator 127 | 128 | 129 | # @timeout(3) 130 | def get_image_and_label(loaders, idx_list, device): 131 | if not isinstance(idx_list, (list, tuple)): 132 | idx_list = [idx_list] 133 | 134 | data_lists = [] 135 | for i in idx_list: 136 | data = loaders[i].next() 137 | data = to(data, device) # , non_blocking=True) 138 | # data = loaders[i].next() 139 | data_lists.append(data) 140 | return cat_meta_data(data_lists) 141 | 142 | 143 | def split_image_and_label(data, size, loo=False): 144 | n_domains = list(data.values())[0].shape[0] // size 145 | idx_sequence = list(np.random.permutation(n_domains)) 146 | if loo: 147 | n_domains = 2 148 | res = [{} for _ in range(n_domains)] 149 | 150 | for k, v in data.items(): 151 | split_data = torch.split(v, size) 152 | if loo: # meta_train, meta_test 153 | res[0][k] = torch.cat([split_data[_] for _ in idx_sequence[:len(split_data) - 1]]) 154 | res[1][k] = split_data[idx_sequence[-1]] 155 | else: 156 | for i, d in enumerate(split_data): 157 | res[i][k] = d 158 | return res 159 | 160 | 161 | def new_split_image_and_label(data, size, loo=False): 162 | n_domains = list(data.values())[0].shape[0] // size 163 | if loo: 164 | n_domains = 2 165 | res = [{} for _ in range(n_domains)] 166 | 167 | for k, v in data.items(): 168 | split_data = torch.split(v, size) 169 | if loo: # meta_train, meta_test 170 | res[0][k] = torch.cat(split_data[:2]) 171 | res[1][k] = torch.cat(split_data[2:]) 172 | else: 173 | for i, d in enumerate(split_data): 174 | res[i][k] = d 175 | return res 176 | 177 | 178 | def init_network(meta_model, meta_lr, previous_opt=None, momentum=0.9, Adam=False, beta1=0.9, beta2=0.999, device=None): 179 | fast_model = copy.deepcopy(meta_model).train() 180 | if device is not None: 181 | fast_model.to(device) 182 | if Adam: 183 | fast_opts = torch.optim.Adam(fast_model.parameters(), lr=meta_lr, betas=(beta1, beta2), weight_decay=5e-4) 184 | else: 185 | fast_opts = torch.optim.SGD(fast_model.parameters(), lr=meta_lr, weight_decay=5e-4, momentum=momentum) 186 | 187 | if previous_opt is not None: 188 | fast_opts.load_state_dict(previous_opt.state_dict()) 189 | return fast_model, fast_opts 190 | 191 | 192 | def load_state(new_opts, old_opts): 193 | [old.load_state_dict(new.state_dict()) for old, new in zip(old_opts, new_opts)] 194 | 195 | 196 | def update_meta_model(meta_model, fast_param_list, optimizers, meta_lr=1): 197 | meta_params, meta_states = get_parameters(meta_model) 198 | 199 | optimizers.zero_grad() 200 | 201 | # update grad 202 | for k in meta_params.keys(): 203 | new_v, old_v = 0, meta_params[k] 204 | for m in fast_param_list: 205 | new_v += m[0][k] 206 | new_v = new_v / len(fast_param_list) 207 | meta_params[k].grad = ((old_v - new_v) / meta_lr).data 208 | optimizers.step() 209 | 210 | 211 | def avg_meta_model(meta_model, fast_param_list): 212 | meta_params, meta_states = get_parameters(meta_model) 213 | 214 | # update grad 215 | for k in meta_params.keys(): 216 | new_v, old_v = 0, meta_params[k] 217 | for m in fast_param_list: 218 | new_v += m[k] 219 | new_v = new_v / len(fast_param_list) 220 | meta_params[k].data = new_v.data 221 | 222 | 223 | def add_grad(meta_model, fast_model, factor): 224 | meta_params, meta_states = get_parameters(meta_model) 225 | fast_params, fast_states = get_parameters(fast_model) 226 | grads = [] 227 | for k in meta_params.keys(): 228 | new_v, old_v = fast_params[k], meta_params[k] 229 | if meta_params[k].grad is None: 230 | meta_params[k].grad = ((old_v - new_v) * factor).data # if data is not used, the tensor will cause the memory leak 231 | else: 232 | meta_params[k].grad += ((old_v - new_v) * factor).data # if data is not used, the tensor will cause the memory leak 233 | grads.append((old_v - new_v)) 234 | return grads 235 | 236 | 237 | def compare_two_dicts(d1, d2): 238 | flag = True 239 | for k in d1.keys(): 240 | if not ((d1[k] - d2[k]).abs().max() < 1e-7): 241 | print(k, (d1[k] - d2[k]).abs().max()) 242 | flag = False 243 | return flag 244 | 245 | 246 | def inner_loop(meta_model, meta_train_data, meta_test_data, steps, meta_lr, opt_states, running_loss, running_corrects, 247 | meta_test=True, train_aug=False, test_aug=False): 248 | fast_model = copy.deepcopy(meta_model).train() 249 | put_parameters(fast_model, None, get_parameters(meta_model)[1]) # Only Put BN for fair comparison 250 | fast_opts = torch.optim.SGD(fast_model.parameters(), lr=meta_lr, weight_decay=5e-4, momentum=0.9) 251 | if opt_states is not None: 252 | fast_opts.load_state_dict(opt_states) 253 | 254 | # meta train 255 | if not train_aug: 256 | meta_train_data = copy.deepcopy(meta_train_data) 257 | meta_train_data['aug_x'] = None 258 | for i in range(steps): 259 | out = fast_model(**meta_train_data, meta_train=True, do_aug=train_aug) 260 | meta_train_loss2 = get_loss_and_acc(out, None, None) 261 | zero_and_update([fast_opts], meta_train_loss2) 262 | 263 | if meta_test: 264 | # meta test 265 | if not test_aug: 266 | meta_test_data = copy.deepcopy(meta_test_data) 267 | meta_test_data['aug_x'] = None 268 | out = fast_model(**meta_test_data, meta_train=False, do_aug=test_aug) 269 | meta_val_loss2 = get_loss_and_acc(out, running_loss, running_corrects) 270 | zero_and_update(fast_opts, meta_val_loss2) 271 | return fast_model, fast_opts.state_dict(), out 272 | else: 273 | return fast_model, fast_opts 274 | 275 | 276 | def correlation(grad1, grad2, cos=False): 277 | all_sim = [] 278 | for g1, g2 in zip(grad1, grad2): 279 | if cos: 280 | sim = F.cosine_similarity(g1.view(-1), g2.view(-1), 0) 281 | else: 282 | sim = (g1 * g2).sum() 283 | all_sim.append(sim) 284 | all_sim = torch.stack(all_sim) 285 | return all_sim.mean() 286 | 287 | 288 | def regularize_params(meta_model, params, opts, weight): 289 | def get_direction(param1, param2): 290 | dirs = [] 291 | for p1, p2 in zip(param1, param2): 292 | dirs.append(p2 - p1) 293 | return dirs 294 | 295 | def get_mean(dirs): 296 | mean_dir = [] 297 | for ls in zip(*dirs): 298 | v = 0 299 | for m in ls: 300 | v += m 301 | v = v / len(ls) 302 | mean_dir.append(v) 303 | return mean_dir 304 | 305 | meta_param = get_parameters(meta_model)[0] 306 | 307 | # get gradient direction from each models 308 | dirs = [get_direction(meta_param.values(), param.values()) for param in params] 309 | 310 | # get mean gradient direction 311 | mean_dir = get_mean(dirs) 312 | 313 | # measure distance between mean and other directions 314 | dists = [] 315 | for i in range(len(dirs)): 316 | for j in range(len(dirs)): 317 | if j > i: 318 | dists.append(correlation(dirs[i], dirs[j], cos=True)) 319 | dists = 1 - torch.stack(dists).mean() 320 | zero_and_update(opts, dists * weight) # w/o, w/ 321 | return dists 322 | 323 | 324 | def mixup_parameters(params, num=2, alpha=1): 325 | assert num <= len(params) 326 | selected_list = np.random.permutation(len(params))[:num] 327 | if alpha > 0: 328 | ws = np.float32(np.random.dirichlet([alpha] * num)) # Random mixup params 329 | else: 330 | ws = [1 / num] * num # simply average model 331 | params = [params[i] for i in selected_list] 332 | new_param = {} 333 | for name in params[0].keys(): 334 | new_p = 0 335 | for w, p in zip(ws, params): 336 | new_p += w * p[name] 337 | new_param[name] = new_p 338 | return new_param, selected_list 339 | 340 | 341 | def average_models(models): 342 | params = [get_parameters(m)[0] for m in models] 343 | new_param, _ = mixup_parameters(params, num=len(params), alpha=0) 344 | new_model = copy.deepcopy(models[0]) 345 | averaged_model = put_parameters(new_model, new_param, None) 346 | return averaged_model 347 | 348 | 349 | def get_consistency_loss(logits_clean, logits_aug=None, T=4, weight=2): 350 | if logits_aug is None: 351 | length = len(logits_clean) 352 | logits_clean, logits_aug = logits_clean[length // 2:], logits_clean[:length // 2] 353 | logits_clean = logits_clean.detach() 354 | p_clean, p_aug = (logits_clean / T).softmax(1), (logits_aug / T).softmax(1) 355 | p_mixture = ((p_aug + p_clean) / 2).clamp(min=1e-7, max=1).log() 356 | loss = (F.kl_div(p_mixture, p_clean, reduction='batchmean') + F.kl_div(p_mixture, p_aug, reduction='batchmean')) * weight 357 | return loss 358 | 359 | 360 | class AveragedModel(nn.Module): 361 | def __init__(self, start_epoch=0, device=None, lam=None, avg_fn=None): 362 | super(AveragedModel, self).__init__() 363 | self.device, self.start_epoch = device, start_epoch 364 | self.module = None 365 | self.lam = lam 366 | self.register_buffer('n_averaged', torch.tensor(0, dtype=torch.long, device=device)) 367 | if avg_fn is None: 368 | def avg_fn(averaged_model_parameter, model_parameter, lamd): 369 | return lamd * averaged_model_parameter + (1 - lamd) * model_parameter 370 | self.avg_fn = avg_fn 371 | 372 | def forward(self, *args, **kwargs): 373 | return self.module(*args, **kwargs) 374 | 375 | def step(self, *args, **kwargs): 376 | return self.module.step(*args, **kwargs) 377 | 378 | def init_model(self, model, epoch): 379 | if self.module is None: 380 | self.module = copy.deepcopy(model) 381 | if self.device is not None: 382 | self.module = self.module.to(self.device) 383 | 384 | def update_parameters(self, model, epoch): 385 | if epoch < self.start_epoch: 386 | return 387 | 388 | if self.module is None: 389 | self.module = copy.deepcopy(model) 390 | if self.device is not None: 391 | self.module = self.module.to(self.device) 392 | return 393 | 394 | if self.lam is None: 395 | lam = self.n_averaged.to(self.device) / (self.n_averaged.to(self.device) + 1) 396 | else: 397 | lam = self.lam 398 | for p_swa, p_model in zip(self.parameters(), model.parameters()): 399 | device = p_swa.device 400 | p_model_ = p_model.detach().to(device) 401 | if self.n_averaged == 0: 402 | p_swa.detach().copy_(p_model_) 403 | else: 404 | p_swa.detach().copy_(self.avg_fn(p_swa.detach(), p_model_, lam)) 405 | self.n_averaged += 1 406 | 407 | def update_bn(self, loader, epoch, iters=None, model=None, meta=False): 408 | model = self.module if model is None else model 409 | if epoch < self.start_epoch: 410 | return 411 | momenta = {} 412 | for module in model.modules(): 413 | if isinstance(module, torch.nn.modules.batchnorm._BatchNorm): 414 | module.running_mean = torch.zeros_like(module.running_mean) 415 | module.running_var = torch.ones_like(module.running_var) 416 | momenta[module] = module.momentum 417 | 418 | if not momenta: 419 | return 420 | 421 | was_training = model.training 422 | model.train() 423 | for module in momenta.keys(): 424 | module.momentum = None 425 | module.num_batches_tracked *= 0 426 | 427 | if meta: 428 | with torch.no_grad(): 429 | inner_loops = len(loader) if iters is None else iters 430 | for i in range(inner_loops): 431 | data_list = get_image_and_label(loader, [0, 1, 2], device=self.device) 432 | model.step(**data_list) 433 | else: 434 | with torch.no_grad(): 435 | inner_loops = len(loader) if iters is None else iters 436 | loader = iter(loader) 437 | for i in range(inner_loops): 438 | data_list = to(next(loader), self.device) 439 | model.step(**data_list) 440 | 441 | for bn_module in momenta.keys(): 442 | bn_module.momentum = momenta[bn_module] 443 | model.train(was_training) 444 | 445 | 446 | def freeze(model, name, freeze, reverse=False): 447 | for n, p in model.named_parameters(): 448 | if not reverse: 449 | if name in n: 450 | p.requires_grad = freeze 451 | else: 452 | if name not in n: 453 | p.requires_grad = freeze 454 | 455 | 456 | @contextmanager 457 | def meta_learning_MAML(meta_model): 458 | fast_model = copy.deepcopy(meta_model) 459 | params, states = get_parameters(meta_model) 460 | fast_model = put_parameters(fast_model, params, states).train() 461 | 462 | def meta_step(self, meta_loss, meta_lr, use_second_order=False, ignore_keys=None): 463 | params = get_parameters(self)[0] 464 | params = update_parameters(meta_loss, params, meta_lr, use_second_order=use_second_order, ignore_keys=ignore_keys) 465 | put_parameters(self, params, None) 466 | 467 | fast_model.meta_step = types.MethodType(meta_step, fast_model) # assign method to the instance 468 | yield fast_model 469 | del fast_model, params, states 470 | 471 | -------------------------------------------------------------------------------- /framework/registry.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copy from maskrcnn-benchmark 3 | """ 4 | import functools 5 | import inspect 6 | 7 | 8 | def _register_generic(module_dict, func_name, func, parse_name): 9 | assert func_name not in module_dict, 'Key of "{}" from "{}" already defined in "{}"'.format( 10 | func_name, inspect.getsourcefile(func), module_dict.get_src_file(func_name)) 11 | if parse_name: 12 | func = functools.partial(func, registed_name=func_name) 13 | module_dict[func_name.lower()] = func 14 | 15 | 16 | class Registry(dict): 17 | ''' 18 | A helper class for managing registering modules, it extends a dictionary 19 | and provides a register functions. 20 | 21 | Eg. creeting a registry: 22 | some_registry = Registry({"default": default_module}) 23 | 24 | There're two ways of registering new modules: 25 | 1): normal way is just calling register function: 26 | def foo(): 27 | ... 28 | some_registry.register("foo_module", foo) 29 | 2): used as decorator when declaring the module: 30 | @some_registry.register("foo_module") 31 | @some_registry.register("foo_modeul_nickname") 32 | def foo(): 33 | ... 34 | 35 | Access of module is just like using a dictionary, eg: 36 | f = some_registry["foo_modeul"] 37 | ''' 38 | 39 | def __init__(self, name, *args, **kwargs): 40 | super(Registry, self).__init__(*args, **kwargs) 41 | self.name = name 42 | 43 | def register(self, module_name, module=None, parse_name=False): 44 | # '-' is reserved as extra param for calling the function 45 | assert '-' not in module_name, "Function name should not contain '-'" 46 | 47 | # used as function call 48 | if module is not None: 49 | _register_generic(self, module_name, module, parse_name) 50 | return 51 | 52 | # used as decorator 53 | def register_fn(fn): 54 | _register_generic(self, module_name, fn, parse_name) 55 | return fn 56 | 57 | return register_fn 58 | 59 | def __getitem__(self, item): 60 | """ 61 | if function name is used with 'func-xx', 62 | the item after the first '-' will be passed to the function as param 63 | """ 64 | splits = item.split('-', 1) 65 | item = splits[0].lower() 66 | func = super(Registry, self).__getitem__(item) 67 | if len(splits) == 1: 68 | return func 69 | else: 70 | return functools.partial(func, param=splits[1]) 71 | 72 | def __repr__(self): 73 | return "Registry-{}".format(self.name) 74 | 75 | def get_src_file(self, module_name): 76 | module = self[module_name] 77 | src_file = inspect.getsourcefile(module) 78 | return src_file 79 | 80 | 81 | Models = Registry('Models') 82 | Datasets = Registry('Datasets') 83 | LossFuncs = Registry('LossFuncs') 84 | AccFuncs = Registry('AccFuncs') 85 | EvalFuncs = Registry('EvalFuncs') 86 | TrainFuncs = Registry('TrainFuncs') 87 | Schedulers = Registry('Schedulers') 88 | Backbones = Registry('Backbones') 89 | 90 | Entries = [Models, Datasets, LossFuncs, AccFuncs, EvalFuncs, TrainFuncs, Schedulers, Backbones] 91 | 92 | 93 | def show_entries_and_files(): 94 | for entry in Entries: 95 | print(f'\n{entry.name} : [') 96 | for key in entry.keys(): 97 | print(f'\t"{key}" from "{entry.get_src_file(key)}"') 98 | print(']') 99 | 100 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import ast 2 | import sys 3 | import argparse 4 | 5 | from framework.exp import Experiments 6 | 7 | 8 | def get_default_parser(): 9 | dataset = 'PACS' 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument('--show-entry', action='store_true') 12 | 13 | parser.add_argument('--data-root', default='/data/DataSets/') 14 | parser.add_argument('--dataset', default='{}'.format(dataset)) 15 | parser.add_argument('--save-path', default='../script/{}_New/resnet_test'.format(dataset)) 16 | parser.add_argument('--backbone', type=str, default='resnet18') 17 | parser.add_argument('--model', default='ERM') 18 | parser.add_argument('--train', default='deepall') 19 | parser.add_argument('--eval', default='deepall') 20 | 21 | parser.add_argument('--exp-num', nargs='+', type=int, default=[1], 22 | help='num >= 0 select which domain to train, num == -1 to train all domains, num == -2 to trian all domains multi times. ') 23 | parser.add_argument('--start-time', type=int, default=0) 24 | parser.add_argument('--times', type=int, default=1) 25 | 26 | parser.add_argument('--gpu', type=int, default=0) 27 | parser.add_argument('--num-epoch', type=int, default=30) 28 | parser.add_argument('--batch-size', type=int, default=128) 29 | 30 | # almost no need to change 31 | parser.add_argument('--eval-step', type=int, default=1) 32 | parser.add_argument('--save-step', type=int, default=1000) # Save steps 33 | parser.add_argument('--start-save-epoch', type=int, default=1000) # Save steps 34 | parser.add_argument('--save-last', action='store_true') 35 | 36 | # scheduler 37 | parser.add_argument('--scheduler', default='step') 38 | parser.add_argument('--lr-decay-gamma', type=float, default=0.1) 39 | 40 | # optimizer 41 | parser.add_argument('--lr', type=float, default=0.001) 42 | parser.add_argument('--fc-weight', type=float, default=10.0) 43 | parser.add_argument('--optimizer', type=str, default='sgd') 44 | parser.add_argument('--opt-split', action='store_true') 45 | parser.add_argument('--momentum', type=float, default=0.9) 46 | parser.add_argument('--weight-decay', type=float, default=5e-4) 47 | parser.add_argument('--nesterov', type=ast.literal_eval, default=True) 48 | parser.add_argument('--beta1', type=float, default=0.9) 49 | parser.add_argument('--beta2', type=float, default=0.999) 50 | 51 | parser.add_argument('--in-ch', default=3, type=int) 52 | 53 | # dataset 54 | parser.add_argument('--loader', default='normal', choices=['normal', 'meta', 'original', 'interleaved']) 55 | parser.add_argument('--img-size', default=224, type=int) 56 | parser.add_argument('--color-jitter', type=ast.literal_eval, default=True) # important 57 | parser.add_argument('--min-scale', type=float, default=0.8) 58 | parser.add_argument('--domain-label', action='store_true') 59 | parser.add_argument('--data-path', action='store_true') 60 | parser.add_argument('--workers', type=int, default=8) 61 | parser.add_argument('--src', nargs='+', type=int, default=[-1]) 62 | parser.add_argument('--tgt', nargs='+', type=int, default=[-1]) 63 | parser.add_argument('--do-train', type=ast.literal_eval, default=True) 64 | parser.add_argument('--do-not-transform', action='store_true') 65 | parser.add_argument('--load-path', type=str, default='') 66 | parser.add_argument('--shuffled', type=ast.literal_eval, default=True) 67 | parser.add_argument('--test-with-eval', action='store_true') 68 | parser.add_argument('--small-dataset', action='store_true') 69 | 70 | # ------ customized parameters ------ 71 | parser.add_argument('--TN', action='store_true') 72 | parser.add_argument('--meta-step', default=1, type=int) 73 | parser.add_argument('--meta-lr', default=1e-3, type=float) 74 | parser.add_argument('--meta-lr-weight', default=1, type=float) 75 | parser.add_argument('--meta-second-order', type=ast.literal_eval, default=False) 76 | parser.add_argument('--meta-aug', default=1, type=float) 77 | 78 | parser.add_argument('--replace', action='store_true') 79 | 80 | parser.add_argument('--TTAug', action='store_true') 81 | parser.add_argument('--TTA-bs', default=3, type=int) 82 | parser.add_argument('--TTA-head', default='em') 83 | 84 | # augment data in dataset 85 | parser.add_argument('--jigsaw', action='store_true') 86 | parser.add_argument('--rot', action='store_true') 87 | 88 | # loss list 89 | parser.add_argument('--head', type=str, default='em', help='Classification for DomainAdaptor') 90 | parser.add_argument('--loss-names', nargs='+', type=str, default=['gem-t']) 91 | 92 | # AdaMixBN 93 | parser.add_argument('--AdaMixBN', action='store_true', default=True) 94 | parser.add_argument('--Transform', action='store_true', default=True) 95 | parser.add_argument('--mix-lambda', type=float, default=None) 96 | 97 | parser.add_argument('--LAME', action='store_true', default=False) 98 | parser.add_argument('--online', action='store_true', default=False) 99 | 100 | return parser 101 | 102 | 103 | if __name__ == '__main__': 104 | parser = get_default_parser() 105 | args = parser.parse_args() 106 | exp = Experiments(args) 107 | exp.run() 108 | -------------------------------------------------------------------------------- /models/AdaptorHeads.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from models.AdaptorHelper import get_new_optimizers 5 | 6 | 7 | class Head(nn.Module): 8 | Replace = False 9 | ft_steps = 1 10 | 11 | def __init__(self, num_classes, in_ch, args): 12 | super(Head, self).__init__() 13 | self.args = args 14 | self.in_ch = in_ch 15 | self.num_classes = num_classes 16 | 17 | def forward(self, base_features, x, label, backbone, **kwargs): 18 | raise NotImplementedError() 19 | 20 | def setup(self, whole_model, online): 21 | whole_model.backbone.train() 22 | lr = 0.05 23 | print(f'Learning rate : {lr}') 24 | return [ 25 | get_new_optimizers(whole_model, lr=lr, names=['bn'], opt_type='sgd') 26 | ] 27 | 28 | def do_ft(self, backbone, x, label, **kwargs): 29 | return self.do_train(backbone, x, label, **kwargs) 30 | 31 | def do_test(self, backbone, x, label, **kwargs): 32 | return self.do_train(backbone, x, label, **kwargs) 33 | 34 | def do_train(self, backbone, x, label, **kwargs): 35 | base_features = backbone(x) 36 | class_dict = {'main': {'loss_type': 'ce', 'acc_type': 'acc', 'pred': base_features[-1], 'target': label}} 37 | return class_dict 38 | 39 | 40 | class RotationHead(Head): 41 | KEY = 'rotation' 42 | 43 | def setup(self, whole_model, online): 44 | whole_model.backbone.train() 45 | lr = 0.05 46 | print(f'Learning rate : {lr}') 47 | return [ 48 | get_new_optimizers(whole_model, lr=lr, names=['bn'], opt_type='sgd') 49 | ] 50 | 51 | def __init__(self, num_classes, in_ch, args): 52 | super(RotationHead, self).__init__(num_classes, in_ch, args) 53 | self.shared = args.shared 54 | self.rotation_fc = nn.Linear(512, 4, bias=False) 55 | emb_dim = in_ch 56 | # self.rotation_fc = nn.Sequential( 57 | # nn.Linear(in_ch, emb_dim), 58 | # nn.ReLU(), 59 | # nn.Linear(emb_dim, emb_dim), 60 | # nn.ReLU(), 61 | # nn.Linear(emb_dim, 4), 62 | # ) 63 | 64 | def do_ft(self, backbone, x, label, **kwargs): 65 | logits = backbone(x)[-1] 66 | 67 | rotated_x, rotation_label = kwargs['rot_x'], kwargs['rot_label'] 68 | l4 = backbone(rotated_x)[-2].mean((-1, -2)) 69 | rotation_logits = self.rotation_fc(l4) 70 | 71 | return { 72 | 'main': {'acc_type': 'acc', 'pred': logits, 'target': label}, 73 | 'rotation': {'loss_type': 'ce', 'acc_type': 'acc', 'pred': rotation_logits, 'target': rotation_label} 74 | } 75 | 76 | def do_train(self, backbone, x, label, **kwargs): 77 | base_features = backbone(x) 78 | 79 | rotated_x, rotation_label = kwargs['rot_x'], kwargs['rot_label'] 80 | l4 = backbone(rotated_x)[-2].mean((-1, -2)) 81 | 82 | rotation_logits = self.rotation_fc(l4) 83 | 84 | class_dict = { 85 | 'main': {'loss_type': 'ce', 'acc_type': 'acc', 'pred': base_features[-1], 'target': label}, 86 | 'rotation': {'loss_type': 'ce', 'acc_type': 'acc', 'pred': rotation_logits, 'target': rotation_label, 'weight':0.0} 87 | } 88 | return class_dict 89 | 90 | 91 | class NormHead(Head): 92 | KEY = 'Norm' 93 | 94 | def __init__(self, num_classes, in_ch, args): 95 | super(NormHead, self).__init__(num_classes, in_ch, args) 96 | 97 | class MLP(nn.Module): 98 | def __init__(self, in_size=10, out_size=1, hidden_dim=32, norm_reduce=False): 99 | super(MLP, self).__init__() 100 | self.norm_reduce = norm_reduce 101 | self.model = nn.Sequential( 102 | nn.Linear(in_size, hidden_dim), 103 | nn.ReLU(), 104 | nn.Linear(hidden_dim, hidden_dim), 105 | nn.ReLU(), 106 | nn.Linear(hidden_dim, hidden_dim), 107 | nn.ReLU(), 108 | nn.Linear(hidden_dim, out_size), 109 | ) 110 | 111 | def forward(self, x): 112 | out = self.model(x) 113 | if self.norm_reduce: 114 | out = torch.norm(out) 115 | return out 116 | 117 | self.mlp = MLP(in_size=num_classes, norm_reduce=True) 118 | 119 | def do_ft(self, backbone, x, label, **kwargs): 120 | base_features = backbone(x) 121 | feats = base_features[-1] 122 | normed_loss = self.mlp(feats) 123 | return { 124 | 'main': {'acc_type': 'acc', 'pred': base_features[-1], 'target': label}, 125 | 'norm_loss': {'loss': normed_loss}, 126 | } 127 | 128 | def do_train(self, backbone, x, label, **kwargs): 129 | base_features = backbone(x) 130 | feats = base_features[-1] 131 | normed_loss = self.mlp(feats) 132 | return { 133 | 'main': {'loss_type': 'ce', 'acc_type': 'acc', 'pred': base_features[-1], 'target': label}, 134 | NormHead.KEY: {'loss': normed_loss} 135 | } 136 | 137 | 138 | class JigsawHead(Head): 139 | KEY = 'Jigsaw' 140 | 141 | def __init__(self, num_classes, in_ch, args): 142 | super(JigsawHead, self).__init__(num_classes, in_ch, args) 143 | jigsaw_classes = 32 144 | emb_dim = in_ch 145 | # self.jigsaw_classifier = nn.Linear(in_ch, jigsaw_classes) 146 | self.jigsaw_classifier = nn.Sequential( 147 | nn.Linear(in_ch, emb_dim), 148 | nn.ReLU(), 149 | nn.Linear(emb_dim, emb_dim), 150 | nn.ReLU(), 151 | nn.Linear(emb_dim, jigsaw_classes), 152 | ) 153 | self.i = 0 154 | 155 | def do_ft(self, backbone, x, label, **kwargs): 156 | base_features = backbone(x) 157 | logits = base_features[-1] 158 | 159 | jig_features = backbone(kwargs['jigsaw_x'])[-2] 160 | jig_features = jig_features.mean((-1, -2)) 161 | jig_logits = self.jigsaw_classifier(jig_features) 162 | return { 163 | # 'main': {'acc_type': 'acc', 'pred': logits, 'target': label}, 164 | 'jig': {'acc_type': 'acc', 'pred': jig_logits, 'target': kwargs['jigsaw_label'], 'loss_type': 'ce'}, 165 | } 166 | 167 | def train(self, mode=True): 168 | super(JigsawHead, self).train(mode) 169 | self.i = 0 170 | 171 | def do_train(self, backbone, x, label, **kwargs): 172 | base_features = backbone(x) 173 | logits = base_features[-1] 174 | ret = { 175 | 'main': {'acc_type': 'acc', 'pred': logits, 'target': label, 'loss_type': 'ce'}, 176 | } 177 | # if self.i == 0 or random.random() > 0.9: 178 | # self.i = 1 179 | if True: 180 | jig_features = backbone(kwargs['jigsaw_x']) 181 | jig_class_logits = jig_features[-1] 182 | jig_features = jig_features[-2].mean((-1, -2)) 183 | jig_logits = self.jigsaw_classifier(jig_features) 184 | ret.update({ 185 | 'jig': {'acc_type': 'acc', 'pred': jig_logits, 'target': kwargs['jigsaw_label'], 'loss_type': 'ce', 'weight':0.1}, 186 | # 'jig_cls': {'acc_type': 'acc', 'pred': jig_class_logits, 'target': label, 'loss_type': 'ce', 'weight':0.5}, 187 | }) 188 | return ret 189 | 190 | def setup(self, whole_model, online): 191 | whole_model.backbone.train() 192 | # online best : 0.01 193 | # not online : 0.02? 194 | lr = 0.01 # 0.0005 is better for MDN 195 | print(f"Learning rate : {lr} ") 196 | return get_new_optimizers(whole_model, lr=lr, names=['bn'], opt_type='sgd', momentum=online) 197 | 198 | 199 | class NoneHead(Head): 200 | def do_train(self, backbone, x, label, **kwargs): 201 | base_features = backbone(x) 202 | return { 203 | 'main': {'loss_type': 'ce', 'acc_type': 'acc', 'pred': base_features[-1], 'target': label}, 204 | } 205 | 206 | -------------------------------------------------------------------------------- /models/AdaptorHelper.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn as nn 3 | 4 | 5 | def module_has_string(has_strings, x): 6 | # No string required, all layers can be converted 7 | if len(has_strings) == 0: 8 | return True 9 | 10 | # Only modules with names contain one string in has_strings can be converted 11 | for string in has_strings: 12 | if string in x: 13 | return True 14 | return False 15 | 16 | 17 | def has_string(has_strings, x): 18 | # No string required, all layers can be converted 19 | if len(has_strings) == 0: 20 | return False 21 | 22 | # Only modules with names contain one string in has_strings can be converted 23 | for string in has_strings: 24 | if string in x: 25 | return True 26 | return False 27 | 28 | 29 | def collect_module_params(model, module_class, has_strings=[]): 30 | params = [] 31 | names = [] 32 | nnn = [] 33 | for nm, m in model.named_modules(): 34 | if has_string(nnn, nm): 35 | continue 36 | if isinstance(m, module_class) and module_has_string(has_strings, nm): 37 | for np, p in m.named_parameters(): 38 | params.append(p) 39 | names.append(f"{nm}.{np}") 40 | return params, names 41 | 42 | 43 | def set_param_trainable(model, module_names, requires_grad): 44 | classes = { 45 | 'bn': nn.BatchNorm2d, 46 | 'conv': nn.Conv2d, 47 | 'fc': nn.Linear 48 | } 49 | set_names = [] 50 | for name in module_names: 51 | params, param_names = collect_module_params(model, classes[name]) 52 | for p in params: 53 | p.requires_grad = requires_grad 54 | set_names.extend(param_names) 55 | return set_names 56 | 57 | 58 | def remove_param_grad(model, module_names): 59 | classes = { 60 | 'bn': nn.BatchNorm2d, 61 | 'conv': nn.Conv2d, 62 | 'fc': nn.Linear 63 | } 64 | set_names = [] 65 | for name in module_names: 66 | params, param_names = collect_module_params(model, classes[name]) 67 | for p in params: 68 | p.grad = None 69 | set_names.extend(param_names) 70 | return set_names 71 | 72 | 73 | def get_optimizer(opt_dic, lr, opt_type='sgd', momentum=True): 74 | if opt_type == 'sgd': 75 | if momentum: 76 | opt = torch.optim.SGD(opt_dic, lr=lr, momentum=0.9) 77 | else: 78 | opt = torch.optim.SGD(opt_dic, lr=lr) 79 | else: 80 | opt = torch.optim.Adam(opt_dic, lr=lr, betas=(0.9, 0.999)) 81 | return opt 82 | 83 | 84 | def get_new_optimizers(model, lr=1e-4, names=('bn', 'conv', 'fc'), opt_type='sgd', momentum=False): 85 | optimizers, opt_names = [], [] 86 | classes = { 87 | 'bn': nn.BatchNorm2d, 88 | 'conv': nn.Conv2d, 89 | 'fc': nn.Linear 90 | } 91 | opt_dic = [] 92 | for name in names: 93 | name = name.lower() 94 | params, names = collect_module_params(model, module_class=classes[name]) 95 | for param in params: 96 | opt_dic.append({'params': param, 'lr': lr}) 97 | opt = get_optimizer(opt_dic, lr, opt_type, momentum) 98 | # optimizers.append(opt) 99 | # opt_names.append(names) 100 | return opt 101 | 102 | 103 | def convert_to_target(net, norm, start=0, end=5, verbose=True, res50=False): 104 | def convert_norm(old_norm, new_norm, num_features, idx): 105 | norm_layer = new_norm(num_features, idx=idx).to(net.conv1.weight.device) 106 | if hasattr(norm_layer, 'load_old_dict'): 107 | info = 'Converted to : {}'.format(norm) 108 | norm_layer.load_old_dict(old_norm) 109 | elif hasattr(norm_layer, 'load_state_dict'): 110 | state_dict = old_norm.state_dict() 111 | info = norm_layer.load_state_dict(state_dict, strict=False) 112 | else: 113 | info = 'No load_old_dict() found!!!' 114 | if verbose: 115 | print(info) 116 | return norm_layer 117 | 118 | layers = [0, net.layer1, net.layer2, net.layer3, net.layer4] 119 | 120 | idx = 0 121 | converted_bns = {} 122 | for i, layer in enumerate(layers): 123 | if not (start <= i < end): 124 | continue 125 | if i == 0: 126 | net.bn1 = convert_norm(net.bn1, norm, net.bn1.num_features, idx) 127 | converted_bns['L0-BN0-0'] = net.bn1 128 | idx += 1 129 | else: 130 | for j, bottleneck in enumerate(layer): 131 | bottleneck.bn1 = convert_norm(bottleneck.bn1, norm, bottleneck.bn1.num_features, idx) 132 | converted_bns['L{}-BN{}-{}'.format(i, j, 0)] = bottleneck.bn1 133 | idx += 1 134 | bottleneck.bn2 = convert_norm(bottleneck.bn2, norm, bottleneck.bn2.num_features, idx) 135 | converted_bns['L{}-BN{}-{}'.format(i, j, 1)] = bottleneck.bn2 136 | idx += 1 137 | if res50: 138 | bottleneck.bn3 = convert_norm(bottleneck.bn3, norm, bottleneck.bn3.num_features, idx) 139 | converted_bns['L{}-BN{}-{}'.format(i, j, 3)] = bottleneck.bn3 140 | idx += 1 141 | if bottleneck.downsample is not None: 142 | bottleneck.downsample[1] = convert_norm(bottleneck.downsample[1], norm, bottleneck.downsample[1].num_features, idx) 143 | converted_bns['L{}-BN{}-{}'.format(i, j, 2)] = bottleneck.downsample[1] 144 | idx += 1 145 | return net, converted_bns 146 | -------------------------------------------------------------------------------- /models/DomainAdaptor.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import functools 3 | import warnings 4 | 5 | from framework.ERM import ERM 6 | from framework.loss_and_acc import * 7 | from framework.registry import EvalFuncs, Models 8 | from models.AdaptorHeads import RotationHead, NormHead, NoneHead, Head, JigsawHead 9 | from models.AdaptorHelper import get_new_optimizers, convert_to_target 10 | from models.LAME import laplacian_optimization, kNN_affinity 11 | from utils.tensor_utils import to, AverageMeterDict, zero_and_update 12 | 13 | 14 | warnings.filterwarnings("ignore") 15 | np.set_printoptions(edgeitems=30, linewidth=1000, formatter=dict(float=lambda x: "{:.4f}, ".format(x))) 16 | 17 | 18 | class AdaMixBN(nn.BatchNorm2d): 19 | # AdaMixBn cannot be applied in an online manner. 20 | def __init__(self, in_ch, lambd=None, transform=True, mix=True, idx=0): 21 | super(AdaMixBN, self).__init__(in_ch) 22 | self.lambd = lambd 23 | self.rectified_params = None 24 | self.transform = transform 25 | self.layer_idx = idx 26 | self.mix = mix 27 | 28 | def get_retified_gamma_beta(self, lambd, src_mu, src_var, cur_mu, cur_var): 29 | C = src_mu.shape[1] 30 | new_gamma = (cur_var + self.eps).sqrt() / (lambd * src_var + (1 - lambd) * cur_var + self.eps).sqrt() * self.weight.view(1, C, 1, 1) 31 | new_beta = lambd * (cur_mu - src_mu) / (cur_var + self.eps).sqrt() * new_gamma + self.bias.view(1, C, 1, 1) 32 | return new_gamma.view(-1), new_beta.view(-1) 33 | 34 | def get_lambd(self, x, src_mu, src_var, cur_mu, cur_var): 35 | instance_mu = x.mean((2, 3), keepdims=True) 36 | instance_std = x.std((2, 3), keepdims=True) 37 | 38 | it_dist = ((instance_mu - cur_mu) ** 2).mean(1, keepdims=True) + ((instance_std - cur_var.sqrt()) ** 2).mean(1, keepdims=True) 39 | is_dist = ((instance_mu - src_mu) ** 2).mean(1, keepdims=True) + ((instance_std - src_var.sqrt()) ** 2).mean(1, keepdims=True) 40 | st_dist = ((cur_mu - src_mu) ** 2).mean(1)[None] + ((cur_var.sqrt() - src_var.sqrt()) ** 2).mean(1)[None] 41 | 42 | src_lambd = 1 - (st_dist) / (st_dist + is_dist + it_dist) 43 | 44 | src_lambd = torch.clip(src_lambd, min=0, max=1) 45 | return src_lambd 46 | 47 | def get_mu_var(self, x): 48 | C = x.shape[1] 49 | src_mu = self.running_mean.view(1, C, 1, 1) 50 | src_var = self.running_var.view(1, C, 1, 1) 51 | cur_mu = x.mean((0, 2, 3), keepdims=True) 52 | cur_var = x.var((0, 2, 3), keepdims=True) 53 | 54 | lambd = self.get_lambd(x, src_mu, src_var, cur_mu, cur_var).mean(0, keepdims=True) 55 | 56 | if self.lambd is not None: 57 | lambd = self.lambd 58 | 59 | if self.transform: 60 | if self.rectified_params is None: 61 | new_gamma, new_beta = self.get_retified_gamma_beta(lambd, src_mu, src_var, cur_mu, cur_var) 62 | # self.test(x, lambd, src_mu, src_var, cur_mu, cur_var, new_gamma, new_beta) 63 | self.weight.data = new_gamma.data 64 | self.bias.data = new_beta.data 65 | self.rectified_params = new_gamma, new_beta 66 | return cur_mu, cur_var 67 | else: 68 | new_mu = lambd * src_mu + (1 - lambd) * cur_mu 69 | new_var = lambd * src_var + (1 - lambd) * cur_var 70 | return new_mu, new_var 71 | 72 | def forward(self, x): 73 | n, C, H, W = x.shape 74 | new_mu = x.mean((0, 2, 3), keepdims=True) 75 | new_var = x.var((0, 2, 3), keepdims=True) 76 | 77 | if self.training: 78 | if self.mix: 79 | new_mu, new_var = self.get_mu_var(x) 80 | 81 | # Normalization with new statistics 82 | inv_std = 1 / (new_var + self.eps).sqrt() 83 | new_x = (x - new_mu) * (inv_std * self.weight.view(1, C, 1, 1)) + self.bias.view(1, C, 1, 1) 84 | return new_x 85 | else: 86 | return super(AdaMixBN, self).forward(x) 87 | 88 | def reset(self): 89 | self.rectified_params = None 90 | 91 | def test_equivalence(self, x): 92 | C = x.shape[1] 93 | src_mu = self.running_mean.view(1, C, 1, 1) 94 | src_var = self.running_var.view(1, C, 1, 1) 95 | cur_mu = x.mean((0, 2, 3), keepdims=True) 96 | cur_var = x.var((0, 2, 3), keepdims=True) 97 | lambd = 0.9 98 | 99 | new_gamma, new_beta = self.get_retified_gamma_beta(x, lambd, src_mu, src_var, cur_mu, cur_var) 100 | inv_std = 1 / (cur_var + self.eps).sqrt() 101 | x_1 = (x - cur_mu) * (inv_std * new_gamma.view(1, C, 1, 1)) + new_beta.view(1, C, 1, 1) 102 | 103 | new_mu = lambd * src_mu + (1 - lambd) * cur_mu 104 | new_var = lambd * src_var + (1 - lambd) * cur_var 105 | inv_std = 1 / (new_var + self.eps).sqrt() 106 | x_2 = (x - new_mu) * (inv_std * self.weight.view(1, C, 1, 1)) + self.bias.view(1, C, 1, 1) 107 | assert (x_2 - x_1).abs().mean() < 1e-5 108 | return x_1, x_2 109 | 110 | 111 | class Losses(): 112 | def __init__(self): 113 | self.losses = { 114 | 'em': self.em, 115 | 'slr': self.slr, 116 | 'norm': self.norm, 117 | 'gem-t': self.GEM_T, 118 | 'gem-skd': self.GEM_SKD, 119 | 'gem-aug': self.GEM_Aug, 120 | } 121 | 122 | def GEM_T(self, logits, **kwargs): 123 | logits = logits - logits.mean(1, keepdim=True).detach() 124 | T = logits.std(1, keepdim=True).detach() * 2 125 | prob = (logits / T).softmax(1) 126 | loss = - ((prob * prob.log()).sum(1) * (T ** 2)).mean() 127 | return loss 128 | 129 | def GEM_SKD(self, logits, **kwargs): 130 | logits = logits - logits.mean(1, keepdim=True).detach() 131 | T = logits.std(1, keepdim=True).detach() * 2 132 | 133 | original_prob = logits.softmax(1) 134 | prob = (logits / T).softmax(1) 135 | 136 | loss = - ((original_prob.detach() * prob.log()).sum(1) * (T ** 2)).mean() 137 | return loss 138 | 139 | def GEM_Aug(self, logits, **kwargs): 140 | logits = logits - logits.mean(1, keepdim=True).detach() 141 | T = logits.std(1, keepdim=True).detach() * 2 142 | aug_logits = kwargs['aug_logits'] 143 | loss = - ((aug_logits.softmax(1).detach() * (logits / T).softmax(1).log()).sum(1) * (T ** 2)).mean() 144 | return loss 145 | 146 | def em(self, logits, **kwargs): 147 | prob = (logits).softmax(1) 148 | loss = (- prob * prob.log()).sum(1) 149 | return loss.mean() 150 | 151 | def slr(self, logits, **kwargs): 152 | prob = (logits).softmax(1) 153 | return -(prob * (1 / (1 - prob + 1e-8)).log()).sum(1).mean() # * 3 is enough = 82.7 154 | 155 | def norm(self, logits, **kwargs): 156 | return -logits.norm(dim=1).mean() * 2 157 | 158 | def get_loss(self, name, **kwargs): 159 | return {name: {'loss': self.losses[name.lower()](**kwargs)}} 160 | 161 | 162 | class EntropyMinimizationHead(Head): 163 | KEY = 'EM' 164 | ft_steps = 1 165 | 166 | def __init__(self, num_classes, in_ch, args): 167 | super(EntropyMinimizationHead, self).__init__(num_classes, in_ch, args) 168 | self.losses = Losses() 169 | 170 | def get_cos_logits(self, feats, backbone): 171 | w = backbone.fc.weight # c X C 172 | w, feats = F.normalize(w, dim=1), F.normalize(feats, dim=1) 173 | logits = (feats @ w.t()) # / 0.07 174 | return logits 175 | 176 | def label_rectify(self, feats, logits, thresh=0.95): 177 | # mask = self.get_confident_mask(logits, thresh=thresh) 178 | max_prob = logits.softmax(1).max(1)[0] 179 | normed_feats = feats / feats.norm(dim=1, keepdim=True) 180 | # N x N 181 | sim = (normed_feats @ normed_feats.t()) / 0.07 182 | # sim = feats @ feats.t() 183 | # select from high confident masks 184 | selected_sim = sim # * max_prob[None] 185 | # N x n @ n x C = N x C 186 | rectified_feats = (selected_sim.softmax(1) @ feats) 187 | return rectified_feats + feats 188 | 189 | def do_lame(self, feats, logits): 190 | prob = logits.softmax(1) 191 | unary = - torch.log(prob + 1e-10) # [N, K] 192 | 193 | feats = F.normalize(feats, p=2, dim=-1) # [N, d] 194 | kernel = kNN_affinity(5)(feats) # [N, N] 195 | 196 | kernel = 1 / 2 * (kernel + kernel.t()) 197 | 198 | # --- Perform optim --- 199 | Y = laplacian_optimization(unary, kernel) 200 | return Y 201 | 202 | def do_ft(self, backbone, x, label, loss_name=None, step=0, model=None, **kwargs): 203 | assert loss_name is not None 204 | 205 | if loss_name.lower() == 'gem-aug': 206 | with torch.no_grad(): 207 | aug_x = kwargs['tta'] 208 | n, N, C, H, W = aug_x.shape 209 | aug_x = aug_x.reshape(n * N, C, H, W) 210 | aug_logits = backbone(aug_x)[-1].view(n, N, -1).mean(1) 211 | else: 212 | aug_logits = None 213 | 214 | base_features = backbone(x) 215 | logits, feats = base_features[-1], base_features[-2].mean((2, 3)) 216 | ret = { 217 | 'main': {'acc_type': 'acc', 'pred': logits, 'target': label}, 218 | 'logits': logits.detach() 219 | } 220 | 221 | ret.update(self.losses.get_loss(loss_name, logits=logits, backbone=backbone, feats=feats, 222 | step=step, aug_logits=aug_logits)) 223 | return ret 224 | 225 | def do_train(self, backbone, x, label, **kwargs): 226 | base_features = backbone(x) 227 | logits, feats = base_features[-1], base_features[-2].mean((2, 3)) 228 | 229 | res = { 230 | 'main': {'loss_type': 'ce', 'acc_type': 'acc', 'pred': logits, 'target': label}, 231 | 'logits': logits.detach() 232 | } 233 | if self.args.LAME: 234 | res.update({'LAME': {'acc_type': 'acc', 'pred': self.do_lame(feats, logits), 'target': label}}) 235 | return res 236 | 237 | def setup(self, model, online): 238 | model.backbone.train() 239 | lr = self.args.lr 240 | print(f'Learning rate : {lr}') 241 | return [ 242 | get_new_optimizers(model, lr=lr, names=['bn'], opt_type='sgd', momentum=self.args.online), 243 | ] 244 | 245 | 246 | @Models.register('DomainAdaptor') 247 | class DomainAdaptor(ERM): 248 | def __init__(self, num_classes, pretrained=True, args=None): 249 | super(DomainAdaptor, self).__init__(num_classes, pretrained, args) 250 | heads = { 251 | 'em': EntropyMinimizationHead, 252 | 'rot': RotationHead, 253 | 'norm': NormHead, 254 | 'none': NoneHead, 255 | 'jigsaw': JigsawHead, 256 | } 257 | self.head = heads[args.TTA_head.lower()](num_classes, self.in_ch, args) 258 | 259 | if args.AdaMixBN: 260 | self.bns = list(convert_to_target(self.backbone, functools.partial(AdaMixBN, transform=args.Transform, lambd=args.mix_lambda), 261 | verbose=False, start=0, end=5, res50=args.backbone == 'resnet50')[-1].values()) 262 | 263 | def step(self, x, label, train_mode='test', **kwargs): 264 | if train_mode == 'train': 265 | res = self.head.do_train(self.backbone, x, label, model=self, **kwargs) 266 | elif train_mode == 'test': 267 | res = self.head.do_test(self.backbone, x, label, model=self, **kwargs) 268 | elif train_mode == 'ft': 269 | res = self.head.do_ft(self.backbone, x, label, model=self, **kwargs) 270 | else: 271 | raise Exception("Unexpected mode : {}".format(train_mode)) 272 | return res 273 | 274 | def finetune(self, data, optimizers, loss_name, running_loss=None, running_corrects=None): 275 | if hasattr(self, 'bns'): 276 | [bn.reset() for bn in self.bns] 277 | 278 | with torch.enable_grad(): 279 | res = None 280 | for i in range(self.head.ft_steps): 281 | o = self.step(**data, train_mode='ft', step=i, loss_name=loss_name) 282 | meta_train_loss = get_loss_and_acc(o, running_loss, running_corrects, prefix=f'A{i}_') 283 | zero_and_update(optimizers, meta_train_loss) 284 | if i == 0: 285 | res = o 286 | return res 287 | 288 | def forward(self, *args, **kwargs): 289 | return self.step(*args, **kwargs) 290 | 291 | def setup(self, online): 292 | return self.head.setup(self, online) 293 | 294 | 295 | @EvalFuncs.register('tta_ft') 296 | def test_time_adaption(model, eval_data, lr, epoch, args, engine, mode): 297 | device, optimizers = engine.device, engine.optimizers 298 | running_loss, running_corrects = AverageMeterDict(), AverageMeterDict() 299 | 300 | model.eval() 301 | model_to_ft = copy.deepcopy(model) 302 | original_state_dict = model.state_dict() 303 | 304 | online = args.online 305 | optimizers = model_to_ft.setup(online) 306 | 307 | loss_names = args.loss_names # 'gem-t', 'gem-skd', 'gem-tta'] 308 | 309 | with torch.no_grad(): 310 | for i, data in enumerate(eval_data): 311 | data = to(data, device) 312 | 313 | # Normal Test 314 | out = model(**data, train_mode='test') 315 | get_loss_and_acc(out, running_loss, running_corrects, prefix='original_') 316 | 317 | # test-time adaptation to a single batch 318 | for loss_name in loss_names: 319 | # recover to the original weight 320 | model_to_ft.load_state_dict(original_state_dict) if (not online) else "" 321 | 322 | # adapt to the current batch 323 | adapt_out = model_to_ft.finetune(data, optimizers, loss_name, running_loss, running_corrects) 324 | 325 | # get the adapted result 326 | cur_out = model_to_ft(**data, train_mode='test') 327 | 328 | get_loss_and_acc(cur_out, running_loss, running_corrects, prefix=f'{loss_name}_') 329 | if loss_name == loss_names[-1]: 330 | get_loss_and_acc(cur_out, running_loss, running_corrects) # the last one is recorded as the main result 331 | 332 | loss, acc = running_loss.get_average_dicts(), running_corrects.get_average_dicts() 333 | return acc['main'], (loss, acc) 334 | -------------------------------------------------------------------------------- /models/LAME.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.jit 3 | import logging 4 | from typing import List, Dict 5 | 6 | import time 7 | import torch.nn.functional as F 8 | 9 | 10 | __all__ = ["LAME"] 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | 15 | class AffinityMatrix: 16 | 17 | def __call__(X, **kwargs): 18 | raise NotImplementedError 19 | 20 | def is_psd(self, mat): 21 | eigenvalues = torch.eig(mat)[0][:, 0].sort(descending=True)[0] 22 | return eigenvalues, float((mat == mat.t()).all() and (eigenvalues >= 0).all()) 23 | 24 | def symmetrize(self, mat): 25 | return 1 / 2 * (mat + mat.t()) 26 | 27 | 28 | class kNN_affinity(AffinityMatrix): 29 | def __init__(self, knn: int, **kwargs): 30 | self.knn = knn 31 | 32 | def __call__(self, X): 33 | N = X.size(0) 34 | dist = torch.norm(X.unsqueeze(0) - X.unsqueeze(1), dim=-1, p=2) # [N, N] 35 | n_neighbors = min(self.knn + 1, N) 36 | 37 | knn_index = dist.topk(n_neighbors, -1, largest=False).indices[:, 1:] # [N, knn] 38 | 39 | W = torch.zeros(N, N, device=X.device) 40 | W.scatter_(dim=-1, index=knn_index, value=1.0) 41 | 42 | return W 43 | 44 | 45 | class rbf_affinity(AffinityMatrix): 46 | def __init__(self, sigma: float, **kwargs): 47 | self.sigma = sigma 48 | self.k = kwargs['knn'] 49 | 50 | def __call__(self, X): 51 | 52 | N = X.size(0) 53 | dist = torch.norm(X.unsqueeze(0) - X.unsqueeze(1), dim=-1, p=2) # [N, N] 54 | n_neighbors = min(self.k, N) 55 | kth_dist = dist.topk(k=n_neighbors, dim=-1, largest=False).values[:, -1] # compute k^th distance for each point, [N, knn + 1] 56 | sigma = kth_dist.mean() 57 | rbf = torch.exp(- dist ** 2 / (2 * sigma ** 2)) 58 | # mask = torch.eye(X.size(0)).to(X.device) 59 | # rbf = rbf * (1 - mask) 60 | return rbf 61 | 62 | 63 | class linear_affinity(AffinityMatrix): 64 | 65 | def __call__(self, X: torch.Tensor): 66 | """ 67 | X: [N, d] 68 | """ 69 | return torch.matmul(X, X.t()) 70 | 71 | 72 | def laplacian_optimization(unary, kernel, bound_lambda=1, max_steps=100): 73 | 74 | E_list = [] 75 | oldE = float('inf') 76 | Y = (-unary).softmax(-1) # [N, K] 77 | for i in range(max_steps): 78 | pairwise = bound_lambda * kernel.matmul(Y) # [N, K] 79 | exponent = -unary + pairwise 80 | Y = exponent.softmax(-1) 81 | E = entropy_energy(Y, unary, pairwise, bound_lambda).item() 82 | E_list.append(E) 83 | 84 | if (i > 1 and (abs(E - oldE) <= 1e-8 * abs(oldE))): 85 | logger.info(f'Converged in {i} iterations') 86 | break 87 | else: 88 | oldE = E 89 | 90 | return Y 91 | 92 | 93 | def entropy_energy(Y, unary, pairwise, bound_lambda): 94 | E = (unary * Y - bound_lambda * pairwise * Y + Y * torch.log(Y.clip(1e-20))).sum() 95 | return E 96 | -------------------------------------------------------------------------------- /models/MetaModel.py: -------------------------------------------------------------------------------- 1 | from framework.loss_and_acc import * 2 | from framework.meta_util import split_image_and_label 3 | from framework.registry import EvalFuncs, TrainFuncs 4 | from models.AdaptorHelper import get_new_optimizers 5 | from utils.tensor_utils import to, AverageMeterDict 6 | 7 | """ 8 | ARM 9 | """ 10 | @TrainFuncs.register('tta_meta') 11 | def tta_meta_train2(meta_model, train_data, lr, epoch, args, engine, mode): 12 | import higher 13 | device, optimizers = engine.device, engine.optimizers 14 | running_loss, running_corrects = AverageMeterDict(), AverageMeterDict() 15 | 16 | inner_opt_conv = get_new_optimizers(meta_model, lr=args.meta_lr, names=['bn'], momentum=False) 17 | meta_model.train() 18 | print(f'Meta LR : {args.meta_lr}') 19 | 20 | for data_list in train_data: 21 | data_list = to(data_list, device) 22 | split_data = split_image_and_label(data_list, size=args.batch_size) 23 | 24 | for data in split_data: 25 | 26 | with higher.innerloop_ctx(meta_model, inner_opt_conv, copy_initial_weights=False, track_higher_grads=True) as (fnet, diffopt): 27 | for _ in range(args.meta_step): 28 | unsup_loss = get_loss_and_acc(fnet(**data, train_mode='ft', step=_), running_loss, running_corrects, prefix=f'adapt{_}_') 29 | diffopt.step(unsup_loss) 30 | main_loss, unsup_loss = get_loss_and_acc(fnet(**data, train_mode='train'), running_loss, running_corrects, reduction='none') 31 | (main_loss).backward() 32 | optimizers.step() 33 | optimizers.zero_grad() 34 | 35 | return running_loss.get_average_dicts(), running_corrects.get_average_dicts() 36 | 37 | 38 | @EvalFuncs.register('tta_meta') 39 | def tta_meta_test(meta_model, eval_data, lr, epoch, args, engine, mode): 40 | import higher 41 | device, optimizers = engine.device, engine.optimizers 42 | running_loss, running_corrects = AverageMeterDict(), AverageMeterDict() 43 | 44 | inner_opt = get_new_optimizers(meta_model, lr=args.meta_lr, names=['bn'], momentum=False) 45 | meta_model.eval() 46 | for data in eval_data: 47 | data = to(data, device) 48 | 49 | with torch.no_grad(): # Normal Test 50 | get_loss_and_acc(meta_model.step(**data, train_mode='test'), running_loss, running_corrects, prefix='original_') 51 | 52 | with higher.innerloop_ctx(meta_model, inner_opt, copy_initial_weights=False, track_higher_grads=False) as (fnet, diffopt): 53 | fnet.train() 54 | for _ in range(args.meta_step): 55 | unsup_loss = get_loss_and_acc(fnet(**data, train_mode='ft', step=_), running_loss, running_corrects, prefix=f'adapt{_}_') 56 | diffopt.step(unsup_loss) 57 | get_loss_and_acc(fnet(**data, train_mode='test'), running_loss, running_corrects) 58 | 59 | loss, acc = running_loss.get_average_dicts(), running_corrects.get_average_dicts() 60 | if 'main' in acc: 61 | return acc['main'], (loss, acc) 62 | else: 63 | return 0, (loss, acc) 64 | -------------------------------------------------------------------------------- /models/TTA.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from pathlib import Path 3 | 4 | import numpy as np 5 | import torch.optim 6 | from torch import nn 7 | from tqdm import tqdm 8 | import torch.nn.functional as F 9 | from framework.engine import get_optimizers 10 | from framework.registry import Datasets, EvalFuncs 11 | from models.AdaptorHelper import get_new_optimizers 12 | from utils.tensor_utils import AverageMeterDict, zero_and_update 13 | from framework.loss_and_acc import get_loss_and_acc 14 | from utils.tensor_utils import to 15 | 16 | 17 | @EvalFuncs.register('tta') 18 | def TTA_eval_model(model, eval_data, lr, epoch, args, engine, mode): 19 | device = engine.device 20 | running_loss, running_corrects, shadow_running_corrects = AverageMeterDict(), AverageMeterDict(), AverageMeterDict() 21 | 22 | # State dicts 23 | cur_domain = Datasets[args.dataset].Domains[args.exp_num[0]] 24 | model_path = args.TTA_model_path 25 | # p = '{}/{}{}/models/model_best.pt'.format(model_path, cur_domain, engine.time) 26 | # p = f'ckpts/FSDCL/{cur_domain}0/models/model_best.pt' 27 | # p = '/data/zj/PycharmProjects/DG-Feature-Stylization-main/test—c/model/model-best.pth.tar' 28 | # p = '/data/gjt/sagnet-master/checkpoint/PACS/sagnet/cartoon,sketch,photo/checkpoint_latest.pth' 29 | if engine.target_domain[0] == 'art_painting': 30 | p = '/data/zj/PycharmProjects/FACT-main/output/PACS_ResNet50/art_painting/2022-05-29-14-06-46/best_model.tar' 31 | elif engine.target_domain[0] == 'cartoon': 32 | p = '/data/zj/PycharmProjects/FACT-main/output/PACS_ResNet50/cartoon/2022-05-29-14-06-40/best_model.tar' 33 | elif engine.target_domain[0] == 'photo': 34 | p = '/data/zj/PycharmProjects/FACT-main/output/PACS_ResNet50/photo/2022-05-29-13-59-08/best_model.tar' 35 | elif engine.target_domain[0] == 'sketch': 36 | p = '/data/zj/PycharmProjects/FACT-main/output/PACS_ResNet50/sketch/2022-05-29-14-06-52/best_model.tar' 37 | else: 38 | raise Exception("? {}".format(engine.target_domain[0])) 39 | # original_state_dicts = torch.load(p, map_location='cpu')['model'] 40 | model.load_pretrained(p, absolute=True) 41 | # print("Loaded models from {}".format(model_path)) 42 | 43 | changed = 0 44 | total = 0 45 | change_list = [] 46 | sample_list = [] 47 | label_list = [] 48 | prob_list = [] 49 | with torch.no_grad(): 50 | model.eval() # eval mode for normal test 51 | for i, data_list in enumerate(tqdm(eval_data)): 52 | test_data, test_label, aug_data = data_list['x'], data_list['label'], data_list['tta'] 53 | test_data, aug_data, test_label = to([test_data, aug_data, test_label], device) 54 | 55 | outputs = model.step(test_data, test_label) 56 | logits = outputs['logits'] 57 | _ = get_loss_and_acc(outputs, running_loss, running_corrects) 58 | 59 | # aug_data = aug_data.mean(1, keepdims=True) 60 | N, aug_n, C, H, W = aug_data.shape # 61 | aug_data = aug_data.reshape(-1, C, H, W) 62 | aug_label = test_label.unsqueeze(1).repeat(aug_n, 1).reshape(-1) 63 | outputs2 = model.step(aug_data, aug_label) 64 | logits2 = outputs2['logits'] 65 | # max_prob_mask = logits2.softmax(1).max(1)[0].view(N, aug_n, 1) > 0.95 66 | # mean_logits = (logits2.reshape(N, aug_n, -1) * max_prob_mask).sum(1) / max_prob_mask.sum(1) 67 | mean_logits = logits2.reshape(N, aug_n, -1).mean(1) + logits 68 | 69 | # second_pred = torch.topk(logits, 2, dim=-1)[1][:, 1] 70 | # idx = mean_logits.argmax(1) == logits.argmax(1) 71 | # second_pred[idx] = logits.argmax(1)[idx] 72 | # acc = (second_pred == test_label).sum() / len(second_pred) 73 | 74 | # w/ , w/o 75 | outputs2 = {'TTA': {'acc_type': 'acc', 'pred': mean_logits, 'target': test_label},} 76 | _ = get_loss_and_acc(outputs2, running_loss, running_corrects) 77 | 78 | # previous_pred = logits.argmax(1) 79 | # current_pred = logits2.argmax(1).view(N, aug_n) 80 | # sample_list.append((logits.softmax(1).cpu().numpy(), logits2.reshape(N, aug_n, -1).softmax(2).mean(1).cpu().numpy())) 81 | # label_list.append(test_label.cpu().numpy()) 82 | # for j, (p, c) in enumerate(zip(previous_pred, current_pred)): 83 | # total += 1 84 | # if len(c.unique()) > 1: 85 | # changed += 1 86 | # if len(logits) == 16: 87 | # pre_max_prob = logits.softmax(1).max(1)[0] 88 | # prob_list.append(pre_max_prob) 89 | # change_list.append(i * 16 + j) 90 | # print(change_list) 91 | # print('Previous pred prob : ', torch.stack(prob_list).mean(), torch.stack(prob_list).std()) 92 | # print('Total : {}, Changed : {}, PCR : {:.2f}'.format(total, changed, changed/total*100)) 93 | # np.save('{}'.format(cur_domain), sample_list) 94 | # np.save('{}-l'.format(cur_domain), label_list) 95 | 96 | loss = running_loss.get_average_dicts() 97 | acc = running_corrects.get_average_dicts() # 98 | if 'main' in acc: 99 | return acc['TTA'], (loss, acc) 100 | else: 101 | return 0, (loss, acc) 102 | 103 | 104 | @torch.enable_grad() 105 | def finetune_entropy(model, img, optimizer, k=1): 106 | avg_loss = [] 107 | for i in range(k): 108 | logits = model.step(img, None)['logits'] 109 | entropy_loss = - (logits.softmax(1) * logits.log_softmax(1)).sum(1).mean() 110 | gt_loss = 0 111 | 112 | prob = logits.softmax(1) 113 | max_prob = prob.max(1)[0] > 0.9 114 | peudo_label = prob.max(1)[1] 115 | low_loss = - (logits.softmax(1) * logits.log_softmax(1)).sum(1)[~max_prob].mean() 116 | high_loss = F.cross_entropy(logits, peudo_label, reduction='none')[max_prob].mean() 117 | 118 | # .mean(0) 119 | # prob = logits.softmax(1) 120 | # consistency_loss = ((prob.unsqueeze(1) - prob.unsqueeze(0)) ** 2).mean() 121 | loss = entropy_loss # + consistency_loss 122 | zero_and_update(optimizer, loss) 123 | avg_loss.append(loss.distance_dict()) 124 | 125 | 126 | @torch.enable_grad() 127 | def finetune_sep(model, test_data, logits, label, bn_optimizers, all_optimizers, k=3, threshold=0.95): 128 | optimizers = all_optimizers 129 | avg_loss = [] 130 | conf = logits.softmax(1).max(1)[0] > threshold 131 | confident_label = logits.argmax(1)[conf] 132 | for i in range(k): 133 | logits = model(test_data, None)['logits'] 134 | unconfident_logits = logits 135 | entropy_loss = - (unconfident_logits.softmax(1) * unconfident_logits.log_softmax(1)).sum(1).mean(0) 136 | 137 | confident_logits, unconfident_logits = logits[conf], logits[~conf] 138 | sup_loss = F.cross_entropy(confident_logits, confident_label) 139 | # 140 | entropy_loss.backward(retain_graph=True) 141 | for n, p in model.named_parameters(): 142 | if 'bn' not in n: 143 | p.grad = None 144 | 145 | # loss = F.cross_entropy(logits, label) 146 | loss = entropy_loss 147 | [o.zero_grad() for o in optimizers] 148 | loss.backward() 149 | [o.step() for o in optimizers] 150 | avg_loss.append(loss.distance_dict()) 151 | return logits 152 | 153 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import importlib 3 | 4 | 5 | def import_all_modules_in_current_folders(): 6 | imported_modules = [] 7 | for module in os.listdir(os.path.dirname(__file__)): 8 | if module == '__init__.py' or module[-3:] != '.py': 9 | continue 10 | importlib.import_module('.' + module[:-3], __package__) # '.' before module_name is required 11 | imported_modules.append(module) 12 | del module 13 | print('Successfully imported modules : ', imported_modules) 14 | 15 | import_all_modules_in_current_folders() -------------------------------------------------------------------------------- /script/TTA.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | 4 | python main.py \ 5 | --gpu=0 \ 6 | --load-path=pretrained_models/resnet18_PACS \ 7 | --save-path=test \ 8 | --do-train=False \ 9 | --dataset=PACS \ 10 | --loss-names=gem-t \ 11 | --TTAug \ 12 | --TTA-bs=3 \ 13 | --TTA-head=em \ 14 | --shuffled=True \ 15 | --eval=tta_ft \ 16 | --model=DomainAdaptor \ 17 | --backbone=resnet18 \ 18 | --batch-size=64 \ 19 | --exp-num=-2 \ 20 | --start-time=0 \ 21 | --times=5 \ 22 | 23 | -------------------------------------------------------------------------------- /script/TTA_jigsaw.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | 4 | python main.py \ 5 | --img-size=225 \ 6 | --jigsaw \ 7 | --dataset='MDN' \ 8 | --save-path='AdaBN/resnet18_jigsaw_MDN2' \ 9 | --gpu=0 \ 10 | --do-train=True \ 11 | --lr=1e-3 \ 12 | \ 13 | --TTA-head='jigsaw' \ 14 | --model='tta_model' \ 15 | --backbone='resnet18' \ 16 | --batch-size=128 \ 17 | --num-epoch=30 \ 18 | \ 19 | --exp-num=-2 \ 20 | --start-time=1 \ 21 | --times=4 \ 22 | --train='deepall' \ 23 | --eval='deepall' \ 24 | --loader='normal' -------------------------------------------------------------------------------- /script/TTA_rotation.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # OfficeHome, PACS, VLCS, MiniDomainNet 3 | 4 | python main.py \ 5 | --rot \ 6 | --dataset='PACS' \ 7 | --save-path='pretrained_models/resnet18_rot_PACS' \ 8 | --gpu=0 \ 9 | --do-train=True \ 10 | --lr=1e-3 \ 11 | \ 12 | --model='tta_model' \ 13 | --TTA-head='rot' \ 14 | --batch-size=128 \ 15 | --num-epoch=30 \ 16 | \ 17 | --exp-num=-2 \ 18 | --start-time=1 \ 19 | --times=4 \ 20 | --fc-weight=10.0 \ 21 | --train='deepall' \ 22 | --eval='deepall' \ 23 | --loader='normal' \ 24 | -------------------------------------------------------------------------------- /script/deepall.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python main.py \ 4 | --dataset='PACS' \ 5 | --save-path='pretrained_models/resnet18_PACS' \ 6 | --gpu=0 \ 7 | --do-train=True \ 8 | --lr=1e-3 \ 9 | \ 10 | --model='erm' \ 11 | --backbone='resnet18' \ 12 | --batch-size=128 \ 13 | --num-epoch=30 \ 14 | \ 15 | --exp-num=-2 \ 16 | --start-time=0 \ 17 | --times=5 \ 18 | --train='deepall' \ 19 | --eval='deepall' \ 20 | --loader='normal' \ 21 | --eval-step=1 \ 22 | --scheduler='step' \ 23 | --lr-decay-gamma=0.1 \ 24 | -------------------------------------------------------------------------------- /script/meta.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | 4 | python main.py \ 5 | --dataset='PACS' \ 6 | --save-path='AdaBN/meta_norm_PACS' \ 7 | --gpu=0 \ 8 | --do-train=True \ 9 | --meta-lr=0.1 \ 10 | --lr=1e-3 \ 11 | \ 12 | --replace \ 13 | --meta-step=1 \ 14 | --meta-second-order=False \ 15 | --TTA-head='norm' \ 16 | --model='DomainAdaptor' \ 17 | --backbone='resnet50' \ 18 | --batch-size=64 \ 19 | --num-epoch=30 \ 20 | \ 21 | --exp-num -2 \ 22 | --start-time=0 \ 23 | --times=5 \ 24 | --fc-weight=10.0 \ 25 | --train='tta_meta' \ 26 | --eval='tta_meta' \ 27 | --loader='meta' \ 28 | -------------------------------------------------------------------------------- /utils/tensor_utils.py: -------------------------------------------------------------------------------- 1 | import time 2 | from pathlib import Path 3 | 4 | import numpy as np 5 | import torch 6 | 7 | __all__ = ['to', 'to_numpy', 'cat', 'zero_and_update', 'mkdir', 'Timer', 'AverageMeter', 'AverageMeterDict'] 8 | 9 | """ 10 | Tensor Utils 11 | """ 12 | 13 | 14 | def to(tensors, device, non_blocking=False): 15 | res = [] 16 | if isinstance(tensors, (list, tuple)): 17 | for t in tensors: 18 | res.append(to(t, device, non_blocking=non_blocking)) 19 | return res 20 | elif isinstance(tensors, (dict,)): 21 | res = {} 22 | for k, v in tensors.items(): 23 | res[k] = to(v, device, non_blocking=non_blocking) 24 | return res 25 | else: 26 | if isinstance(tensors, torch.Tensor): 27 | return tensors.to(device, non_blocking=non_blocking) 28 | else: 29 | return tensors 30 | 31 | 32 | def record_stream(tensors): 33 | if isinstance(tensors, (list, tuple)): 34 | for t in tensors: 35 | record_stream(t) 36 | elif isinstance(tensors, (dict,)): 37 | for k, v in tensors.items(): 38 | record_stream(v) 39 | else: 40 | if isinstance(tensors, torch.Tensor): 41 | tensors.record_stream(torch.cuda.current_stream()) 42 | 43 | 44 | def to_numpy(tensor): 45 | if tensor is None: 46 | return None 47 | elif isinstance(tensor, (tuple, list)): 48 | res = [] 49 | for t in tensor: 50 | res.append(to_numpy(t)) 51 | return res 52 | else: 53 | if isinstance(tensor, np.ndarray) or str(type(tensor))[8:13] == 'numpy': 54 | return tensor 55 | else: 56 | return tensor.detach().cpu().numpy() 57 | 58 | 59 | def cat(tensors, axis=0): 60 | res = [] 61 | for t in tensors: 62 | if t is None: 63 | res.append(None) 64 | elif isinstance(t[0], torch.Tensor): 65 | res.append(torch.cat(t, dim=axis)) 66 | else: 67 | res.append(np.concatenate(t, axis=axis)) 68 | return res 69 | 70 | 71 | def zero_and_update(optimizers, loss): 72 | if isinstance(optimizers, (list, tuple)): 73 | for optimizer in optimizers: 74 | optimizer.zero_grad() 75 | loss.backward() 76 | for opt in optimizers: 77 | opt.step() 78 | else: 79 | optimizers.zero_grad() 80 | loss.backward() 81 | optimizers.step() 82 | 83 | 84 | """ 85 | Output Utils 86 | """ 87 | 88 | 89 | class AverageMeter(object): 90 | """Computes and stores the average and current value""" 91 | 92 | def __init__(self): 93 | self.reset() 94 | 95 | def reset(self): 96 | self.val = 0 97 | self.n = 0 98 | self.avg = 0 99 | self.sum = 0 100 | self.count = 0 101 | self.min = 1e8 102 | self.max = -1e8 103 | # self.l = [] 104 | return self 105 | 106 | def update(self, val, n=1): 107 | self.val = val 108 | self.n = n 109 | self.sum += val 110 | self.count += n 111 | self.avg = self.sum / self.count 112 | # self.l.append(val) 113 | return self 114 | 115 | 116 | class AverageMeterDict(object): 117 | def __init__(self): 118 | self.dict = {} 119 | 120 | def update(self, name, val, n=1): 121 | if name not in self.dict: 122 | self.dict[name] = AverageMeter() 123 | self.dict[name].update(val, n) 124 | 125 | def update_dict(self, d, n=1): 126 | for name, val in d.items(): 127 | if isinstance(val, (list, tuple)): 128 | continue 129 | self.update(name, val, n) 130 | 131 | def get_average_dicts(self): 132 | return {k: v.avg for k, v in self.dict.items()} 133 | 134 | def get_current_dicts(self): 135 | return {k: v.val / v.n for k, v in self.dict.items()} 136 | 137 | def print(self, current=True, end='\n'): 138 | strs = [] 139 | d = self.get_current_dicts() if current else self.get_average_dicts() 140 | for k, v in d.items(): 141 | strs.append('{}: {:.4f}'.format(k, v)) 142 | strs = '{' + ', '.join(strs) + '} ' 143 | print(strs, end=end) 144 | 145 | 146 | """ 147 | Other Utils 148 | """ 149 | 150 | 151 | def mkdir(path, level=2, create_self=True): 152 | """ Make directory for this path, 153 | level is how many parent folders should be created. 154 | create_self is whether create path(if it is a file, it should not be created) 155 | 156 | e.g. : mkdir('/home/parent1/parent2/folder', level=3, create_self=True), 157 | it will first create parent1, then parent2, then folder. 158 | 159 | :param path: string 160 | :param level: int 161 | :param create_self: True or False 162 | :return: 163 | """ 164 | p = Path(path) 165 | if create_self: 166 | paths = [p] 167 | else: 168 | paths = [] 169 | level -= 1 170 | while level != 0: 171 | p = p.parent 172 | paths.append(p) 173 | level -= 1 174 | 175 | for p in paths[::-1]: 176 | p.mkdir(exist_ok=True) 177 | 178 | 179 | class Timer(object): 180 | def __init__(self, name='', thresh=0, verbose=True): 181 | self.start_time = time.time() 182 | self.verbose = verbose 183 | self.duration = 0 184 | self.thresh=thresh 185 | self.name = name 186 | 187 | def restart(self): 188 | self.duration = self.start_time = time.time() 189 | return self.duration 190 | 191 | def stop(self): 192 | time.asctime() 193 | return time.time() - self.start_time 194 | 195 | def get_last_duration(self): 196 | return self.duration 197 | 198 | def get_formatted_duration(self, duration=None): 199 | def sec2time(seconds): 200 | s = seconds % 60 201 | seconds = seconds // 60 202 | m = seconds % 60 203 | seconds = seconds // 60 204 | h = seconds % 60 205 | return h, m, s 206 | 207 | if duration is None: 208 | duration = self.duration 209 | return '{} Time {:^.0f} h, {:^.0f} m, {:^.4f} s'.format(self.name, *sec2time(duration)) 210 | 211 | def __enter__(self): 212 | self.restart() 213 | return self 214 | 215 | def __exit__(self, exc_type, exc_val, exc_tb): 216 | self.duration = self.stop() 217 | if self.verbose and self.duration > self.thresh: 218 | print(self.get_formatted_duration()) 219 | 220 | --------------------------------------------------------------------------------