├── README.md ├── data ├── JigsawLoader.py ├── StandardDataset.py ├── __init__.py ├── __pycache__ │ ├── JigsawLoader.cpython-37.pyc │ ├── StandardDataset.cpython-37.pyc │ ├── __init__.cpython-37.pyc │ ├── concat_dataset.cpython-37.pyc │ └── data_helper.cpython-37.pyc ├── concat_dataset.py ├── data_helper.py └── samplers.py ├── loss └── KL_Loss.py ├── models ├── FilterDropout.py ├── LayerDiscriminator.py ├── __init__.py ├── __pycache__ │ ├── MixStyle.cpython-37.pyc │ ├── __init__.cpython-37.pyc │ ├── model_factory.cpython-37.pyc │ ├── resnet_Ensemble.cpython-37.pyc │ ├── resnet_RandAug.cpython-37.pyc │ └── resnet_classifiers.cpython-37.pyc ├── model_factory.py ├── model_utils.py ├── resnet_domain.py └── utils.py ├── optimizer ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-37.pyc │ └── optimizer_helper.cpython-37.pyc └── optimizer_helper.py ├── train.sh ├── train_domain.py └── utils ├── Logger.py ├── __init__.py ├── __pycache__ ├── Logger.cpython-37.pyc ├── __init__.cpython-37.pyc └── tf_logger.cpython-37.pyc ├── tf_logger.py ├── tools.py └── vis.py /README.md: -------------------------------------------------------------------------------- 1 | # DomainDrop: Suppressing Domain-Sensitive Channels for Domain Generalization 2 | 3 | ## Requirements 4 | 5 | * Python == 3.7.3 6 | * Pytorch == 1.8.1 7 | * Cuda == 10.1 8 | * Torchvision == 0.4.2 9 | * Tensorflow == 1.14.0 10 | * GPU == RTX 2080Ti 11 | 12 | ## DataSets 13 | Please download PACS dataset from [here](https://drive.google.com/drive/folders/0B6x7gtvErXgfUU1WcGY5SzdwZVk?resourcekey=0-2fvpQY_QSyJf2uIECzqPuQ). 14 | Make sure you use the official train/val/test split in [PACS paper](https://openaccess.thecvf.com/content_iccv_2017/html/Li_Deeper_Broader_and_ICCV_2017_paper.html). 15 | Take `/data/DataSets/` as the saved directory for example: 16 | ``` 17 | images -> /data/DataSets/PACS/kfold/art_painting/dog/pic_001.jpg, ... 18 | splits -> /data/DataSets/PACS/pacs_label/art_painting_crossval_kfold.txt, ... 19 | ``` 20 | Then set the `"data_root"` as `"/data/DataSets/"` and `"data"` as `"PACS"` in both `train_domain.py` and `train.sh`. 21 | 22 | ## Training 23 | For training the model, please set the `"result_path"` where the results are saved in both `train_domain.py` and `train.sh`. 24 | Then simply running the code to train a ResNet-18: 25 | ``` 26 | python train_domain.py --target [domain_index] --device [GPU_index] 27 | ``` 28 | The `domain_index` denotes the index of target domain, and `GPU_index` denotes the GPU device number. 29 | ``` 30 | domain_index: [0:'photo', 1:'art_painting', 2:'cartoon', 3:'sketch'] 31 | ``` 32 | Or run the `train.sh` directly. 33 | 34 | ## Evaluation 35 | 36 | 37 | To evaluate the performance of the models, you can download the models trained on PACS as below: 38 | 39 | Target domain | Photo | Art | Cartoon | Sketch | 40 | :----: | :----: | :----: | :----: | :----: | 41 | Acc(%) | 96.71 | 84.91 | 80.72 | 84.32 | 42 | models | [download](https://drive.google.com/drive/folders/1N63V8HxLXRl94GZgllQHTrxWrqH2-GDl?usp=sharing) | [download](https://drive.google.com/drive/folders/1zA9smbTRExm6FSu5WpfI0tmx93uonjuk?usp=sharing) | [download](https://drive.google.com/drive/folders/1jJW4q-aUVsNcUeiE8wKbv0zuzK5f3aJA?usp=sharing) | [download](https://drive.google.com/drive/folders/1x-33N1mtAJP08sT5dqZX53Y8B_8_Vify?usp=sharing) | 43 | 44 | 45 | Please set the `--eval = 1` and `--eval_model_path` as the saved path of the downloaded models. *e.g.*, `/trained/model/path/photo/model.pt`. Then you can simple run: 46 | 47 | ``` 48 | python train_domain.py --target [domain_index] --device [GPU_index] --eval 1 --eval_model_path '/trained/model/path/photo/model.pt' 49 | ``` 50 | 51 | ## Citations 52 | ``` 53 | @inproceedings{guo2023domaindrop, 54 | title={DomainDrop: Suppressing Domain-Sensitive Channels for Domain Generalization}, 55 | author={Guo, Jintao and Qi, Lei and Shi, Yinghuan}, 56 | booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision}, 57 | year={2023} 58 | } 59 | ``` 60 | 61 | ## Acknowledgement 62 | Part of our code is derived from the following repository. 63 | * [MMLD](https://github.com/mil-tokyo/dg_mmld): "Domain Generalization Using a Mixture of Multiple Latent Domains", AAAI 2020 64 | 65 | We thank to the authors for releasing their codes. Please also consider citing their work. 66 | 67 | 68 | -------------------------------------------------------------------------------- /data/JigsawLoader.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.utils.data as data 4 | import torchvision 5 | import torchvision.transforms as transforms 6 | from PIL import Image 7 | from random import sample, random 8 | import sys 9 | import os 10 | 11 | 12 | def get_random_subset(names, labels, percent): 13 | """ 14 | 15 | :param names: list of names 16 | :param labels: list of labels 17 | :param percent: 0 < float < 1 18 | :return: 19 | """ 20 | samples = len(names) 21 | amount = int(samples * percent) 22 | random_index = sample(range(samples), amount) 23 | name_val = [names[k] for k in random_index] 24 | name_train = [v for k, v in enumerate(names) if k not in random_index] 25 | labels_val = [labels[k] for k in random_index] 26 | labels_train = [v for k, v in enumerate(labels) if k not in random_index] 27 | return name_train, name_val, labels_train, labels_val 28 | 29 | 30 | def _dataset_info(txt_labels): 31 | # read from the official split txt 32 | file_names = [] 33 | labels = [] 34 | 35 | for row in open(txt_labels, 'r'): 36 | row = row.split(' ') 37 | file_names.append(row[0]) 38 | labels.append(int(row[1])) 39 | 40 | return file_names, labels 41 | 42 | 43 | def find_classes(dir_name): 44 | if sys.version_info >= (3, 5): 45 | # Faster and available in Python 3.5 and above 46 | classes = [d.name for d in os.scandir(dir_name) if d.is_dir()] 47 | else: 48 | classes = [d for d in os.listdir(dir_name) if os.path.isdir(os.path.join(dir_name, d))] 49 | classes.sort() 50 | class_to_idx = {classes[i]: i+1 for i in range(len(classes))} 51 | return classes, class_to_idx 52 | 53 | 54 | def get_split_domain_info_from_dir(domain_path, dataset_name=None, val_percentage=None, domain_label=None): 55 | # read from the directory 56 | domain_name = domain_path.split("/")[-1] 57 | if dataset_name == "VLCS": 58 | name_train, name_val, labels_train, labels_val = [], [], [], [] 59 | classes, class_to_idx = find_classes(domain_path + "/full") 60 | # full为train 61 | for i, item in enumerate(classes): 62 | class_path = domain_path + "/" + "full" + "/" + item 63 | for root, _, fnames in sorted(os.walk(class_path)): 64 | for fname in sorted(fnames): 65 | path = os.path.join(domain_name, "full", item, fname) 66 | name_train.append(path) 67 | labels_train.append(class_to_idx[item]) 68 | # test为val 69 | for i, item in enumerate(classes): 70 | class_path = domain_path + "/" + "test" + "/" + item 71 | for root, _, fnames in sorted(os.walk(class_path)): 72 | for fname in sorted(fnames): 73 | path = os.path.join(domain_name, "test", item, fname) 74 | name_val.append(path) 75 | labels_val.append(class_to_idx[item]) 76 | domain_label_train = [domain_label for i in range(len(labels_train))] 77 | domain_label_val = [domain_label for i in range(len(labels_val))] 78 | return name_train, name_val, labels_train, labels_val, domain_label_train, domain_label_val 79 | 80 | elif dataset_name == "OfficeHome" or "PACS" in dataset_name: 81 | names, labels = [], [] 82 | classes, class_to_idx = find_classes(domain_path) 83 | for i, item in enumerate(classes): 84 | class_path = domain_path + "/" + item 85 | for root, _, fnames in sorted(os.walk(class_path)): 86 | for fname in sorted(fnames): 87 | path = os.path.join(domain_name, item, fname) 88 | names.append(path) 89 | labels.append(class_to_idx[item]) 90 | name_train, name_val, labels_train, labels_val = get_random_subset(names, labels, val_percentage) 91 | domain_label_train = [domain_label for i in range(len(labels_train))] 92 | domain_label_val = [domain_label for i in range(len(labels_val))] 93 | return name_train, name_val, labels_train, labels_val, domain_label_train, domain_label_val 94 | 95 | else: 96 | raise ValueError("dataset is wrong.") 97 | 98 | 99 | def get_split_dataset_info_from_txt(txt_path, domain, domain_label, val_percentage=None): 100 | if "PACS" in txt_path: 101 | train_name = "_train_kfold.txt" 102 | val_name = "_crossval_kfold.txt" 103 | 104 | train_txt = txt_path + "/" + domain + train_name 105 | val_txt = txt_path + "/" + domain + val_name 106 | 107 | train_names, train_labels = _dataset_info(train_txt) 108 | val_names, val_labels = _dataset_info(val_txt) 109 | train_domain_labels = [domain_label for i in range(len(train_labels))] 110 | val_domain_labels = [domain_label for i in range(len(val_labels))] 111 | return train_names, val_names, train_labels, val_labels, train_domain_labels, val_domain_labels 112 | 113 | elif "miniDomainNet" in txt_path: 114 | # begin at 0, need to add 1 115 | train_name = "_train.txt" 116 | val_name = "_test.txt" 117 | train_txt = txt_path + "/" + domain + train_name 118 | val_txt = txt_path + "/" + domain + val_name 119 | 120 | train_names, train_labels = _dataset_info(train_txt) 121 | val_names, val_labels = _dataset_info(val_txt) 122 | train_labels = [label + 1 for label in train_labels] 123 | val_labels = [label + 1 for label in val_labels] 124 | 125 | names = train_names + val_names 126 | labels = train_labels + val_labels 127 | train_names, val_names, train_labels, val_labels = get_random_subset(names, labels, val_percentage) 128 | 129 | train_domain_labels = [domain_label for i in range(len(train_labels))] 130 | val_domain_labels = [domain_label for i in range(len(val_labels))] 131 | return train_names, val_names, train_labels, val_labels, train_domain_labels, val_domain_labels 132 | else: 133 | raise NotImplementedError 134 | 135 | 136 | def get_split_dataset_info(txt_list, val_percentage): 137 | names, labels = _dataset_info(txt_list) 138 | return get_random_subset(names, labels, val_percentage) 139 | 140 | 141 | class JigsawDataset(data.Dataset): 142 | def __init__(self, names, labels, jig_classes=100, img_transformer=None, tile_transformer=None, patches=True, bias_whole_image=None): 143 | self.data_path = "" 144 | self.names = names 145 | self.labels = labels 146 | 147 | self.N = len(self.names) 148 | self.permutations = self.__retrieve_permutations(jig_classes) 149 | self.grid_size = 3 150 | self.bias_whole_image = bias_whole_image 151 | if patches: 152 | self.patch_size = 64 153 | self._image_transformer = img_transformer 154 | self._augment_tile = tile_transformer 155 | if patches: 156 | self.returnFunc = lambda x: x 157 | else: 158 | def make_grid(x): 159 | return torchvision.utils.make_grid(x, self.grid_size, padding=0) 160 | self.returnFunc = make_grid 161 | 162 | def get_tile(self, img, n): 163 | w = float(img.size[0]) / self.grid_size 164 | y = int(n / self.grid_size) 165 | x = n % self.grid_size 166 | tile = img.crop([x * w, y * w, (x + 1) * w, (y + 1) * w]) 167 | tile = self._augment_tile(tile) 168 | return tile 169 | 170 | def get_image(self, index): 171 | framename = self.data_path + '/' + self.names[index] 172 | img = Image.open(framename).convert('RGB') 173 | return self._image_transformer(img) 174 | 175 | def __getitem__(self, index): 176 | img = self.get_image(index) 177 | n_grids = self.grid_size ** 2 178 | tiles = [None] * n_grids 179 | for n in range(n_grids): 180 | tiles[n] = self.get_tile(img, n) 181 | 182 | order = np.random.randint(len(self.permutations) + 1) # added 1 for class 0: unsorted 183 | if self.bias_whole_image: 184 | if self.bias_whole_image > random(): 185 | order = 0 186 | if order == 0: 187 | data = tiles 188 | else: 189 | data = [tiles[self.permutations[order - 1][t]] for t in range(n_grids)] 190 | 191 | data = torch.stack(data, 0) 192 | return self.returnFunc(data), int(order), int(self.labels[index]) 193 | 194 | def __len__(self): 195 | return len(self.names) 196 | 197 | def __retrieve_permutations(self, classes): 198 | all_perm = np.load('permutations_%d.npy' % (classes)) 199 | # from range [1,9] to [0,8] 200 | if all_perm.min() == 1: 201 | all_perm = all_perm - 1 202 | 203 | return all_perm 204 | 205 | 206 | class JigsawTestDataset(JigsawDataset): 207 | def __init__(self, *args, **xargs): 208 | super().__init__(*args, **xargs) 209 | 210 | def __getitem__(self, index): 211 | framename = self.data_path + '/' + self.names[index] 212 | img = Image.open(framename).convert('RGB') 213 | return self._image_transformer(img), 0, int(self.labels[index]) 214 | 215 | 216 | class JigsawTestDatasetMultiple(JigsawDataset): 217 | def __init__(self, *args, **xargs): 218 | super().__init__(*args, **xargs) 219 | self._image_transformer = transforms.Compose([ 220 | transforms.Resize(255, Image.BILINEAR), 221 | ]) 222 | self._image_transformer_full = transforms.Compose([ 223 | transforms.Resize(225, Image.BILINEAR), 224 | transforms.ToTensor(), 225 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 226 | ]) 227 | self._augment_tile = transforms.Compose([ 228 | transforms.Resize((75, 75), Image.BILINEAR), 229 | transforms.ToTensor(), 230 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 231 | ]) 232 | 233 | def __getitem__(self, index): 234 | framename = self.data_path + '/' + self.names[index] 235 | _img = Image.open(framename).convert('RGB') 236 | img = self._image_transformer(_img) 237 | 238 | w = float(img.size[0]) / self.grid_size 239 | n_grids = self.grid_size ** 2 240 | images = [] 241 | jig_labels = [] 242 | tiles = [None] * n_grids 243 | for n in range(n_grids): 244 | y = int(n / self.grid_size) 245 | x = n % self.grid_size 246 | tile = img.crop([x * w, y * w, (x + 1) * w, (y + 1) * w]) 247 | tile = self._augment_tile(tile) 248 | tiles[n] = tile 249 | for order in range(0, len(self.permutations)+1, 3): 250 | if order==0: 251 | data = tiles 252 | else: 253 | data = [tiles[self.permutations[order-1][t]] for t in range(n_grids)] 254 | data = self.returnFunc(torch.stack(data, 0)) 255 | images.append(data) 256 | jig_labels.append(order) 257 | images = torch.stack(images, 0) 258 | jig_labels = torch.LongTensor(jig_labels) 259 | return images, jig_labels, int(self.labels[index]) 260 | 261 | 262 | class JigsawNewDataset(data.Dataset): 263 | def __init__(self, names, labels, domain_labels, dataset_path, jig_classes=100, img_transformer=None, 264 | tile_transformer=None, patches=True,bias_whole_image=None): 265 | self.data_path = dataset_path 266 | 267 | self.names = names 268 | self.labels = labels 269 | self.domain_labels = domain_labels 270 | 271 | self.domain = domain_labels[0] - 1 272 | 273 | self.N = len(self.names) 274 | self.grid_size = 3 275 | self.bias_whole_image = bias_whole_image 276 | if patches: 277 | self.patch_size = 64 278 | self._image_transformer = img_transformer 279 | self._augment_tile = tile_transformer 280 | if patches: 281 | self.returnFunc = lambda x: x 282 | else: 283 | def make_grid(x): 284 | return torchvision.utils.make_grid(x, self.grid_size, padding=0) 285 | 286 | self.returnFunc = make_grid 287 | 288 | def get_tile(self, img, n): 289 | w = float(img.size[0]) / self.grid_size 290 | y = int(n / self.grid_size) 291 | x = n % self.grid_size 292 | tile = img.crop([x * w, y * w, (x + 1) * w, (y + 1) * w]) 293 | tile = self._augment_tile(tile) 294 | return tile 295 | 296 | def get_image(self, index): 297 | framename = self.data_path + '/' + self.names[index] 298 | img = Image.open(framename).convert('RGB') 299 | return self._image_transformer(img) 300 | 301 | def __getitem__(self, index): 302 | framename = self.data_path + '/' + self.names[index] 303 | img = Image.open(framename).convert('RGB') 304 | return self._image_transformer(img), int(self.labels[index] - 1), int(self.domain_labels[index] - 1) 305 | 306 | def __len__(self): 307 | return len(self.names) 308 | 309 | def __retrieve_permutations(self, classes): 310 | all_perm = np.load('permutations_%d.npy' % (classes)) 311 | # from range [1,9] to [0,8] 312 | if all_perm.min() == 1: 313 | all_perm = all_perm - 1 314 | return all_perm 315 | 316 | 317 | class JigsawTestNewDataset(JigsawNewDataset): 318 | def __init__(self, *args, **xargs): 319 | super().__init__(*args, **xargs) 320 | 321 | def __getitem__(self, index): 322 | framename = self.data_path + '/' + self.names[index] 323 | img = Image.open(framename).convert('RGB') 324 | return self._image_transformer(img), int(self.labels[index] - 1), int(self.domain_labels[index] - 1) 325 | -------------------------------------------------------------------------------- /data/StandardDataset.py: -------------------------------------------------------------------------------- 1 | from torchvision import datasets 2 | from torchvision import transforms 3 | 4 | 5 | def get_dataset(path, mode, image_size): 6 | if mode == "train": 7 | img_transform = transforms.Compose([ 8 | transforms.RandomResizedCrop(image_size, scale=(0.7, 1.0)), 9 | transforms.RandomHorizontalFlip(), 10 | transforms.ToTensor(), 11 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[1/256., 1/256., 1/256.]) # std=[1/256., 1/256., 1/256.] #[0.229, 0.224, 0.225] 12 | ]) 13 | else: 14 | img_transform = transforms.Compose([ 15 | transforms.Resize(image_size), 16 | # transforms.CenterCrop(image_size), 17 | transforms.ToTensor(), 18 | transforms.Normalize([0.485, 0.456, 0.406], std=[1/256., 1/256., 1/256.]) # std=[1/256., 1/256., 1/256.] 19 | ]) 20 | return datasets.ImageFolder(path, transform=img_transform) 21 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lingeringlight/DomainDrop/a2d9607d4f1fb9ba35a9fa48fbceff82fd67d15f/data/__init__.py -------------------------------------------------------------------------------- /data/__pycache__/JigsawLoader.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lingeringlight/DomainDrop/a2d9607d4f1fb9ba35a9fa48fbceff82fd67d15f/data/__pycache__/JigsawLoader.cpython-37.pyc -------------------------------------------------------------------------------- /data/__pycache__/StandardDataset.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lingeringlight/DomainDrop/a2d9607d4f1fb9ba35a9fa48fbceff82fd67d15f/data/__pycache__/StandardDataset.cpython-37.pyc -------------------------------------------------------------------------------- /data/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lingeringlight/DomainDrop/a2d9607d4f1fb9ba35a9fa48fbceff82fd67d15f/data/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /data/__pycache__/concat_dataset.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lingeringlight/DomainDrop/a2d9607d4f1fb9ba35a9fa48fbceff82fd67d15f/data/__pycache__/concat_dataset.cpython-37.pyc -------------------------------------------------------------------------------- /data/__pycache__/data_helper.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lingeringlight/DomainDrop/a2d9607d4f1fb9ba35a9fa48fbceff82fd67d15f/data/__pycache__/data_helper.cpython-37.pyc -------------------------------------------------------------------------------- /data/concat_dataset.py: -------------------------------------------------------------------------------- 1 | import bisect 2 | import warnings 3 | 4 | from torch.utils.data import Dataset 5 | 6 | # This is a small variant of the ConcatDataset class, which also returns dataset index 7 | from data.JigsawLoader import JigsawTestDatasetMultiple 8 | 9 | 10 | class ConcatDataset(Dataset): 11 | """ 12 | Dataset to concatenate multiple datasets. 13 | Purpose: useful to assemble different existing datasets, possibly 14 | large-scale datasets as the concatenation operation is done in an 15 | on-the-fly manner. 16 | 17 | Arguments: 18 | datasets (sequence): List of datasets to be concatenated 19 | """ 20 | 21 | @staticmethod 22 | def cumsum(sequence): 23 | r, s = [], 0 24 | for e in sequence: 25 | l = len(e) 26 | r.append(l + s) 27 | s += l 28 | return r 29 | 30 | def isMulti(self): 31 | return isinstance(self.datasets[0], JigsawTestDatasetMultiple) 32 | 33 | def __init__(self, datasets): 34 | super(ConcatDataset, self).__init__() 35 | assert len(datasets) > 0, 'datasets should not be an empty iterable' 36 | self.datasets = list(datasets) 37 | self.cumulative_sizes = self.cumsum(self.datasets) 38 | 39 | def __len__(self): 40 | return self.cumulative_sizes[-1] 41 | 42 | def __getitem__(self, idx): 43 | dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) 44 | if dataset_idx == 0: 45 | sample_idx = idx 46 | else: 47 | sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] 48 | return self.datasets[dataset_idx][sample_idx], dataset_idx 49 | 50 | @property 51 | def cummulative_sizes(self): 52 | warnings.warn("cummulative_sizes attribute is renamed to " 53 | "cumulative_sizes", DeprecationWarning, stacklevel=2) 54 | return self.cumulative_sizes 55 | -------------------------------------------------------------------------------- /data/data_helper.py: -------------------------------------------------------------------------------- 1 | from os.path import join 2 | 3 | import torch 4 | from torchvision import transforms 5 | 6 | from data.JigsawLoader import * 7 | from data.concat_dataset import ConcatDataset 8 | 9 | from data.JigsawLoader import JigsawNewDataset, JigsawTestNewDataset 10 | from data.JigsawLoader import get_split_domain_info_from_dir, get_split_dataset_info_from_txt, _dataset_info 11 | 12 | vlcs_datasets = ["CALTECH", "LABELME", "PASCAL", "SUN"] 13 | pacs_datasets = ["art_painting", "cartoon", "photo", "sketch"] 14 | officehome_datasets = ['Art', 'Clipart', 'Product', 'RealWorld'] 15 | available_datasets = officehome_datasets + pacs_datasets + vlcs_datasets 16 | 17 | 18 | class Subset(torch.utils.data.Dataset): 19 | def __init__(self, dataset, limit): 20 | indices = torch.randperm(len(dataset))[:limit] 21 | self.dataset = dataset 22 | self.indices = indices 23 | 24 | def __getitem__(self, idx): 25 | return self.dataset[self.indices[idx]] 26 | 27 | def __len__(self): 28 | return len(self.indices) 29 | 30 | 31 | def get_train_dataloader(args, patches): 32 | dataset_list = args.source 33 | assert isinstance(dataset_list, list) 34 | datasets = [] 35 | val_datasets = [] 36 | img_transformer, tile_transformer = get_train_transformers(args) 37 | img_transformer_val = get_val_transformer(args) 38 | limit = args.limit_source 39 | 40 | if "PACS" in args.data_root: 41 | dataset_path = join(args.data_root, "kfold") 42 | elif args.data == "miniDomainNet": 43 | dataset_path = "/data/DataSets/" + "DomainNet" 44 | else: 45 | dataset_path = args.data_root 46 | 47 | for i, dname in enumerate(dataset_list): 48 | if args.data == "PACS": 49 | name_train, name_val, labels_train, labels_val, domain_labels_train, domain_labels_val = \ 50 | get_split_dataset_info_from_txt(txt_path=join(args.data_root, "pacs_label"), domain=dname, 51 | domain_label=i + 1) 52 | 53 | elif args.data == "miniDomainNet": 54 | name_train, name_val, labels_train, labels_val, domain_labels_train, domain_labels_val = \ 55 | get_split_dataset_info_from_txt(txt_path=args.data_root, domain=dname, domain_label=i+1, 56 | val_percentage=args.val_size) 57 | else: 58 | name_train, name_val, labels_train, labels_val, domain_labels_train, domain_labels_val = \ 59 | get_split_domain_info_from_dir(join(dataset_path, dname), dataset_name=args.data, 60 | val_percentage=args.val_size, domain_label=i+1) 61 | 62 | train_dataset = JigsawNewDataset(name_train, labels_train, domain_labels_train, 63 | dataset_path=dataset_path, patches=patches, 64 | img_transformer=img_transformer, tile_transformer=tile_transformer, 65 | jig_classes=30, bias_whole_image=args.bias_whole_image) 66 | if limit: 67 | train_dataset = Subset(train_dataset, limit) 68 | datasets.append(train_dataset) 69 | val_datasets.append( 70 | JigsawTestNewDataset(name_val, labels_val, domain_labels_val, dataset_path=dataset_path, 71 | img_transformer=img_transformer_val, patches=patches, jig_classes=30)) 72 | dataset = ConcatDataset(datasets) 73 | val_dataset = ConcatDataset(val_datasets) 74 | 75 | loader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=4, 76 | pin_memory=True, drop_last=False) 77 | val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4, 78 | pin_memory=True, drop_last=False) 79 | return loader, val_loader 80 | 81 | 82 | def get_val_dataloader(args, patches=False, tSNE_flag=0): 83 | if "PACS" in args.data_root: 84 | dataset_path = join(args.data_root, "kfold") 85 | elif args.data == "miniDomainNet": 86 | dataset_path = "/data/DataSets/" + "DomainNet" 87 | else: 88 | dataset_path = args.data_root 89 | 90 | if args.data == "miniDomainNet": 91 | name_train, name_val, labels_train, labels_val, domain_label_train, domain_label_val = \ 92 | get_split_dataset_info_from_txt(txt_path=args.data_root, domain=args.target, domain_label=0, 93 | val_percentage=args.val_size) 94 | else: 95 | name_train, name_val, labels_train, labels_val, domain_label_train, domain_label_val = get_split_domain_info_from_dir( 96 | join(dataset_path, args.target), dataset_name=args.data, val_percentage=args.val_size, domain_label=0) 97 | 98 | if tSNE_flag == 0: 99 | names = name_train + name_val 100 | labels = labels_train + labels_val 101 | domain_label = domain_label_train + domain_label_val 102 | else: 103 | names = name_val 104 | labels = labels_val 105 | domain_label = domain_label_val 106 | 107 | img_tr = get_val_transformer(args) 108 | dataset_list = args.source 109 | val_dataset = JigsawTestNewDataset(names, labels, domain_label, dataset_path=dataset_path, patches=patches, 110 | img_transformer=img_tr, jig_classes=30) 111 | 112 | if args.limit_target and len(val_dataset) > args.limit_target: 113 | val_dataset = Subset(val_dataset, args.limit_target) 114 | print("Using %d subset of val dataset" % args.limit_target) 115 | 116 | dataset = ConcatDataset([val_dataset]) 117 | loader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=False, num_workers=4, 118 | pin_memory=True, drop_last=False) 119 | return loader 120 | 121 | 122 | def get_train_transformers(args): 123 | 124 | img_tr = [transforms.RandomResizedCrop((int(args.image_size), int(args.image_size)), (args.min_scale, args.max_scale))] 125 | if args.random_horiz_flip > 0.0: 126 | img_tr.append(transforms.RandomHorizontalFlip(args.random_horiz_flip)) 127 | if args.jitter > 0.0: 128 | img_tr.append(transforms.ColorJitter(brightness=args.jitter, contrast=args.jitter, saturation=args.jitter, hue=min(0.5, args.jitter))) 129 | 130 | # this is special operation for JigenDG 131 | if args.gray_flag: 132 | img_tr.append(transforms.RandomGrayscale(args.tile_random_grayscale)) 133 | 134 | img_tr.append(transforms.ToTensor()) 135 | img_tr.append(transforms.Normalize([0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])) 136 | 137 | tile_tr = [] 138 | if args.tile_random_grayscale: 139 | tile_tr.append(transforms.RandomGrayscale(args.tile_random_grayscale)) 140 | tile_tr = tile_tr + [transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])] 141 | 142 | return transforms.Compose(img_tr), transforms.Compose(tile_tr) 143 | 144 | 145 | def get_val_transformer(args): 146 | img_tr = [ 147 | transforms.Resize((args.image_size, args.image_size)), 148 | transforms.ToTensor(), 149 | transforms.Normalize([0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 150 | ] 151 | return transforms.Compose(img_tr) 152 | 153 | -------------------------------------------------------------------------------- /data/samplers.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import numpy as np 3 | import random 4 | from collections import defaultdict 5 | from torch.utils.data.sampler import Sampler, RandomSampler, SequentialSampler 6 | import math 7 | 8 | 9 | 10 | class BatchSchedulerSampler(Sampler): 11 | """ 12 | iterate over tasks and provide a random batch per task in each mini-batch 13 | """ 14 | def __init__(self, dataset, batch_size): 15 | self.dataset = dataset 16 | self.batch_size = batch_size 17 | self.number_of_datasets = len(dataset.datasets) 18 | self.mini_batch_size = int(batch_size / self.number_of_datasets) 19 | # self.largest_dataset_size = max([len(cur_dataset.samples) for cur_dataset in dataset.datasets]) 20 | self.largest_dataset_size = max([len(cur_dataset) for cur_dataset in dataset.datasets]) 21 | 22 | def __len__(self): 23 | return self.mini_batch_size * math.ceil(self.largest_dataset_size / self.mini_batch_size) * len(self.dataset.datasets) 24 | 25 | def __iter__(self): 26 | samplers_list = [] 27 | sampler_iterators = [] 28 | for dataset_idx in range(self.number_of_datasets): 29 | cur_dataset = self.dataset.datasets[dataset_idx] 30 | sampler = RandomSampler(cur_dataset) 31 | samplers_list.append(sampler) 32 | cur_sampler_iterator = sampler.__iter__() 33 | sampler_iterators.append(cur_sampler_iterator) 34 | 35 | push_index_val = [0] + self.dataset.cumulative_sizes[:-1] 36 | step = self.batch_size 37 | # for this case we want to get all samples in dataset, this force us to resample from the smaller datasets 38 | epoch_samples = self.largest_dataset_size * self.number_of_datasets 39 | 40 | final_samples_list = [] # this is a list of indexes from the combined dataset 41 | for _ in range(0, epoch_samples, step): 42 | for i in range(self.number_of_datasets): 43 | cur_batch_sampler = sampler_iterators[i] 44 | cur_samples = [] 45 | for _ in range(self.mini_batch_size): 46 | try: 47 | cur_sample_org = cur_batch_sampler.__next__() 48 | cur_sample = cur_sample_org + push_index_val[i] 49 | cur_samples.append(cur_sample) 50 | except StopIteration: 51 | # got to the end of iterator - restart the iterator and continue to get samples 52 | # until reaching "epoch_samples" 53 | sampler_iterators[i] = samplers_list[i].__iter__() 54 | cur_batch_sampler = sampler_iterators[i] 55 | cur_sample_org = cur_batch_sampler.__next__() 56 | cur_sample = cur_sample_org + push_index_val[i] 57 | cur_samples.append(cur_sample) 58 | final_samples_list.extend(cur_samples) 59 | 60 | return iter(final_samples_list) 61 | 62 | 63 | class RandomDomainSampler(Sampler): 64 | """Randomly samples N domains each with K images 65 | to form a minibatch of size N*K. 66 | Args: 67 | data_source (list): list of Datums. 68 | batch_size (int): batch size. 69 | n_domain (int): number of domains to sample in a minibatch. 70 | """ 71 | 72 | def __init__(self, data_source, batch_size, n_domain): 73 | self.data_source = data_source 74 | 75 | # Keep track of image indices for each domain 76 | self.domain_dict = defaultdict(list) 77 | for i, item in enumerate(data_source): 78 | self.domain_dict[item.domain].append(i) 79 | self.domains = list(self.domain_dict.keys()) 80 | 81 | # Make sure each domain has equal number of images 82 | if n_domain is None or n_domain <= 0: 83 | n_domain = len(self.domains) 84 | assert batch_size % n_domain == 0 85 | self.n_img_per_domain = batch_size // n_domain 86 | 87 | self.batch_size = batch_size 88 | # n_domain denotes number of domains sampled in a minibatch 89 | self.n_domain = n_domain 90 | self.length = len(list(self.__iter__())) 91 | 92 | def __iter__(self): 93 | domain_dict = copy.deepcopy(self.domain_dict) 94 | final_idxs = [] 95 | stop_sampling = False 96 | 97 | while not stop_sampling: 98 | selected_domains = random.sample(self.domains, self.n_domain) 99 | 100 | for domain in selected_domains: 101 | idxs = domain_dict[domain] 102 | selected_idxs = random.sample(idxs, self.n_img_per_domain) 103 | final_idxs.extend(selected_idxs) 104 | 105 | for idx in selected_idxs: 106 | domain_dict[domain].remove(idx) 107 | 108 | remaining = len(domain_dict[domain]) 109 | if remaining < self.n_img_per_domain: 110 | stop_sampling = True 111 | 112 | return iter(final_idxs) 113 | 114 | def __len__(self): 115 | return self.length 116 | 117 | 118 | class SeqDomainSampler(Sampler): 119 | """Sequential domain sampler, which randomly samples K 120 | images from each domain to form a minibatch. 121 | Args: 122 | data_source (list): list of Datums. 123 | batch_size (int): batch size. 124 | """ 125 | 126 | def __init__(self, data_source, batch_size): 127 | self.data_source = data_source 128 | 129 | # Keep track of image indices for each domain 130 | self.domain_dict = defaultdict(list) 131 | for i, item in enumerate(data_source): 132 | self.domain_dict[item.domain].append(i) 133 | self.domains = list(self.domain_dict.keys()) 134 | self.domains.sort() 135 | 136 | # Make sure each domain has equal number of images 137 | n_domain = len(self.domains) 138 | assert batch_size % n_domain == 0 139 | self.n_img_per_domain = batch_size // n_domain 140 | 141 | self.batch_size = batch_size 142 | # n_domain denotes number of domains sampled in a minibatch 143 | self.n_domain = n_domain 144 | self.length = len(list(self.__iter__())) 145 | 146 | def __iter__(self): 147 | domain_dict = copy.deepcopy(self.domain_dict) 148 | final_idxs = [] 149 | stop_sampling = False 150 | 151 | while not stop_sampling: 152 | for domain in self.domains: 153 | idxs = domain_dict[domain] 154 | selected_idxs = random.sample(idxs, self.n_img_per_domain) 155 | final_idxs.extend(selected_idxs) 156 | 157 | for idx in selected_idxs: 158 | domain_dict[domain].remove(idx) 159 | 160 | remaining = len(domain_dict[domain]) 161 | if remaining < self.n_img_per_domain: 162 | stop_sampling = True 163 | 164 | return iter(final_idxs) 165 | 166 | def __len__(self): 167 | return self.length 168 | 169 | 170 | class RandomClassSampler(Sampler): 171 | """Randomly samples N classes each with K instances to 172 | form a minibatch of size N*K. 173 | Modified from https://github.com/KaiyangZhou/deep-person-reid. 174 | Args: 175 | data_source (list): list of Datums. 176 | batch_size (int): batch size. 177 | n_ins (int): number of instances per class to sample in a minibatch. 178 | """ 179 | 180 | def __init__(self, data_source, batch_size, n_ins): 181 | if batch_size < n_ins: 182 | raise ValueError( 183 | "batch_size={} must be no less " 184 | "than n_ins={}".format(batch_size, n_ins) 185 | ) 186 | 187 | self.data_source = data_source 188 | self.batch_size = batch_size 189 | self.n_ins = n_ins 190 | self.ncls_per_batch = self.batch_size // self.n_ins 191 | self.index_dic = defaultdict(list) 192 | for index, item in enumerate(data_source): 193 | self.index_dic[item.label].append(index) 194 | self.labels = list(self.index_dic.keys()) 195 | assert len(self.labels) >= self.ncls_per_batch 196 | 197 | # estimate number of images in an epoch 198 | self.length = len(list(self.__iter__())) 199 | 200 | def __iter__(self): 201 | batch_idxs_dict = defaultdict(list) 202 | 203 | for label in self.labels: 204 | idxs = copy.deepcopy(self.index_dic[label]) 205 | if len(idxs) < self.n_ins: 206 | idxs = np.random.choice(idxs, size=self.n_ins, replace=True) 207 | random.shuffle(idxs) 208 | batch_idxs = [] 209 | for idx in idxs: 210 | batch_idxs.append(idx) 211 | if len(batch_idxs) == self.n_ins: 212 | batch_idxs_dict[label].append(batch_idxs) 213 | batch_idxs = [] 214 | 215 | avai_labels = copy.deepcopy(self.labels) 216 | final_idxs = [] 217 | 218 | while len(avai_labels) >= self.ncls_per_batch: 219 | selected_labels = random.sample(avai_labels, self.ncls_per_batch) 220 | for label in selected_labels: 221 | batch_idxs = batch_idxs_dict[label].pop(0) 222 | final_idxs.extend(batch_idxs) 223 | if len(batch_idxs_dict[label]) == 0: 224 | avai_labels.remove(label) 225 | 226 | return iter(final_idxs) 227 | 228 | def __len__(self): 229 | return self.length 230 | 231 | 232 | def build_sampler( 233 | sampler_type, 234 | cfg=None, 235 | data_source=None, 236 | batch_size=32, 237 | n_domain=0, 238 | n_ins=16 239 | ): 240 | if sampler_type == "RandomSampler": 241 | return RandomSampler(data_source) 242 | 243 | elif sampler_type == "SequentialSampler": 244 | return SequentialSampler(data_source) 245 | 246 | elif sampler_type == "RandomDomainSampler": 247 | return RandomDomainSampler(data_source, batch_size, n_domain) 248 | 249 | elif sampler_type == "SeqDomainSampler": 250 | return SeqDomainSampler(data_source, batch_size) 251 | 252 | elif sampler_type == "RandomClassSampler": 253 | return RandomClassSampler(data_source, batch_size, n_ins) 254 | 255 | else: 256 | raise ValueError("Unknown sampler type: {}".format(sampler_type)) 257 | 258 | -------------------------------------------------------------------------------- /loss/KL_Loss.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | 3 | def compute_kl_loss(p, q, pad_mask=None, T=10): 4 | p_T = p / T 5 | q_T = q / T 6 | p_loss = F.kl_div(F.log_softmax(p_T, dim=-1), F.softmax(q_T, dim=-1), reduction='none') 7 | q_loss = F.kl_div(F.log_softmax(q_T, dim=-1), F.softmax(p_T, dim=-1), reduction='none') 8 | 9 | # pad_mask is for seq-level tasks 10 | if pad_mask is not None: 11 | p_loss.masked_fill_(pad_mask, 0.) 12 | q_loss.masked_fill_(pad_mask, 0.) 13 | 14 | # You can choose whether to use function "sum" and "mean" depending on your task 15 | # p_loss = p_loss.sum() 16 | # q_loss = q_loss.sum() 17 | 18 | p_loss = p_loss.mean() 19 | q_loss = q_loss.mean() 20 | 21 | loss = (p_loss + q_loss) / 2 22 | return loss -------------------------------------------------------------------------------- /models/FilterDropout.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import random 5 | 6 | 7 | def mask_selection(scores, percent, wrs_flag): 8 | # input: scores: BxN 9 | batch_size = scores.shape[0] 10 | num_neurons = scores.shape[1] 11 | drop_num = int(num_neurons * percent) 12 | 13 | if wrs_flag == 0: 14 | # according to scores 15 | threshold = torch.sort(scores, dim=1, descending=True)[0][:, drop_num] 16 | threshold_expand = threshold.view(batch_size, 1).expand(batch_size, num_neurons) 17 | mask_filters = torch.where(scores > threshold_expand, torch.tensor(1.).cuda(), torch.tensor(0.).cuda()) 18 | else: 19 | # add random modules 20 | score_max = scores.max(dim=1, keepdim=True)[0] 21 | score_min = scores.min(dim=1, keepdim=True)[0] 22 | scores = (scores - score_min) / (score_max - score_min) 23 | 24 | r = torch.rand(scores.shape).cuda() # BxC 25 | key = r.pow(1. / scores) 26 | threshold = torch.sort(key, dim=1, descending=True)[0][:, drop_num] 27 | threshold_expand = threshold.view(batch_size, 1).expand(batch_size, num_neurons) 28 | mask_filters = torch.where(key > threshold_expand, torch.tensor(1.).cuda(), torch.tensor(0.).cuda()) 29 | 30 | mask_filters = 1 - mask_filters # BxN 31 | return mask_filters 32 | 33 | 34 | def filter_dropout_channel(scores, percent, wrs_flag): 35 | # scores: BxCxHxW 36 | batch_size, channel_num, H, W = scores.shape[0], scores.shape[1], scores.shape[2], scores.shape[3] 37 | channel_scores = nn.AdaptiveAvgPool2d((1, 1))(scores).view(batch_size, channel_num) 38 | # channel_scores = channel_scores / channel_scores.sum(dim=1, keepdim=True) 39 | mask = mask_selection(channel_scores, percent, wrs_flag) # BxC 40 | mask_filters = mask.view(batch_size, channel_num, 1, 1) 41 | return mask_filters 42 | 43 | 44 | -------------------------------------------------------------------------------- /models/LayerDiscriminator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from .FilterDropout import filter_dropout_channel 4 | 5 | 6 | class GradReverse(torch.autograd.Function): 7 | @staticmethod 8 | def forward(ctx, x, lambd, reverse=True): 9 | ctx.lambd = lambd 10 | ctx.reverse = reverse 11 | return x.view_as(x) 12 | 13 | @staticmethod 14 | def backward(ctx, grad_output): 15 | if ctx.reverse: 16 | return (grad_output * -ctx.lambd), None, None 17 | else: 18 | return (grad_output * ctx.lambd), None, None 19 | 20 | 21 | def grad_reverse(x, lambd=1.0, reverse=True): 22 | return GradReverse.apply(x, lambd, reverse) 23 | 24 | 25 | class LayerDiscriminator(nn.Module): 26 | def __init__(self, num_channels, num_classes, grl=True, reverse=True, lambd=0.0, wrs_flag=1): 27 | super(LayerDiscriminator, self).__init__() 28 | 29 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 30 | self.model = nn.Linear(num_channels, num_classes) 31 | self.softmax = nn.Softmax(0) 32 | self.num_channels = num_channels 33 | 34 | self.grl = grl 35 | self.reverse = reverse 36 | self.lambd = lambd 37 | 38 | self.wrs_flag = wrs_flag 39 | 40 | def scores_dropout(self, scores, percent): 41 | mask_filters = filter_dropout_channel(scores=scores, percent=percent, wrs_flag=self.wrs_flag) 42 | mask_filters = mask_filters.cuda() # BxCx1x1 43 | return mask_filters 44 | 45 | def norm_scores(self, scores): 46 | score_max = scores.max(dim=1, keepdim=True)[0] 47 | score_min = scores.min(dim=1, keepdim=True)[0] 48 | scores_norm = (scores - score_min) / (score_max - score_min) 49 | return scores_norm 50 | 51 | def get_scores(self, feature, labels, percent=0.33): 52 | weights = self.model.weight.clone().detach() # num_domains x C 53 | domain_num, channel_num = weights.shape[0], weights.shape[1] 54 | batch_size, _, H, W = feature.shape[0], feature.shape[1], feature.shape[2], feature.shape[3] 55 | 56 | weight = weights[labels].view(batch_size, channel_num, 1).expand(batch_size, channel_num, H * W)\ 57 | .view(batch_size, channel_num, H, W) 58 | 59 | right_score = torch.mul(feature, weight) 60 | right_score = self.norm_scores(right_score) 61 | 62 | # right_score_masks: BxCxHxW 63 | right_score_masks = self.scores_dropout(right_score, percent=percent) 64 | return right_score_masks 65 | 66 | def forward(self, x, labels, percent=0.33): 67 | if self.grl: 68 | x = grad_reverse(x, self.lambd, self.reverse) 69 | 70 | feature = x.clone().detach() # BxCxHxW 71 | x = self.avgpool(x) 72 | x = x.view(x.size(0), -1) # BxC 73 | y = self.model(x) 74 | 75 | # This step is to compute the 0-1 mask, which indicate the location of the domain-related information. 76 | # mask_filters: {0 / 1} BxCxHxW 77 | mask_filters = self.get_scores(feature, labels, percent=percent) 78 | return y, mask_filters 79 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lingeringlight/DomainDrop/a2d9607d4f1fb9ba35a9fa48fbceff82fd67d15f/models/__init__.py -------------------------------------------------------------------------------- /models/__pycache__/MixStyle.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lingeringlight/DomainDrop/a2d9607d4f1fb9ba35a9fa48fbceff82fd67d15f/models/__pycache__/MixStyle.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lingeringlight/DomainDrop/a2d9607d4f1fb9ba35a9fa48fbceff82fd67d15f/models/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/model_factory.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lingeringlight/DomainDrop/a2d9607d4f1fb9ba35a9fa48fbceff82fd67d15f/models/__pycache__/model_factory.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/resnet_Ensemble.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lingeringlight/DomainDrop/a2d9607d4f1fb9ba35a9fa48fbceff82fd67d15f/models/__pycache__/resnet_Ensemble.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/resnet_RandAug.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lingeringlight/DomainDrop/a2d9607d4f1fb9ba35a9fa48fbceff82fd67d15f/models/__pycache__/resnet_RandAug.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/resnet_classifiers.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lingeringlight/DomainDrop/a2d9607d4f1fb9ba35a9fa48fbceff82fd67d15f/models/__pycache__/resnet_classifiers.cpython-37.pyc -------------------------------------------------------------------------------- /models/model_factory.py: -------------------------------------------------------------------------------- 1 | from models import resnet_domain as resnet 2 | 3 | nets_map = { 4 | # 'alexnet': alexnet.alexnet, 5 | 'resnet18': resnet.resnet18, 6 | 'resnet50': resnet.resnet50, 7 | } 8 | 9 | 10 | def get_network(name): 11 | if name not in nets_map: 12 | raise ValueError('Name of network unknown %s' % name) 13 | 14 | def get_network_fn(**kwargs): 15 | return nets_map[name](**kwargs) 16 | 17 | return get_network_fn 18 | -------------------------------------------------------------------------------- /models/model_utils.py: -------------------------------------------------------------------------------- 1 | from torch.autograd import Function 2 | 3 | 4 | class GradientKillerLayer(Function): 5 | @staticmethod 6 | def forward(ctx, x, **kwargs): 7 | return x.view_as(x) 8 | 9 | @staticmethod 10 | def backward(ctx, grad_output): 11 | return None, None 12 | 13 | 14 | class ReverseLayerF(Function): 15 | @staticmethod 16 | def forward(ctx, x, lambda_val): 17 | ctx.lambda_val = lambda_val 18 | 19 | return x.view_as(x) 20 | 21 | @staticmethod 22 | def backward(ctx, grad_output): 23 | output = grad_output.neg() * ctx.lambda_val 24 | 25 | return output, None -------------------------------------------------------------------------------- /models/resnet_domain.py: -------------------------------------------------------------------------------- 1 | from torch.utils import model_zoo 2 | from torchvision.models.resnet import BasicBlock, model_urls, Bottleneck 3 | from torch import nn as nn 4 | from .LayerDiscriminator import LayerDiscriminator 5 | import random 6 | 7 | class ResNet(nn.Module): 8 | def __init__(self, block, layers, 9 | device, 10 | classes=100, 11 | domains=3, 12 | network='resnet18', 13 | domain_discriminator_flag=0, 14 | grl=0, 15 | lambd=0., 16 | drop_percent=0.33, 17 | dropout_mode=0, 18 | wrs_flag=0, 19 | recover_flag=0, 20 | layer_wise_flag=0, 21 | ): 22 | self.inplanes = 64 23 | super(ResNet, self).__init__() 24 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) 25 | self.bn1 = nn.BatchNorm2d(64) 26 | self.relu = nn.ReLU(inplace=True) 27 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 28 | self.layer1 = self._make_layer(block, 64, layers[0]) 29 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 30 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 31 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 32 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 33 | self.classifier = nn.Linear(512 * block.expansion, classes) 34 | 35 | if network == "resnet18": 36 | layer_channels = [64, 64, 128, 256, 512] 37 | else: 38 | layer_channels = [64, 256, 512, 1024, 2048] 39 | 40 | self.device = device 41 | self.domain_discriminator_flag = domain_discriminator_flag 42 | self.drop_percent = drop_percent 43 | self.dropout_mode = dropout_mode 44 | 45 | self.recover_flag = recover_flag 46 | self.layer_wise_flag = layer_wise_flag 47 | 48 | self.domain_discriminators = nn.ModuleList([ 49 | LayerDiscriminator( 50 | num_channels=layer_channels[layer], 51 | num_classes=domains, 52 | grl=grl, 53 | reverse=True, 54 | lambd=lambd, 55 | wrs_flag=wrs_flag, 56 | ) 57 | for i, layer in enumerate([0, 1, 2, 3, 4])]) 58 | 59 | for m in self.modules(): 60 | if isinstance(m, nn.Conv2d): 61 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 62 | elif isinstance(m, nn.BatchNorm2d): 63 | nn.init.constant_(m.weight, 1) 64 | nn.init.constant_(m.bias, 0) 65 | 66 | def _make_layer(self, block, planes, blocks, stride=1): 67 | downsample = None 68 | if stride != 1 or self.inplanes != planes * block.expansion: 69 | downsample = nn.Sequential( 70 | nn.Conv2d(self.inplanes, planes * block.expansion, 71 | kernel_size=1, stride=stride, bias=False), 72 | nn.BatchNorm2d(planes * block.expansion), 73 | ) 74 | layers = [] 75 | layers.append(block(self.inplanes, planes, stride, downsample)) 76 | self.inplanes = planes * block.expansion 77 | for i in range(1, blocks): 78 | layers.append(block(self.inplanes, planes)) 79 | return nn.Sequential(*layers) 80 | 81 | def is_patch_based(self): 82 | return False 83 | 84 | def perform_dropout(self, feature, domain_labels, layer_index, layer_dropout_flag): 85 | domain_output = None 86 | if self.domain_discriminator_flag and self.training: 87 | index = layer_index 88 | percent = self.drop_percent 89 | domain_output, domain_mask = self.domain_discriminators[index]( 90 | feature.clone(), 91 | domain_labels, 92 | percent=percent, 93 | ) 94 | if self.recover_flag: 95 | domain_mask = domain_mask * domain_mask.numel() / domain_mask.sum() 96 | if layer_dropout_flag: 97 | feature = feature * domain_mask 98 | return feature, domain_output 99 | 100 | def forward(self, x, domain_labels=None, layer_drop_flag=None): 101 | x = self.conv1(x) 102 | x = self.bn1(x) 103 | x = self.relu(x) 104 | x = self.maxpool(x) 105 | 106 | domain_outputs = [] 107 | for i, layer in enumerate([self.layer1, self.layer2, self.layer3, self.layer4]): 108 | x = layer(x) 109 | x, domain_output = self.perform_dropout(x, domain_labels, layer_index=i + 1, 110 | layer_dropout_flag=layer_drop_flag[i]) 111 | if domain_output is not None: 112 | domain_outputs.append(domain_output) 113 | 114 | x = self.avgpool(x) 115 | x = x.view(x.size(0), -1) # B x C 116 | y = self.classifier(x) 117 | return y, domain_outputs 118 | 119 | 120 | def resnet18(pretrained=True, **kwargs): 121 | """Constructs a ResNet-18 model. 122 | Args: 123 | pretrained (bool): If True, returns a model pre-trained on ImageNet 124 | """ 125 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 126 | if pretrained: 127 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18']), strict=False) 128 | return model 129 | 130 | 131 | def resnet50(pretrained=True, **kwargs): 132 | """Constructs a ResNet-50 model. 133 | Args: 134 | pretrained (bool): If True, returns a model pre-trained on ImageNet 135 | """ 136 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 137 | if pretrained: 138 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50']), strict=False) 139 | return model 140 | 141 | -------------------------------------------------------------------------------- /models/utils.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | import torch 3 | 4 | 5 | def set_requires_grad(model, requires_grad): 6 | for param in model.parameters(): 7 | param.requires_grad = requires_grad 8 | 9 | 10 | def simple_transform(x, beta): 11 | x = 1 / torch.pow(torch.log(1/x + 1), beta) 12 | return x 13 | 14 | 15 | @contextlib.contextmanager 16 | def disable_tracking_bn_stats(model): 17 | def switch_attr(m): 18 | if hasattr(m, 'track_running_stats'): 19 | m.track_running_stats ^= True 20 | 21 | model.apply(switch_attr) 22 | yield 23 | model.apply(switch_attr) 24 | 25 | 26 | 27 | 28 | -------------------------------------------------------------------------------- /optimizer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lingeringlight/DomainDrop/a2d9607d4f1fb9ba35a9fa48fbceff82fd67d15f/optimizer/__init__.py -------------------------------------------------------------------------------- /optimizer/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lingeringlight/DomainDrop/a2d9607d4f1fb9ba35a9fa48fbceff82fd67d15f/optimizer/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /optimizer/__pycache__/optimizer_helper.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lingeringlight/DomainDrop/a2d9607d4f1fb9ba35a9fa48fbceff82fd67d15f/optimizer/__pycache__/optimizer_helper.cpython-37.pyc -------------------------------------------------------------------------------- /optimizer/optimizer_helper.py: -------------------------------------------------------------------------------- 1 | from torch import optim 2 | 3 | 4 | def get_optim_and_scheduler(model, network, epochs, lr, train_all=True, nesterov=False): 5 | if train_all: 6 | params = model.parameters() 7 | else: 8 | params = model.get_params(lr) 9 | optimizer = optim.SGD(params, weight_decay=.0005, momentum=.9, nesterov=nesterov, lr=lr) 10 | step_size = int(epochs * .8) 11 | scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=0.1) 12 | print("Step size: %d" % step_size) 13 | return optimizer, scheduler 14 | 15 | 16 | def get_optim_and_scheduler_style(style_net, epochs, lr, nesterov=False, step_radio=0.8): 17 | optimizer = optim.SGD(style_net, weight_decay=.0005, momentum=.9, nesterov=nesterov, lr=lr) 18 | step_size = int(epochs * step_radio) 19 | scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=step_size) 20 | print("Step size: %d for style net" % step_size) 21 | return optimizer, scheduler 22 | 23 | 24 | def get_optim_and_scheduler_layer_joint(style_net, epochs, lr, train_all=None, nesterov=False): 25 | optimizer = optim.SGD(style_net, weight_decay=.0005, momentum=.9, nesterov=nesterov, lr=lr) 26 | step_size = int(epochs * 1.) 27 | scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=step_size) 28 | print("Step size: %d for style net" % step_size) 29 | return optimizer, scheduler 30 | 31 | 32 | def get_model_lr(name, model, fc_weight=1.0): 33 | if 'resnet' in name: 34 | return [ 35 | (model.conv1, 1.0), # 0 36 | (model.bn1, 1.0), # 1 37 | (model.layer1, 1.0), # 2 38 | (model.domain_discriminators[1], 1.0), # 3 39 | (model.layer2, 1.0), # 4 40 | (model.domain_discriminators[2], 1.0), # 5 41 | (model.layer3, 1.0), # 6 42 | (model.domain_discriminators[3], 1.0), # 7 43 | (model.layer4, 1.0), # 8 44 | (model.domain_discriminators[4], 1.0), # 9 45 | (model.classifier, 1.0 * fc_weight) # 10 46 | ] 47 | elif name == 'alexnet': 48 | return [ 49 | (model.layer0, 1.0), # 0 50 | (model.layer1, 1.0), # 1 51 | (model.layer2, 1.0), # 2 52 | (model.feature_layers, 1.0), # 3 53 | (model.fc, 1.0 * fc_weight), # 4 54 | ] 55 | else: 56 | raise NotImplementedError 57 | 58 | 59 | def get_optimizer(model, init_lr, momentum=.9, weight_decay=.0005, nesterov=False): 60 | optimizer = optim.SGD(model.parameters(), lr=init_lr, momentum=momentum, weight_decay=weight_decay, 61 | nesterov=nesterov) 62 | return optimizer 63 | 64 | 65 | def get_optim_and_scheduler_scatter(model, network, epochs, lr, momentum=.9, weight_decay=.0005, nesterov=False, step_radio=0.8): 66 | model_lr = get_model_lr(name=network, model=model, fc_weight=1.0) 67 | optimizers = [get_optimizer(model_part, lr * alpha, momentum, weight_decay, nesterov) 68 | for model_part, alpha in model_lr] 69 | step_size = int(epochs * step_radio) 70 | schedulers = [optim.lr_scheduler.StepLR(opt, step_size=step_size) for opt in optimizers] 71 | return optimizers, schedulers 72 | -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | device=0 4 | data='PACS' 5 | network='resnet18' 6 | 7 | for t in `seq 0 4` 8 | do 9 | for domain in `seq 0 3` 10 | do 11 | python train_domain.py \ 12 | --target $domain \ 13 | --device $device \ 14 | --network $network \ 15 | --time $t \ 16 | --batch_size 64 \ 17 | --data $data \ 18 | --data_root "/data/DataSets/" \ 19 | --result_path "/data/save/models/" \ 20 | --KL_Loss 1 \ 21 | --KL_Loss_weight 1.5 \ 22 | --KL_Loss_T 5 \ 23 | --layer_wise_prob 0.8 \ 24 | --domain_discriminator_flag 1 \ 25 | --domain_loss_flag 1 \ 26 | --discriminator_layers 1 2 3 4 \ 27 | --grl 1 \ 28 | --lambd 0.25 \ 29 | --drop_percent 0.33 \ 30 | --recover_flag 1 \ 31 | --epochs 50 \ 32 | --learning_rate 0.002 33 | done 34 | done 35 | -------------------------------------------------------------------------------- /train_domain.py: -------------------------------------------------------------------------------- 1 | # from torch.utils.tensorboard import SummaryWriter 2 | import argparse 3 | # import torch 4 | from torch import nn 5 | from data import data_helper 6 | from models import model_factory 7 | from optimizer.optimizer_helper import get_optim_and_scheduler, get_optim_and_scheduler_scatter 8 | from utils.Logger import Logger 9 | from models.resnet_domain import resnet18, resnet50 10 | import os 11 | import random 12 | import time 13 | from utils.tools import * 14 | from loss.KL_Loss import compute_kl_loss 15 | 16 | 17 | def get_args(): 18 | parser = argparse.ArgumentParser(description="Script to launch jigsaw training", 19 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 20 | parser.add_argument("--target", default=0, type=int, help="Target") 21 | parser.add_argument("--device", type=int, default=0, help="GPU num") 22 | parser.add_argument("--time", default=0, type=int, help="train time") 23 | 24 | parser.add_argument("--eval", default=0, type=int, help="Eval trained models") 25 | parser.add_argument("--eval_model_path", default="/model/path", help="Path of trained models") 26 | 27 | parser.add_argument("--batch_size", "-b", type=int, default=64, help="Batch size") 28 | parser.add_argument("--image_size", type=int, default=224, help="Image size") 29 | 30 | parser.add_argument("--data", default="PACS") 31 | parser.add_argument("--data_root", default="/data/DataSets/") 32 | 33 | parser.add_argument("--KL_Loss", default=1, type=int, help="whether to use consistency of dropout") 34 | parser.add_argument("--KL_Loss_weight", default=1.5, type=float, help="weight of KL_Loss") 35 | parser.add_argument("--KL_Loss_T", default=5, type=float, help="T of KL_Loss") 36 | 37 | parser.add_argument("--layer_wise_prob", default=0.8, type=float, help="prob to use layer-wise dropout") 38 | 39 | parser.add_argument("--domain_discriminator_flag", default=1, type=int, help="whether use domain discriminator.") 40 | parser.add_argument("--domain_loss_flag", default=1, type=int, help="whether use domain loss.") 41 | parser.add_argument("--discriminator_layers", default=[1, 2, 3, 4], nargs="+", type=int, help="where to place discriminators") 42 | parser.add_argument("--grl", default=1, type=int, help="whether to use grl") 43 | parser.add_argument("--lambd", default=0.25, type=float, help="weight of grl") 44 | 45 | parser.add_argument("--drop_percent", default=0.33, type=float, help="percent of dropped filters") 46 | parser.add_argument("--filter_WRS_flag", default=1, type=int, help="Weighted Random Selection.") 47 | parser.add_argument("--recover_flag", default=1, type=int) 48 | parser.add_argument("--result_path", default="/data/DomainDropout_results/", help="") 49 | 50 | # data aug stuff 51 | parser.add_argument("--learning_rate", "-l", type=float, default=.002, help="Learning rate") 52 | parser.add_argument("--epochs", "-e", type=int, default=50, help="Number of epochs") 53 | parser.add_argument("--min_scale", default=0.8, type=float, help="Minimum scale percent") 54 | parser.add_argument("--max_scale", default=1.0, type=float, help="Maximum scale percent") 55 | parser.add_argument("--gray_flag", default=1, type=int, help="whether use random gray") 56 | parser.add_argument("--random_horiz_flip", default=0.5, type=float, help="Chance of random horizontal flip") 57 | parser.add_argument("--jitter", default=0.4, type=float, help="Color jitter amount") 58 | parser.add_argument("--tile_random_grayscale", default=0.1, type=float, 59 | help="Chance of randomly greyscaling a tile") 60 | parser.add_argument("--limit_source", default=None, type=int, 61 | help="If set, it will limit the number of training samples") 62 | parser.add_argument("--limit_target", default=None, type=int, 63 | help="If set, it will limit the number of testing samples") 64 | parser.add_argument("--network", choices=model_factory.nets_map.keys(), help="Which network to use", 65 | default="resnet18") 66 | parser.add_argument("--tf_logger", type=bool, default=True, help="If true will save tensorboard compatible logs") 67 | parser.add_argument("--folder_name", default='test', help="Used by the logger to save logs") 68 | parser.add_argument("--bias_whole_image", default=0.9, type=float, 69 | help="If set, will bias the training procedure to show more often the whole image") 70 | parser.add_argument("--TTA", type=bool, default=False, help="Activate test time data augmentation") 71 | parser.add_argument("--classify_only_sane", default=False, type=bool, 72 | help="If true, the network will only try to classify the non scrambled images") 73 | parser.add_argument("--train_all", default=True, type=bool, help="If true, all network weights will be trained") 74 | parser.add_argument("--suffix", default="", help="Suffix for the logger") 75 | parser.add_argument("--nesterov", default=True, type=bool, help="Use nesterov") 76 | 77 | return parser.parse_args() 78 | 79 | 80 | def get_results_path(args): 81 | # Make the directory to store the experimental results 82 | base_result_path = args.result_path + "/" + args.data + "/" 83 | base_result_path += args.network 84 | 85 | if args.domain_discriminator_flag == 1: 86 | base_result_path += "_DomainDrop" 87 | 88 | base_result_path += "_layer_wise" + str(args.layer_wise_prob) 89 | 90 | if args.grl == 1: 91 | base_result_path += "_grl" + str(args.lambd) 92 | base_result_path += "_channel" 93 | 94 | base_result_path += "_L" 95 | for i, layer in enumerate(args.discriminator_layers): 96 | base_result_path += str(layer) 97 | base_result_path += "_dropP" + str(args.drop_percent) 98 | base_result_path += "_domain" 99 | if args.filter_WRS_flag == 1: 100 | base_result_path += "_WRS" 101 | 102 | if args.KL_Loss == 1: 103 | base_result_path += "_KL_" + str(args.KL_Loss_weight) + "_T" + str(args.KL_Loss_T) 104 | 105 | base_result_path += "_lr" + str(args.learning_rate) + "_B" + str(args.batch_size) 106 | base_result_path += "/" + args.target + str(args.time) + "/" 107 | if not os.path.exists(base_result_path): 108 | os.makedirs(base_result_path) 109 | return base_result_path 110 | 111 | 112 | class Trainer: 113 | def __init__(self, args, device): 114 | self.args = args 115 | self.device = device 116 | if args.network == 'resnet18': 117 | model = resnet18( 118 | pretrained=True, 119 | device=device, 120 | classes=args.n_classes, 121 | domains=args.n_domains, 122 | network=args.network, 123 | domain_discriminator_flag=args.domain_discriminator_flag, 124 | grl=args.grl, 125 | lambd=args.lambd, 126 | drop_percent=args.drop_percent, 127 | wrs_flag=args.filter_WRS_flag, 128 | recover_flag=args.recover_flag, 129 | ) 130 | elif args.network == 'resnet50': 131 | model = resnet50( 132 | pretrained=True, 133 | device=device, 134 | classes=args.n_classes, 135 | domains=args.n_domains, 136 | network=args.network, 137 | domain_discriminator_flag=args.domain_discriminator_flag, 138 | grl=args.grl, 139 | lambd=args.lambd, 140 | drop_percent=args.drop_percent, 141 | wrs_flag=args.filter_WRS_flag, 142 | recover_flag=args.recover_flag, 143 | ) 144 | else: 145 | raise NotImplementedError("Not Implemented Network.") 146 | 147 | self.model = model.to(device) 148 | self.source_loader, self.val_loader = data_helper.get_train_dataloader(args, patches=model.is_patch_based()) 149 | self.target_loader = data_helper.get_val_dataloader(args, patches=model.is_patch_based()) 150 | self.test_loaders = {"val": self.val_loader, "test": self.target_loader} 151 | self.len_dataloader = len(self.source_loader) 152 | print("Dataset size: train %d, val %d, test %d" % (len(self.source_loader.dataset), 153 | len(self.val_loader.dataset), 154 | len(self.target_loader.dataset))) 155 | 156 | self.optimizer_scatter, self.scheduler_scatter = get_optim_and_scheduler_scatter(model=model, 157 | network=args.network, 158 | epochs=args.epochs, 159 | lr=args.learning_rate, 160 | nesterov=args.nesterov) 161 | self.n_classes = args.n_classes 162 | self.base_result_path = get_results_path(args) 163 | 164 | self.val_best = 0.0 165 | self.test_corresponding = 0.0 166 | 167 | self.criterion = nn.CrossEntropyLoss() 168 | self.domain_criterion = nn.CrossEntropyLoss() 169 | 170 | self.domain_discriminator_flag = args.domain_discriminator_flag 171 | self.domain_loss_flag = args.domain_loss_flag 172 | self.discriminator_layers = args.discriminator_layers 173 | 174 | self.layer_wise_prob = args.layer_wise_prob 175 | 176 | def select_layers(self, layer_wise_prob): 177 | # layer_wise_prob: prob for layer-wise dropout 178 | layer_index = np.random.randint(len(self.args.discriminator_layers), size=1)[0] 179 | layer_select = self.discriminator_layers[layer_index] 180 | layer_drop_flag = [0, 0, 0, 0] 181 | if random.random() <= layer_wise_prob: 182 | layer_drop_flag[layer_select - 1] = 1 183 | return layer_drop_flag 184 | 185 | def _do_epoch(self, epoch=None): 186 | self.model.train() 187 | 188 | CE_loss = 0.0 189 | batch_num = 0.0 190 | class_right = 0.0 191 | class_total = 0.0 192 | 193 | CE_domain_loss = [0.0 for i in range(5)] 194 | domain_right = [0.0 for i in range(5)] 195 | CE_domain_losses_avg = 0.0 196 | KL_loss = 0.0 197 | 198 | for it, ((data, class_l, domain_l), d_idx) in enumerate(self.source_loader): 199 | if self.args.KL_Loss == 1: 200 | data = torch.cat((data, data)).to(self.device) 201 | class_l = torch.cat((class_l, class_l)).to(self.device) 202 | domain_l = torch.cat((domain_l, domain_l)).to(self.device) 203 | else: 204 | data = data.to(self.device) 205 | class_l = class_l.to(self.device) 206 | domain_l = domain_l.to(self.device) 207 | 208 | layer_drop_flag = self.select_layers(layer_wise_prob=self.layer_wise_prob) 209 | optimizer = self.optimizer_scatter 210 | 211 | class_logit, domain_logit = self.model(x=data, domain_labels=domain_l, layer_drop_flag=layer_drop_flag) 212 | class_loss = self.criterion(class_logit, class_l) 213 | CE_loss += class_loss 214 | domain_losses_avg = torch.tensor(0.0).to(device=self.device) 215 | 216 | if self.domain_discriminator_flag == 1: 217 | domain_losses = [] 218 | for i, logit in enumerate(domain_logit): 219 | domain_loss = self.domain_criterion(logit, domain_l) 220 | domain_losses.append(domain_loss) 221 | CE_domain_loss[i] += domain_loss 222 | domain_losses = torch.stack(domain_losses, dim=0) 223 | domain_losses_avg = domain_losses.mean(dim=0) 224 | CE_domain_losses_avg += domain_losses_avg 225 | 226 | loss = 0.0 227 | loss += class_loss 228 | if self.domain_loss_flag == 1: 229 | loss += domain_losses_avg 230 | if self.args.KL_Loss == 1: 231 | batch_size = int(class_logit.shape[0] / 2) 232 | class_logit_1 = class_logit[:batch_size] 233 | class_logit_2 = class_logit[batch_size:] 234 | kl_loss = compute_kl_loss(class_logit_1, class_logit_2, T=self.args.KL_Loss_T) 235 | loss += self.args.KL_Loss_weight * kl_loss 236 | KL_loss += kl_loss 237 | 238 | for opt in optimizer: 239 | opt.zero_grad() 240 | loss.backward() 241 | for opt in optimizer: 242 | opt.step() 243 | 244 | _, class_pred = class_logit.max(dim=1) 245 | class_right_batch = torch.sum(class_pred == class_l.data) 246 | class_right += class_right_batch 247 | 248 | domain_right_batch = [torch.tensor(0.0).cuda() for i in range(5)] 249 | if self.domain_discriminator_flag == 1: 250 | for i, logit in enumerate(domain_logit): 251 | _, domain_pred = logit.max(dim=1) 252 | domain_right_batch[i] = torch.sum(domain_pred == domain_l.data) 253 | domain_right[i] += domain_right_batch[i] 254 | batch_num += 1 255 | 256 | data_shape = data.shape[0] 257 | class_total += data_shape 258 | 259 | self.logger.log(it, len(self.source_loader), 260 | { 261 | "class": class_loss.item(), 262 | "domain": domain_losses_avg.item(), 263 | "loss": loss.item(), 264 | }, 265 | { 266 | "class": class_right_batch, 267 | }, data_shape) 268 | CE_loss = float(CE_loss) / batch_num 269 | CE_domain_losses_avg = float(CE_domain_losses_avg / batch_num) 270 | CE_domain_loss = [float(loss / batch_num) for loss in CE_domain_loss] 271 | 272 | class_acc = float(class_right) / class_total 273 | domain_acc = [float(right / class_total) for right in domain_right] 274 | 275 | KL_loss = float(KL_loss / batch_num) 276 | 277 | result_domain_acc = ", Domain Acc" 278 | result_domain_loss = ", Domain loss" 279 | if self.domain_discriminator_flag == 1: 280 | result_domain_loss += ", Avg: " + str(format(CE_domain_losses_avg, '.4f')) 281 | for i in range(5): 282 | result_domain_acc += ", L" + str(i) + ": " + str(format(domain_acc[i], ".4f")) 283 | result_domain_loss += ", L" + str(i) + ": " + str(format(CE_domain_loss[i], '.4f')) 284 | 285 | result = "train" + ": Epoch: " + str(epoch) \ 286 | + ", CELoss: " + str(format(CE_loss, '.4f')) \ 287 | + ", ACC: " + str(format(class_acc, '.4f')) \ 288 | + result_domain_loss \ 289 | + result_domain_acc \ 290 | + ", KL loss: " + str(format(KL_loss, '.4f')) \ 291 | + '\n' 292 | with open(self.base_result_path + "/" + "train" + ".txt", "a") as f: 293 | f.write(result) 294 | 295 | self.model.eval() 296 | with torch.no_grad(): 297 | val_test_acc = [] 298 | for phase, loader in self.test_loaders.items(): 299 | class_acc, CE_loss = self.do_test(loader) 300 | val_test_acc.append(class_acc) 301 | 302 | result = phase + ": Epoch: " + str(epoch) \ 303 | + ", CELoss: " + str(format(CE_loss, '.4f')) \ 304 | + ", ACC: " + str(format(class_acc, '.4f')) \ 305 | + "\n" 306 | with open(self.base_result_path + "/" + phase + ".txt", "a") as f: 307 | f.write(result) 308 | 309 | self.logger.log_test(phase, {"class": class_acc}) 310 | self.results[phase][self.current_epoch] = class_acc 311 | if val_test_acc[0] >= self.val_best: 312 | self.val_best = val_test_acc[0] 313 | self.save_model(mode="best") 314 | 315 | def do_eval(self, model_path): 316 | checkpoint = torch.load(model_path, map_location='cpu') 317 | self.model.load_state_dict(checkpoint, strict=False) 318 | self.model.eval() 319 | with torch.no_grad(): 320 | for phase, loader in self.test_loaders.items(): 321 | class_acc, CE_loss = self.do_test(loader) 322 | result = phase + ": CELoss: " + str(format(CE_loss, '.4f')) \ 323 | + ", ACC: " + str(format(class_acc, '.4f')) 324 | print(result) 325 | 326 | def save_model(self, mode="best"): 327 | model_path = self.base_result_path + "models/" 328 | if not os.path.exists(model_path): 329 | os.makedirs(model_path) 330 | model_name = "model_" + mode + ".pt" 331 | torch.save(self.model.state_dict(), os.path.join(model_path, model_name)) 332 | 333 | def do_test(self, loader): 334 | class_right = 0.0 335 | CE_loss = 0.0 336 | batch_num = 0 337 | for it, ((data, class_l, domain_l), _) in enumerate(loader): 338 | data, class_l = data.to(self.device), class_l.to(self.device) 339 | class_logit, _ = self.model(x=data, layer_drop_flag=[0, 0, 0, 0]) 340 | class_loss = self.criterion(class_logit, class_l) 341 | _, cls_pred = class_logit.max(dim=1) 342 | 343 | CE_loss += class_loss 344 | class_right += torch.sum(cls_pred == class_l.data) 345 | batch_num += 1 346 | CE_loss = float(CE_loss) / batch_num 347 | class_acc = float(class_right) / len(loader.dataset) 348 | return class_acc, CE_loss 349 | 350 | def do_training(self): 351 | self.logger = Logger(self.args, update_frequency=30) 352 | self.results = {"val": torch.zeros(self.args.epochs), "test": torch.zeros(self.args.epochs)} 353 | for self.current_epoch in range(self.args.epochs): 354 | start_time = time.time() 355 | self._do_epoch(self.current_epoch) 356 | self.logger.new_epoch(self.scheduler_scatter[0].get_last_lr()) 357 | for scl in self.scheduler_scatter: 358 | scl.step() 359 | end_time = time.time() 360 | print("Time for one epoch is " + str(format(end_time-start_time, '.0f')) + "s") 361 | self.save_model(mode="last") 362 | val_res = self.results["val"] 363 | test_res = self.results["test"] 364 | idx_best = val_res.argmax() 365 | line = "Best val %g, corresponding test %g, best val epoch: %g - best test: %g" % ( 366 | val_res.max(), test_res[idx_best], idx_best, test_res.max()) 367 | print(line) 368 | with open(self.base_result_path+"test.txt", "a") as f: 369 | f.write(line+"\n") 370 | self.logger.save_best(test_res[idx_best], test_res.max()) 371 | return self.logger, self.model 372 | 373 | 374 | domain_map = { 375 | 'PACS': ['photo', 'art_painting', 'cartoon', 'sketch'], 376 | 'PACS_random_split': ['photo', 'art_painting', 'cartoon', 'sketch'], 377 | 'OfficeHome': ['Art', 'Clipart', 'Product', 'RealWorld'], 378 | 'VLCS': ["CALTECH", "LABELME", "PASCAL", "SUN"], 379 | } 380 | 381 | classes_map = { 382 | 'PACS': 7, 383 | 'PACS_random_split': 7, 384 | 'OfficeHome': 65, 385 | 'VLCS': 5, 386 | } 387 | 388 | val_size_map = { 389 | 'PACS': 0.1, 390 | 'PACS_random_split': 0.1, 391 | 'OfficeHome': 0.1, 392 | 'VLCS': 0.3, 393 | } 394 | 395 | def setup_seed(seed): 396 | torch.manual_seed(seed) 397 | torch.cuda.manual_seed_all(seed) 398 | torch.cuda.manual_seed(seed) 399 | np.random.seed(seed) 400 | random.seed(seed) 401 | # torch.backends.cudnn.deterministic = True 402 | 403 | def get_domain(name): 404 | if name not in domain_map: 405 | raise ValueError('Name of dataset unknown %s' %name) 406 | return domain_map[name] 407 | 408 | def main(): 409 | args = get_args() 410 | 411 | domain = get_domain(args.data) 412 | args.target = domain.pop(args.target) 413 | args.source = domain 414 | print("Target domain: {}".format(args.target)) 415 | args.data_root = os.path.join(args.data_root, "PACS") if "PACS" in args.data else os.path.join(args.data_root, 416 | args.data) 417 | args.n_classes = classes_map[args.data] 418 | args.n_domains = len(domain) 419 | args.val_size = val_size_map[args.data] 420 | setup_seed(args.time) 421 | 422 | os.environ['CUDA_VISIBLE_DEVICES'] = str(args.device) 423 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 424 | 425 | trainer = Trainer(args, device) 426 | if args.eval: 427 | model_path = args.eval_model_path 428 | trainer.do_eval(model_path=model_path) 429 | return 430 | trainer.do_training() 431 | 432 | 433 | if __name__ == "__main__": 434 | torch.backends.cudnn.benchmark = True 435 | main() 436 | -------------------------------------------------------------------------------- /utils/Logger.py: -------------------------------------------------------------------------------- 1 | from time import time 2 | 3 | from os.path import join, dirname 4 | 5 | # from .tf_logger import TFLogger 6 | 7 | _log_path = join(dirname(__file__), '../logs') 8 | 9 | 10 | # high level wrapper for tf_logger.TFLogger 11 | class Logger(): 12 | def __init__(self, args, update_frequency=10): 13 | self.current_epoch = 0 14 | self.max_epochs = args.epochs 15 | self.last_update = time() 16 | self.start_time = time() 17 | self._clean_epoch_stats() 18 | self.update_f = update_frequency 19 | folder, logname = self.get_name_from_args(args) 20 | log_path = join(_log_path, folder, logname) 21 | if args.tf_logger: 22 | pass 23 | # self.tf_logger = TFLogger(log_path) 24 | # print("Saving to %s" % log_path) 25 | else: 26 | pass 27 | # self.tf_logger = None 28 | self.current_iter = 0 29 | 30 | def new_epoch(self, learning_rates): 31 | self.current_epoch += 1 32 | self.last_update = time() 33 | self.lrs = learning_rates 34 | print("New epoch - lr: %s" % ", ".join([str(lr) for lr in self.lrs])) 35 | self._clean_epoch_stats() 36 | # if self.tf_logger: 37 | # for n, v in enumerate(self.lrs): 38 | # self.tf_logger.scalar_summary("aux/lr%d" % n, v, self.current_iter) 39 | 40 | def log(self, it, iters, losses, samples_right, total_samples): 41 | self.current_iter += 1 42 | loss_string = ", ".join(["%s : %.3f" % (k, v) for k, v in losses.items()]) 43 | for k, v in samples_right.items(): 44 | past = self.epoch_stats.get(k, 0.0) 45 | self.epoch_stats[k] = past + v 46 | self.total += total_samples 47 | acc_string = ", ".join(["%s : %.2f" % (k, 100 * (v / total_samples)) for k, v in samples_right.items()]) 48 | if it % self.update_f == 0: 49 | print("%d/%d of epoch %d/%d %s - acc %s [bs:%d]" % (it, iters, self.current_epoch, self.max_epochs, loss_string, 50 | acc_string, total_samples)) 51 | # update tf log 52 | # if self.tf_logger: 53 | # for k, v in losses.items(): self.tf_logger.scalar_summary("train/loss_%s" % k, v, self.current_iter) 54 | 55 | def _clean_epoch_stats(self): 56 | self.epoch_stats = {} 57 | self.total = 0 58 | 59 | def log_test(self, phase, accuracies): 60 | print("Accuracies on %s: " % phase + ", ".join(["%s : %.2f" % (k, v * 100) for k, v in accuracies.items()])) 61 | # if self.tf_logger: 62 | # for k, v in accuracies.items(): self.tf_logger.scalar_summary("%s/acc_%s" % (phase, k), v, self.current_iter) 63 | 64 | def save_best(self, val_test, best_test): 65 | print("It took %g" % (time() - self.start_time)) 66 | # if self.tf_logger: 67 | # for x in range(10): 68 | # self.tf_logger.scalar_summary("best/from_val_test", val_test, x) 69 | # self.tf_logger.scalar_summary("best/max_test", best_test, x) 70 | 71 | @staticmethod 72 | def get_name_from_args(args): 73 | folder_name = "%s_to_%s" % ("-".join(sorted(args.source)), args.target) 74 | if args.folder_name: 75 | folder_name = join(args.folder_name, folder_name) 76 | name = "eps%d_bs%d_lr%g_class%d_jigClass%d_jigWeight%g" % (args.epochs, args.batch_size, args.learning_rate, args.n_classes, 77 | 30, 0.7) 78 | # if args.ooo_weight > 0: 79 | # name += "_oooW%g" % args.ooo_weight 80 | if args.train_all: 81 | name += "_TAll" 82 | if args.bias_whole_image: 83 | name += "_bias%g" % args.bias_whole_image 84 | if args.classify_only_sane: 85 | name += "_classifyOnlySane" 86 | if args.TTA: 87 | name += "_TTA" 88 | try: 89 | name += "_entropy%g_jig_tW%g" % (args.entropy_weight, args.target_weight) 90 | except AttributeError: 91 | pass 92 | if args.suffix: 93 | name += "_%s" % args.suffix 94 | name += "_%d" % int(time() % 1000) 95 | return folder_name, name 96 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lingeringlight/DomainDrop/a2d9607d4f1fb9ba35a9fa48fbceff82fd67d15f/utils/__init__.py -------------------------------------------------------------------------------- /utils/__pycache__/Logger.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lingeringlight/DomainDrop/a2d9607d4f1fb9ba35a9fa48fbceff82fd67d15f/utils/__pycache__/Logger.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lingeringlight/DomainDrop/a2d9607d4f1fb9ba35a9fa48fbceff82fd67d15f/utils/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/tf_logger.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lingeringlight/DomainDrop/a2d9607d4f1fb9ba35a9fa48fbceff82fd67d15f/utils/__pycache__/tf_logger.cpython-37.pyc -------------------------------------------------------------------------------- /utils/tf_logger.py: -------------------------------------------------------------------------------- 1 | # Code referenced from https://gist.github.com/gyglim/1f8dfb1b5c82627ae3efcfbbadb9f514 2 | import tensorflow as tf 3 | import numpy as np 4 | import scipy.misc 5 | try: 6 | from StringIO import StringIO # Python 2.7 7 | except ImportError: 8 | from io import BytesIO # Python 3.x 9 | 10 | 11 | class TFLogger(object): 12 | 13 | def __init__(self, log_dir): 14 | """Create a summary writer logging to log_dir.""" 15 | self.writer = tf.compat.v1.summary.FileWriter(log_dir) 16 | 17 | def scalar_summary(self, tag, value, step): 18 | """Log a scalar variable.""" 19 | summary = tf.compat.v1.Summary(value=[tf.compat.v1.Summary.Value(tag=tag, simple_value=value)]) 20 | self.writer.add_summary(summary, step) 21 | 22 | def image_summary(self, tag, images, step): 23 | """Log a list of images.""" 24 | 25 | img_summaries = [] 26 | for i, img in enumerate(images): 27 | # Write the image to a string 28 | try: 29 | s = StringIO() 30 | except: 31 | s = BytesIO() 32 | scipy.misc.toimage(img).save(s, format="png") 33 | 34 | # Create an Image object 35 | img_sum = tf.Summary.Image(encoded_image_string=s.getvalue(), 36 | height=img.shape[0], 37 | width=img.shape[1]) 38 | # Create a Summary value 39 | img_summaries.append(tf.Summary.Value(tag='%s/%d' % (tag, i), image=img_sum)) 40 | 41 | # Create and write Summary 42 | summary = tf.Summary(value=img_summaries) 43 | self.writer.add_summary(summary, step) 44 | 45 | def histo_summary(self, tag, values, step, bins=1000): 46 | """Log a histogram of the tensor of values.""" 47 | 48 | # Create a histogram using numpy 49 | counts, bin_edges = np.histogram(values, bins=bins) 50 | 51 | # Fill the fields of the histogram proto 52 | hist = tf.HistogramProto() 53 | hist.min = float(np.min(values)) 54 | hist.max = float(np.max(values)) 55 | hist.num = int(np.prod(values.shape)) 56 | hist.sum = float(np.sum(values)) 57 | hist.sum_squares = float(np.sum(values**2)) 58 | 59 | # Drop the start of the first bin 60 | bin_edges = bin_edges[1:] 61 | 62 | # Add bin edges and counts 63 | for edge in bin_edges: 64 | hist.bucket_limit.append(edge) 65 | for c in counts: 66 | hist.bucket.append(c) 67 | 68 | # Create and write Summary 69 | summary = tf.Summary(value=[tf.Summary.Value(tag=tag, histo=hist)]) 70 | self.writer.add_summary(summary, step) 71 | self.writer.flush() 72 | -------------------------------------------------------------------------------- /utils/tools.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision.utils import save_image 3 | import numpy as np 4 | 5 | def denorm(tensor, device, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]): 6 | std = torch.Tensor(std).reshape(-1, 1, 1).to(device) 7 | mean = torch.Tensor(mean).reshape(-1, 1, 1).to(device) 8 | res = torch.clamp(tensor * std + mean, 0, 1) 9 | return res 10 | 11 | 12 | def save_image_from_tensor_batch(batch, column, path, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], device='cpu'): 13 | batch = denorm(batch, device, mean, std) 14 | save_image(batch, path, nrow=column) 15 | 16 | 17 | def mean_teacher(model, teacher, momentum=0.9995): 18 | model_dict = model.state_dict() 19 | teacher_dict = teacher.state_dict() 20 | for k, v in teacher_dict.items(): 21 | teacher_dict[k] = v * momentum + (1 - momentum) * model_dict[k] 22 | 23 | teacher.load_state_dict(teacher_dict) 24 | 25 | 26 | def update_teacher(model, teacher, momentum=0.9995): 27 | for ema_param, param in zip(teacher.parameters(), model.parameters()): 28 | ema_param.data.mul_(momentum).add_(1 - momentum, param.data) 29 | 30 | 31 | def warm_update_teacher(model, teacher, momentum=0.9995, global_step=2000): 32 | momentum = min(1 - 1 / (global_step + 1), momentum) 33 | for ema_param, param in zip(teacher.parameters(), model.parameters()): 34 | ema_param.data.mul_(momentum).add_(1 - momentum, param.data) 35 | 36 | 37 | def preprocess_teacher(model, teacher): 38 | for param_m, param_t in zip(model.parameters(), teacher.parameters()): 39 | param_t.data.copy_(param_m.data) # initialize 40 | param_t.requires_grad = False # not update by gradient 41 | 42 | 43 | def calculate_correct(scores, labels): 44 | assert scores.size(0) == labels.size(0) 45 | _, pred = scores.max(dim=1) 46 | correct = torch.sum(pred.eq(labels)).item() 47 | return correct 48 | 49 | 50 | def sigmoid_rampup(current, rampup_length): 51 | """Exponential rampup from https://arxiv.org/abs/1610.02242""" 52 | if rampup_length == 0: 53 | return 1.0 54 | else: 55 | current = np.clip(current, 0.0, rampup_length) 56 | phase = 1.0 - current / rampup_length 57 | return float(np.exp(-5.0 * phase * phase)) 58 | 59 | 60 | def linear_rampup(current, rampup_length): 61 | """Linear rampup""" 62 | assert current >= 0 and rampup_length >= 0 63 | if current >= rampup_length: 64 | return 1.0 65 | else: 66 | return current / rampup_length 67 | 68 | 69 | def step_rampup(current, rampup_length): 70 | assert current >= 0 and rampup_length >= 0 71 | if current >= rampup_length: 72 | return 1.0 73 | else: 74 | return 0.0 75 | 76 | 77 | def get_current_consistency_weight(epoch, weight, rampup_length, rampup_type='step'): 78 | if rampup_type == 'step': 79 | rampup_func = step_rampup 80 | elif rampup_type == 'linear': 81 | rampup_func = linear_rampup 82 | elif rampup_type == 'sigmoid': 83 | rampup_func = sigmoid_rampup 84 | else: 85 | raise ValueError("Rampup schedule not implemented") 86 | 87 | return weight * rampup_func(epoch, rampup_length) 88 | 89 | 90 | def update_bn(loader, model, device=None, swa_bn_domaindrop=0): 91 | r"""Updates BatchNorm running_mean, running_var buffers in the model. 92 | 93 | It performs one pass over data in `loader` to estimate the activation 94 | statistics for BatchNorm layers in the model. 95 | Args: 96 | loader (torch.utils.data.DataLoader): dataset loader to compute the 97 | activation statistics on. Each data batch should be either a 98 | tensor, or a list/tuple whose first element is a tensor 99 | containing data. 100 | model (torch.nn.Module): model for which we seek to update BatchNorm 101 | statistics. 102 | device (torch.device, optional): If set, data will be transferred to 103 | :attr:`device` before being passed into :attr:`model`. 104 | 105 | Example: 106 | >>> loader, model = ... 107 | >>> torch.optim.swa_utils.update_bn(loader, model) 108 | 109 | .. note:: 110 | The `update_bn` utility assumes that each data batch in :attr:`loader` 111 | is either a tensor or a list or tuple of tensors; in the latter case it 112 | is assumed that :meth:`model.forward()` should be called on the first 113 | element of the list or tuple corresponding to the data batch. 114 | """ 115 | momenta = {} 116 | for module in model.modules(): 117 | if isinstance(module, torch.nn.modules.batchnorm._BatchNorm): 118 | module.running_mean = torch.zeros_like(module.running_mean) 119 | module.running_var = torch.ones_like(module.running_var) 120 | momenta[module] = module.momentum 121 | 122 | if not momenta: 123 | return 124 | 125 | was_training = model.training 126 | model.train() 127 | for module in momenta.keys(): 128 | module.momentum = None 129 | module.num_batches_tracked *= 0 130 | 131 | for input in loader: 132 | if isinstance(input, (list, tuple)): 133 | data = input[0][0] 134 | class_l = input[0][2] 135 | domain_labels = input[0][3] 136 | if device is not None: 137 | data = data.to(device) 138 | class_l = class_l.to(device) 139 | domain_labels = domain_labels.to(device) 140 | 141 | model(data, gt=class_l, domain_labels=domain_labels, swa_bn_domaindrop=swa_bn_domaindrop) 142 | 143 | for bn_module in momenta.keys(): 144 | bn_module.momentum = momenta[bn_module] 145 | model.train(was_training) -------------------------------------------------------------------------------- /utils/vis.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | 3 | def view_training(logger, title): 4 | fig, ax1 = plt.subplots() 5 | for k,v in logger.losses.items(): 6 | ax1.plot(v, label=k) 7 | l = len(v) 8 | updates = l / len(logger.val_acc["class"]) 9 | plt.legend() 10 | ax2 = ax1.twinx() 11 | for k,v in logger.val_acc.items(): 12 | ax2.plot(range(0,l,int(updates)), v, label="Test %s" % k) 13 | plt.legend() 14 | plt.title(title + " last acc %.2f:" % logger.val_acc["class"][-1]) 15 | plt.show() --------------------------------------------------------------------------------