├── ALOFT-E.sh ├── ALOFT-S.sh ├── ReadMe.md ├── data ├── FourierTransform.py ├── JigsawLoader.py ├── concat_dataset.py ├── data_helper.py └── samplers.py ├── engine.py ├── gfnet.py ├── losses.py ├── main_gfnet.py ├── samplers.py └── utils.py /ALOFT-E.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | device=0 4 | data='PACS' 5 | 6 | for t in `seq 0 4` 7 | do 8 | for domain in `seq 0 3` 9 | do 10 | python main_gfnet.py \ 11 | --target $domain \ 12 | --device $device \ 13 | --seed $t \ 14 | --batch-size 64 \ 15 | --data $data \ 16 | --epochs 50 \ 17 | --lr 0.0005 \ 18 | --data_root '/data/DataSets/' \ 19 | --noise_mode 1 \ 20 | --uncertainty_model 2 \ 21 | --uncertainty_factor 1.0 \ 22 | --mask_radio 0.5 \ 23 | --eval 0 \ 24 | --resume '' 25 | done 26 | done 27 | 28 | 29 | -------------------------------------------------------------------------------- /ALOFT-S.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | device=0 4 | data='PACS' 5 | 6 | for t in `seq 0 4` 7 | do 8 | for domain in `seq 0 3` 9 | do 10 | python main_gfnet.py \ 11 | --target $domain \ 12 | --device $device \ 13 | --seed $t \ 14 | --batch-size 64 \ 15 | --data $data \ 16 | --epochs 50 \ 17 | --lr 0.0005 \ 18 | --data_root '/data/DataSets/' \ 19 | --noise_mode 1 \ 20 | --uncertainty_model 1 \ 21 | --uncertainty_factor 0.9 \ 22 | --mask_radio 0.5 \ 23 | --eval 0 \ 24 | --resume '' 25 | done 26 | done 27 | 28 | 29 | -------------------------------------------------------------------------------- /ReadMe.md: -------------------------------------------------------------------------------- 1 | ## ALOFT: A Lightweight MLP-like Architecture with Dynamic Low-frequency Transform for Domain Generalization 2 | 3 | ### Requirements 4 | 5 | * Python == 3.7.3 6 | * torch >= 1.8 7 | * Cuda == 10.1 8 | * Torchvision == 0.4.2 9 | * timm == 0.4.12 10 | 11 | ### DataSets 12 | Please download PACS dataset from [here](https://drive.google.com/drive/folders/0B6x7gtvErXgfUU1WcGY5SzdwZVk?resourcekey=0-2fvpQY_QSyJf2uIECzqPuQ). 13 | Make sure you use the official train/val/test split in [PACS paper](https://openaccess.thecvf.com/content_iccv_2017/html/Li_Deeper_Broader_and_ICCV_2017_paper.html). 14 | Take `/data/DataSets/` as the saved directory for example: 15 | ``` 16 | images -> /data/DataSets/PACS/kfold/art_painting/dog/pic_001.jpg, ... 17 | splits -> /data/DataSets/PACS/pacs_label/art_painting_crossval_kfold.txt, ... 18 | ``` 19 | Then set the `"data_root"` as `"/data/DataSets/"` and `"data"` as `"PACS"` in `"main_gfnet.py"`. 20 | 21 | You can directly set the `"data_root"` and `"data"` in `"ALOFT-S.sh"`/`"ALOFT-S.sh"` for training the model. 22 | 23 | ### Evaluation 24 | 25 | To evaluate the performance of the models, you can download the models trained on PACS as below: 26 | 27 | | Methods | Acc (%) | models | 28 | | :-------------: | :-----: | :----------------------------------------------------------: | 29 | | Strong Baseline | 87.76 | [download](https://drive.google.com/drive/folders/1DJfGRSpFPmm1FD-sZRZK3ZObOE_-7Aaq?usp=share_link) | 30 | | ALOFT-S | 90.88 | [download](https://drive.google.com/drive/folders/1r2HXwe1O54GfQ9R3H-wL2xyR36YAqcpN?usp=share_link) | 31 | | ALOFT-E | 91.58 | [download](https://drive.google.com/drive/folders/1K80RPvOyw25bnAd5EGothqMTBL-YDCdm?usp=share_link) | 32 | 33 | Please set the `--eval = 1` and `--resume` as the saved path of the downloaded models. *e.g.*, `/trained/model/path/photo/checkpoint.pth`. Then you can simple run: 34 | 35 | ``` 36 | python main_gfnet.py --target $domain --data 'PACS' --device $device --eval 1 --resume '/trained/model/path/photo/checkpoint.pth' 37 | ``` 38 | 39 | ### Training 40 | 41 | Firstly download the GFNet-H-Ti model pretrained on ImageNet from [here](https://drive.google.com/file/d/1Nrq5sfHD9RklCMl6WkcVrAWI5vSVzwSm/view) and save it to `/pretrained_model`. To run ALOFT-E, you could run the following code. Please set the `--data_root` argument needs to be changed according to your folder. 42 | 43 | ``` 44 | bash ALOFT-E.sh 45 | ``` 46 | 47 | You can also train the ALOFT-S model by running the following code: 48 | 49 | ``` 50 | base ALOFT-S.sh 51 | ``` 52 | 53 | ### Acknowledgement 54 | Part of our code is borrowed from the following repositories. 55 | * [GFNet](https://github.com/raoyongming/GFNet): "Global Filter Networks for Image Classification", NeurIPS 2021 56 | * [DSU](https://github.com/lixiaotong97/DSU): "Uncertainty modeling for out-of-distribution generalization", ICLR 2022 57 | 58 | We thank to the authors for releasing their codes. Please also consider citing their works. 59 | -------------------------------------------------------------------------------- /data/FourierTransform.py: -------------------------------------------------------------------------------- 1 | from torchvision import transforms 2 | import random 3 | import torch 4 | import numpy as np 5 | from math import sqrt 6 | import os 7 | 8 | from PIL import Image 9 | 10 | 11 | 12 | def colorful_spectrum_mix(img1, img2, alpha, ratio=1.0, Fourier_swap=0, Fourier_phase=0): 13 | """Input image size: ndarray of [H, W, C]""" 14 | if Fourier_swap == 1: 15 | lam = 1.0 16 | else: 17 | lam = np.random.uniform(0, alpha) 18 | 19 | img1 = np.array(img1, dtype=float) 20 | img2 = np.array(img2, dtype=float) 21 | assert img1.shape == img2.shape 22 | h, w, c = img1.shape 23 | h_crop = int(h * sqrt(ratio)) 24 | w_crop = int(w * sqrt(ratio)) 25 | h_start = h // 2 - h_crop // 2 26 | w_start = w // 2 - w_crop // 2 27 | 28 | img1_fft = np.fft.fft2(img1, axes=(0, 1)) 29 | img2_fft = np.fft.fft2(img2, axes=(0, 1)) 30 | img1_abs, img1_pha = np.abs(img1_fft), np.angle(img1_fft) 31 | img2_abs, img2_pha = np.abs(img2_fft), np.angle(img2_fft) 32 | 33 | if Fourier_phase == 1: 34 | cont_abs_1 = img1_abs.mean() 35 | cont_abs_2 = img2_abs.mean() 36 | img1_pha = cont_abs_1 * (np.e ** (1j * img1_pha)) 37 | img2_pha = cont_abs_2 * (np.e ** (1j * img2_pha)) 38 | img1_pha = np.real(np.fft.ifft2(img1_pha, axes=(0, 1))) 39 | img2_pha = np.real(np.fft.ifft2(img2_pha, axes=(0, 1))) 40 | img1_pha = np.uint8(np.clip(img1_pha, 0, 255)) 41 | img2_pha = np.uint8(np.clip(img2_pha, 0, 255)) 42 | return img1_pha, img2_pha, lam 43 | # img1_pha = np.uint8(np.clip(img1_pha, 0, 255)) 44 | # img2_pha = np.uint8(np.clip(img2_pha, 0, 255)) 45 | # return img1_pha, img2_pha, lam 46 | 47 | img1_abs = np.fft.fftshift(img1_abs, axes=(0, 1)) 48 | img2_abs = np.fft.fftshift(img2_abs, axes=(0, 1)) 49 | 50 | img1_abs_ = np.copy(img1_abs) 51 | img2_abs_ = np.copy(img2_abs) 52 | img1_abs[h_start:h_start + h_crop, w_start:w_start + w_crop] = \ 53 | lam * img2_abs_[h_start:h_start + h_crop, w_start:w_start + w_crop] + (1 - lam) * img1_abs_[ 54 | h_start:h_start + h_crop, 55 | w_start:w_start + w_crop] 56 | img2_abs[h_start:h_start + h_crop, w_start:w_start + w_crop] = \ 57 | lam * img1_abs_[h_start:h_start + h_crop, w_start:w_start + w_crop] + (1 - lam) * img2_abs_[ 58 | h_start:h_start + h_crop, 59 | w_start:w_start + w_crop] 60 | 61 | img1_abs = np.fft.ifftshift(img1_abs, axes=(0, 1)) 62 | img2_abs = np.fft.ifftshift(img2_abs, axes=(0, 1)) 63 | 64 | img21 = img1_abs * (np.e ** (1j * img1_pha)) 65 | img12 = img2_abs * (np.e ** (1j * img2_pha)) 66 | img21 = np.real(np.fft.ifft2(img21, axes=(0, 1))) 67 | img12 = np.real(np.fft.ifft2(img12, axes=(0, 1))) 68 | img21 = np.uint8(np.clip(img21, 0, 255)) 69 | img12 = np.uint8(np.clip(img12, 0, 255)) 70 | 71 | return img21, img12, lam 72 | 73 | 74 | def filter_pass(img1, S=10, high_or_low=0): 75 | """Input image size: ndarray of [H, W, C]""" 76 | 77 | img1 = np.array(img1, dtype=float) 78 | h, w, c = img1.shape 79 | h_start = h // 2 - S // 2 80 | w_start = w // 2 - S // 2 81 | 82 | img1_fft = np.fft.fft2(img1, axes=(0, 1)) 83 | 84 | img1_abs, img1_pha = np.abs(img1_fft), np.angle(img1_fft) 85 | img1_abs = np.fft.fftshift(img1_abs, axes=(0, 1)) 86 | if high_or_low == 1: 87 | # low-pass 88 | # masks = torch.zeros_like(img1_abs) 89 | masks = np.zeros_like(img1_abs) 90 | h_start = h // 2 - S // 2 91 | w_start = w // 2 - S // 2 92 | masks[h_start:h_start + S, w_start:w_start + S, :] = 1 93 | img1_abs = img1_abs * masks 94 | else: 95 | # high-pass 96 | # masks = torch.ones_like(img1_abs) 97 | masks = np.ones_like(img1_abs) 98 | h_start = S // 2 99 | w_start = S // 2 100 | masks[h_start:(h_start + h - S), w_start:w_start + h - S, :] = 0 101 | img1_abs = img1_abs * masks 102 | img1_abs = np.fft.ifftshift(img1_abs, axes=(0, 1)) 103 | img1 = img1_abs * (np.e ** (1j * img1_pha)) 104 | 105 | img1 = np.real(np.fft.ifft2(img1, axes=(0, 1))) 106 | img1 = np.uint8(np.clip(img1, 0, 255)) 107 | return img1 108 | 109 | 110 | class FourierTransform: 111 | def __init__(self, args, alpha=1.0, dataset_list=None, base_dir=None): 112 | self.alpha = alpha 113 | # self.from_domain = args.from_domain 114 | # self.Fourier_swap = args.Fourier_swap 115 | # self.Fourier_phase = args.Fourier_phase 116 | self.dataset_list = dataset_list 117 | self.base_dir = base_dir 118 | self.dataset = args.data 119 | 120 | self.filter_flag = args.freq_analyse 121 | self.filter_S = args.freq_analyse_S 122 | self.high_or_low = args.freq_analyse_high_or_low 123 | 124 | domain_path = os.path.join(self.base_dir, self.dataset_list[0]) 125 | if self.dataset == "VLCS": 126 | domain_path = os.path.join(domain_path, "full") 127 | self.class_names = sorted(os.listdir(domain_path)) 128 | 129 | self.pre_transform = transforms.Compose( 130 | [transforms.RandomResizedCrop(args.image_size, scale=(args.min_scale, 1.0)), 131 | transforms.RandomHorizontalFlip(args.random_horiz_flip), 132 | transforms.ColorJitter(brightness=args.jitter, contrast=args.jitter, saturation=args.jitter, 133 | hue=min(0.5, args.jitter)) 134 | ] 135 | ) 136 | self.post_transform = transforms.Compose([ 137 | transforms.ToTensor(), 138 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 139 | ]) 140 | 141 | def __call__(self, img, domain_label): 142 | img_o = self.pre_transform(img) 143 | if self.filter_flag == 1: 144 | img_s2o = filter_pass(img_o, S=self.filter_S, high_or_low=self.high_or_low) 145 | domain_s = None 146 | lam = None 147 | else: 148 | img_s, label_s, domain_s = self.sample_image(domain_label) 149 | img_s2o, img_o2s, lam = colorful_spectrum_mix(img_o, img_s, alpha=self.alpha, Fourier_swap=self.Fourier_swap, 150 | Fourier_phase=self.Fourier_phase) 151 | img_s2o = self.post_transform(img_s2o) 152 | 153 | return img_s2o, domain_s, lam 154 | 155 | def sample_image(self, domain_label): 156 | if self.from_domain == 'all': 157 | domain_idx = random.randint(0, len(self.dataset_list) - 1) 158 | elif self.from_domain == 'inter': 159 | domains = list(range(len(self.dataset_list))) 160 | domains.remove(domain_label) 161 | domain_idx = random.sample(domains, 1)[0] 162 | elif self.from_domain == 'intra': 163 | domain_idx = domain_label 164 | else: 165 | raise ValueError("Not implemented") 166 | other_domain_name = self.dataset_list[domain_idx] 167 | class_idx = random.randint(0, len(self.class_names)-1) 168 | other_class_name = self.class_names[class_idx] 169 | base_dir_domain = os.path.join(self.base_dir, other_domain_name) 170 | if self.dataset == "VLCS": 171 | base_dir_domain = os.path.join(base_dir_domain, "full") 172 | base_dir_domain_class = os.path.join(base_dir_domain, other_class_name) 173 | other_id = np.random.choice(os.listdir(base_dir_domain_class)) 174 | other_img = Image.open(os.path.join(base_dir_domain_class, other_id)).convert('RGB') 175 | 176 | return self.pre_transform(other_img), class_idx, domain_idx -------------------------------------------------------------------------------- /data/JigsawLoader.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.utils.data as data 4 | import torchvision 5 | import torchvision.transforms as transforms 6 | from PIL import Image 7 | from random import sample, random 8 | import sys 9 | import os 10 | 11 | 12 | def get_random_subset(names, labels, percent): 13 | """ 14 | 15 | :param names: list of names 16 | :param labels: list of labels 17 | :param percent: 0 < float < 1 18 | :return: 19 | """ 20 | samples = len(names) 21 | amount = int(samples * percent) 22 | random_index = sample(range(samples), amount) 23 | name_val = [names[k] for k in random_index] 24 | name_train = [v for k, v in enumerate(names) if k not in random_index] 25 | labels_val = [labels[k] for k in random_index] 26 | labels_train = [v for k, v in enumerate(labels) if k not in random_index] 27 | return name_train, name_val, labels_train, labels_val 28 | 29 | 30 | def _dataset_info(txt_labels): 31 | # read from the official split txt 32 | file_names = [] 33 | labels = [] 34 | 35 | # with open(txt_labels, 'r') as f: 36 | # images_list = f.readlines() 37 | # for row in images_list: 38 | # row = row.split(' ') 39 | # file_names.append(row[0]) 40 | # labels.append(int(row[1])) 41 | 42 | for row in open(txt_labels, 'r'): 43 | row = row.split(' ') 44 | file_names.append(row[0]) 45 | labels.append(int(row[1])) 46 | 47 | return file_names, labels 48 | 49 | 50 | def find_classes(dir_name): 51 | if sys.version_info >= (3, 5): 52 | # Faster and available in Python 3.5 and above 53 | classes = [d.name for d in os.scandir(dir_name) if d.is_dir()] 54 | else: 55 | classes = [d for d in os.listdir(dir_name) if os.path.isdir(os.path.join(dir_name, d))] 56 | classes.sort() 57 | class_to_idx = {classes[i]: i+1 for i in range(len(classes))} 58 | return classes, class_to_idx 59 | 60 | 61 | def get_split_domain_info_from_dir(domain_path, dataset_name=None, val_percentage=None, domain_label=None): 62 | # read from the directory 63 | domain_name = domain_path.split("/")[-1] 64 | if dataset_name == "VLCS": 65 | name_train, name_val, labels_train, labels_val = [], [], [], [] 66 | classes, class_to_idx = find_classes(domain_path + "/full") 67 | # full为train 68 | for i, item in enumerate(classes): 69 | class_path = domain_path + "/" + "full" + "/" + item 70 | for root, _, fnames in sorted(os.walk(class_path)): 71 | for fname in sorted(fnames): 72 | path = os.path.join(domain_name, "full", item, fname) 73 | name_train.append(path) 74 | labels_train.append(class_to_idx[item]) 75 | # test为val 76 | for i, item in enumerate(classes): 77 | class_path = domain_path + "/" + "test" + "/" + item 78 | for root, _, fnames in sorted(os.walk(class_path)): 79 | for fname in sorted(fnames): 80 | path = os.path.join(domain_name, "test", item, fname) 81 | name_val.append(path) 82 | labels_val.append(class_to_idx[item]) 83 | domain_label_train = [domain_label for i in range(len(labels_train))] 84 | domain_label_val = [domain_label for i in range(len(labels_val))] 85 | return name_train, name_val, labels_train, labels_val, domain_label_train, domain_label_val 86 | 87 | elif dataset_name == "digits_dg": 88 | name_train, name_val, labels_train, labels_val = [], [], [], [] 89 | classes, class_to_idx = find_classes(domain_path + "/train") 90 | # train 91 | for i, item in enumerate(classes): 92 | class_path = domain_path + "/" + "train" + "/" + item 93 | for root, _, fnames in sorted(os.walk(class_path)): 94 | for fname in sorted(fnames): 95 | path = os.path.join(domain_name, "train", item, fname) 96 | name_train.append(path) 97 | labels_train.append(class_to_idx[item]) 98 | # val 99 | for i, item in enumerate(classes): 100 | class_path = domain_path + "/" + "val" + "/" + item 101 | for root, _, fnames in sorted(os.walk(class_path)): 102 | for fname in sorted(fnames): 103 | path = os.path.join(domain_name, "val", item, fname) 104 | name_val.append(path) 105 | labels_val.append(class_to_idx[item]) 106 | 107 | # names = name_train + name_val 108 | # labels = labels_train + labels_val 109 | # name_train, name_val, labels_train, labels_val = get_random_subset(names, labels, val_percentage) 110 | 111 | domain_label_train = [domain_label for i in range(len(labels_train))] 112 | domain_label_val = [domain_label for i in range(len(labels_val))] 113 | return name_train, name_val, labels_train, labels_val, domain_label_train, domain_label_val 114 | 115 | elif dataset_name == "OfficeHome" or "PACS" in dataset_name: 116 | names, labels = [], [] 117 | classes, class_to_idx = find_classes(domain_path) 118 | for i, item in enumerate(classes): 119 | class_path = domain_path + "/" + item 120 | for root, _, fnames in sorted(os.walk(class_path)): 121 | for fname in sorted(fnames): 122 | path = os.path.join(domain_name, item, fname) 123 | names.append(path) 124 | labels.append(class_to_idx[item]) 125 | name_train, name_val, labels_train, labels_val = get_random_subset(names, labels, val_percentage) 126 | domain_label_train = [domain_label for i in range(len(labels_train))] 127 | domain_label_val = [domain_label for i in range(len(labels_val))] 128 | return name_train, name_val, labels_train, labels_val, domain_label_train, domain_label_val 129 | 130 | else: 131 | raise ValueError("dataset is wrong.") 132 | 133 | 134 | def get_split_dataset_info_from_txt(txt_path, domain, domain_label, val_percentage=None): 135 | if "PACS" in txt_path: 136 | train_name = "_train_kfold.txt" 137 | val_name = "_crossval_kfold.txt" 138 | 139 | train_txt = txt_path + "/" + domain + train_name 140 | val_txt = txt_path + "/" + domain + val_name 141 | 142 | train_names, train_labels = _dataset_info(train_txt) 143 | val_names, val_labels = _dataset_info(val_txt) 144 | train_domain_labels = [domain_label for i in range(len(train_labels))] 145 | val_domain_labels = [domain_label for i in range(len(val_labels))] 146 | return train_names, val_names, train_labels, val_labels, train_domain_labels, val_domain_labels 147 | 148 | elif "miniDomainNet" in txt_path: 149 | # begin at 0, need to add 1 150 | train_name = "_train.txt" 151 | val_name = "_test.txt" 152 | train_txt = txt_path + "/" + domain + train_name 153 | val_txt = txt_path + "/" + domain + val_name 154 | 155 | train_names, train_labels = _dataset_info(train_txt) 156 | val_names, val_labels = _dataset_info(val_txt) 157 | train_labels = [label + 1 for label in train_labels] 158 | val_labels = [label + 1 for label in val_labels] 159 | 160 | names = train_names + val_names 161 | labels = train_labels + val_labels 162 | train_names, val_names, train_labels, val_labels = get_random_subset(names, labels, val_percentage) 163 | 164 | train_domain_labels = [domain_label for i in range(len(train_labels))] 165 | val_domain_labels = [domain_label for i in range(len(val_labels))] 166 | return train_names, val_names, train_labels, val_labels, train_domain_labels, val_domain_labels 167 | else: 168 | raise NotImplementedError 169 | 170 | 171 | def get_split_dataset_info(txt_list, val_percentage): 172 | names, labels = _dataset_info(txt_list) 173 | return get_random_subset(names, labels, val_percentage) 174 | 175 | 176 | # 原始Jigsaw 177 | class JigsawDataset(data.Dataset): 178 | def __init__(self, names, labels, jig_classes=100, img_transformer=None, tile_transformer=None, patches=True, bias_whole_image=None): 179 | self.data_path = "" 180 | self.names = names 181 | self.labels = labels 182 | 183 | self.N = len(self.names) 184 | self.permutations = self.__retrieve_permutations(jig_classes) 185 | self.grid_size = 3 186 | self.bias_whole_image = bias_whole_image 187 | if patches: 188 | self.patch_size = 64 189 | self._image_transformer = img_transformer 190 | self._augment_tile = tile_transformer 191 | if patches: 192 | self.returnFunc = lambda x: x 193 | else: 194 | def make_grid(x): 195 | return torchvision.utils.make_grid(x, self.grid_size, padding=0) 196 | self.returnFunc = make_grid 197 | 198 | def get_tile(self, img, n): 199 | w = float(img.size[0]) / self.grid_size 200 | y = int(n / self.grid_size) 201 | x = n % self.grid_size 202 | tile = img.crop([x * w, y * w, (x + 1) * w, (y + 1) * w]) 203 | tile = self._augment_tile(tile) 204 | return tile 205 | 206 | def get_image(self, index): 207 | framename = self.data_path + '/' + self.names[index] 208 | img = Image.open(framename).convert('RGB') 209 | return self._image_transformer(img) 210 | 211 | def __getitem__(self, index): 212 | img = self.get_image(index) 213 | n_grids = self.grid_size ** 2 214 | tiles = [None] * n_grids 215 | for n in range(n_grids): 216 | tiles[n] = self.get_tile(img, n) 217 | 218 | order = np.random.randint(len(self.permutations) + 1) # added 1 for class 0: unsorted 219 | if self.bias_whole_image: 220 | if self.bias_whole_image > random(): 221 | order = 0 222 | if order == 0: 223 | data = tiles 224 | else: 225 | data = [tiles[self.permutations[order - 1][t]] for t in range(n_grids)] 226 | 227 | data = torch.stack(data, 0) 228 | return self.returnFunc(data), int(order), int(self.labels[index]) 229 | 230 | def __len__(self): 231 | return len(self.names) 232 | 233 | def __retrieve_permutations(self, classes): 234 | all_perm = np.load('permutations_%d.npy' % (classes)) 235 | # from range [1,9] to [0,8] 236 | if all_perm.min() == 1: 237 | all_perm = all_perm - 1 238 | 239 | return all_perm 240 | 241 | 242 | class JigsawTestDataset(JigsawDataset): 243 | def __init__(self, *args, **xargs): 244 | super().__init__(*args, **xargs) 245 | 246 | def __getitem__(self, index): 247 | framename = self.data_path + '/' + self.names[index] 248 | img = Image.open(framename).convert('RGB') 249 | return self._image_transformer(img), 0, int(self.labels[index]) 250 | 251 | 252 | class JigsawTestDatasetMultiple(JigsawDataset): 253 | def __init__(self, *args, **xargs): 254 | super().__init__(*args, **xargs) 255 | self._image_transformer = transforms.Compose([ 256 | transforms.Resize(255, Image.BILINEAR), 257 | ]) 258 | self._image_transformer_full = transforms.Compose([ 259 | transforms.Resize(225, Image.BILINEAR), 260 | transforms.ToTensor(), 261 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 262 | ]) 263 | self._augment_tile = transforms.Compose([ 264 | transforms.Resize((75, 75), Image.BILINEAR), 265 | transforms.ToTensor(), 266 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 267 | ]) 268 | 269 | def __getitem__(self, index): 270 | framename = self.data_path + '/' + self.names[index] 271 | _img = Image.open(framename).convert('RGB') 272 | img = self._image_transformer(_img) 273 | 274 | w = float(img.size[0]) / self.grid_size 275 | n_grids = self.grid_size ** 2 276 | images = [] 277 | jig_labels = [] 278 | tiles = [None] * n_grids 279 | for n in range(n_grids): 280 | y = int(n / self.grid_size) 281 | x = n % self.grid_size 282 | tile = img.crop([x * w, y * w, (x + 1) * w, (y + 1) * w]) 283 | tile = self._augment_tile(tile) 284 | tiles[n] = tile 285 | for order in range(0, len(self.permutations)+1, 3): 286 | if order==0: 287 | data = tiles 288 | else: 289 | data = [tiles[self.permutations[order-1][t]] for t in range(n_grids)] 290 | data = self.returnFunc(torch.stack(data, 0)) 291 | images.append(data) 292 | jig_labels.append(order) 293 | images = torch.stack(images, 0) 294 | jig_labels = torch.LongTensor(jig_labels) 295 | return images, jig_labels, int(self.labels[index]) 296 | 297 | 298 | class JigsawNewDataset(data.Dataset): 299 | def __init__(self, names, labels, domain_labels, dataset_path, jig_classes=100, img_transformer=None, 300 | tile_transformer=None, patches=True, bias_whole_image=None): 301 | self.data_path = dataset_path 302 | 303 | self.names = names 304 | self.labels = labels 305 | self.domain_labels = domain_labels 306 | 307 | self.N = len(self.names) 308 | # self.permutations = self.__retrieve_permutations(jig_classes) 309 | self.grid_size = 3 310 | self.bias_whole_image = bias_whole_image 311 | if patches: 312 | self.patch_size = 64 313 | self._image_transformer = img_transformer 314 | self._augment_tile = tile_transformer 315 | if patches: 316 | self.returnFunc = lambda x: x 317 | else: 318 | def make_grid(x): 319 | return torchvision.utils.make_grid(x, self.grid_size, padding=0) 320 | 321 | self.returnFunc = make_grid 322 | 323 | def get_tile(self, img, n): 324 | w = float(img.size[0]) / self.grid_size 325 | y = int(n / self.grid_size) 326 | x = n % self.grid_size 327 | tile = img.crop([x * w, y * w, (x + 1) * w, (y + 1) * w]) 328 | tile = self._augment_tile(tile) 329 | return tile 330 | 331 | def get_image(self, index): 332 | framename = self.data_path + '/' + self.names[index] 333 | img = Image.open(framename).convert('RGB') 334 | return self._image_transformer(img) 335 | 336 | def __getitem__(self, index): 337 | framename = self.data_path + '/' + self.names[index] 338 | img = Image.open(framename).convert('RGB') 339 | # image, image_randaug, label, domain 340 | # return self._image_transformer(img), 0, int(self.labels[index] - 1), int(self.domain_labels[index] - 1) 341 | return self._image_transformer(img), int(self.labels[index] - 1) 342 | 343 | def __len__(self): 344 | return len(self.names) 345 | 346 | def __retrieve_permutations(self, classes): 347 | all_perm = np.load('permutations_%d.npy' % (classes)) 348 | # from range [1,9] to [0,8] 349 | if all_perm.min() == 1: 350 | all_perm = all_perm - 1 351 | return all_perm 352 | 353 | 354 | class JigsawTestNewDataset(JigsawNewDataset): 355 | def __init__(self, *args, **xargs): 356 | super().__init__(*args, **xargs) 357 | 358 | def __getitem__(self, index): 359 | framename = self.data_path + '/' + self.names[index] 360 | img = Image.open(framename).convert('RGB') 361 | # return self._image_transformer(img), 0, int(self.labels[index] - 1), int(self.domain_labels[index] - 1) 362 | return self._image_transformer(img), int(self.labels[index] - 1) 363 | 364 | 365 | # from .randaug import RandAugment 366 | # 367 | # class JigsawDatasetRandAug(data.Dataset): 368 | # def __init__(self, names, labels, domain_labels, patches=True, dataset_path=None, img_transformer=None, 369 | # bias_whole_image=None, args=None): 370 | # self.data_path = dataset_path 371 | # 372 | # self.names = names 373 | # self.labels = labels 374 | # self.domain_labels = domain_labels 375 | # 376 | # self.N = len(self.names) 377 | # self.grid_size = 3 378 | # self.bias_whole_image = bias_whole_image 379 | # if patches: 380 | # self.patch_size = 64 381 | # # self._image_transformer = img_transformer 382 | # self._image_transformer = img_transformer 383 | # self._image_transformer_aug = RandAugment(args) 384 | # # self._image_transformer_val = img_transformer_val 385 | # 386 | # # def get_image(self, index): 387 | # # framename = self.data_path + '/' + self.names[index] 388 | # # img = Image.open(framename).convert('RGB') 389 | # # return self._image_transformer(img), self._image_transformer_aug(img) 390 | # 391 | # def __getitem__(self, index): 392 | # framename = self.data_path + '/' + self.names[index] 393 | # img = Image.open(framename).convert('RGB') 394 | # img_randaug, _ = self._image_transformer_aug(img) 395 | # # img_randaug = self._image_transformer_val(img_randaug) 396 | # return self._image_transformer(img), img_randaug, int(self.labels[index] - 1), int(self.domain_labels[index]-1) 397 | # 398 | # def __len__(self): 399 | # return len(self.names) 400 | # 401 | # class JigsawTestDatasetRandAug(JigsawDatasetRandAug): 402 | # def __init__(self, *args, **xargs): 403 | # super().__init__(*args, **xargs) 404 | # 405 | # def __getitem__(self, index): 406 | # framename = self.data_path + '/' + self.names[index] 407 | # img = Image.open(framename).convert('RGB') 408 | # return self._image_transformer(img), 0, int(self.labels[index] - 1), int(self.domain_labels[index] - 1) 409 | # 410 | # 411 | from .FourierTransform import FourierTransform 412 | 413 | class JigsawDatasetFourier(data.Dataset): 414 | def __init__(self, names, labels, domain_labels, dataset_path=None, img_transformer=None, args=None, 415 | dataset_list=None): 416 | self.data_path = dataset_path 417 | 418 | self.names = names 419 | self.labels = labels 420 | self.domain_labels = domain_labels 421 | 422 | self._image_transformer = img_transformer 423 | self._image_transformer_aug = FourierTransform(args, dataset_list=dataset_list, base_dir=dataset_path) 424 | self.Fourier_swap = args.Fourier_swap 425 | # self._image_transformer_val = img_transformer_val 426 | 427 | def __getitem__(self, index): 428 | framename = self.data_path + '/' + self.names[index] 429 | img = Image.open(framename).convert('RGB') 430 | class_label = int(self.labels[index] - 1) 431 | domain_label = int(self.domain_labels[index]-1) 432 | img_randaug, domain_s, lam = self._image_transformer_aug(img, domain_label) 433 | # img_randaug = self._image_transformer_val(img_randaug) 434 | if self.Fourier_swap == 1: 435 | domain_label = [domain_s, domain_label] 436 | else: 437 | domain_label = [domain_label, domain_label] 438 | return self._image_transformer(img), img_randaug, class_label, domain_label 439 | 440 | def __len__(self): 441 | return len(self.names) 442 | 443 | class JigsawTestDatasetFourier(JigsawDatasetFourier): 444 | def __init__(self, *args, **xargs): 445 | super().__init__(*args, **xargs) 446 | 447 | def __getitem__(self, index): 448 | framename = self.data_path + '/' + self.names[index] 449 | img = Image.open(framename).convert('RGB') 450 | return self._image_transformer(img), 0, int(self.labels[index] - 1), int(self.domain_labels[index] - 1) 451 | 452 | class JigsawTestDatasetFreqAnalyse(JigsawDatasetFourier): 453 | def __init__(self, *args, **xargs): 454 | super().__init__(*args, **xargs) 455 | 456 | def __getitem__(self, index): 457 | framename = self.data_path + '/' + self.names[index] 458 | img = Image.open(framename).convert('RGB') 459 | class_label = int(self.labels[index] - 1) 460 | domain_label = int(self.domain_labels[index]-1) 461 | img_aug, domain_s, lam = self._image_transformer_aug(img, domain_label) 462 | return img_aug, class_label 463 | # 464 | # 465 | # from .Tobias import Tobias 466 | # 467 | # class JigsawDatasetTobias(data.Dataset): 468 | # def __init__(self, names, labels, domain_labels, dataset_path=None, img_transformer=None, args=None, 469 | # dataset_list=None): 470 | # self.data_path = dataset_path 471 | # 472 | # self.names = names 473 | # self.labels = labels 474 | # self.domain_labels = domain_labels 475 | # 476 | # self._image_transformer = img_transformer 477 | # self._image_transformer_aug = Tobias(args, dataset_list=dataset_list, base_dir=dataset_path) 478 | # 479 | # def __getitem__(self, index): 480 | # framename = self.data_path + '/' + self.names[index] 481 | # img = Image.open(framename).convert('RGB') 482 | # class_label = int(self.labels[index] - 1) 483 | # domain_label = int(self.domain_labels[index]-1) 484 | # img_randaug, domain_s = self._image_transformer_aug(img=img, img_name=self.names[index], 485 | # domain_label=domain_label) 486 | # return self._image_transformer(img), img_randaug, class_label, domain_label 487 | # 488 | # def __len__(self): 489 | # return len(self.names) 490 | # 491 | # 492 | # class JigsawTestDatasetTobias(JigsawDatasetTobias): 493 | # def __init__(self, *args, **xargs): 494 | # super().__init__(*args, **xargs) 495 | # 496 | # def __getitem__(self, index): 497 | # framename = self.data_path + '/' + self.names[index] 498 | # img = Image.open(framename).convert('RGB') 499 | # return self._image_transformer(img), 0, int(self.labels[index] - 1), int(self.domain_labels[index] - 1) 500 | -------------------------------------------------------------------------------- /data/concat_dataset.py: -------------------------------------------------------------------------------- 1 | import bisect 2 | import warnings 3 | 4 | from torch.utils.data import Dataset 5 | 6 | # This is a small variant of the ConcatDataset class, which also returns dataset index 7 | from data.JigsawLoader import JigsawTestDatasetMultiple 8 | 9 | 10 | class ConcatDataset(Dataset): 11 | """ 12 | Dataset to concatenate multiple datasets. 13 | Purpose: useful to assemble different existing datasets, possibly 14 | large-scale datasets as the concatenation operation is done in an 15 | on-the-fly manner. 16 | 17 | Arguments: 18 | datasets (sequence): List of datasets to be concatenated 19 | """ 20 | 21 | @staticmethod 22 | def cumsum(sequence): 23 | r, s = [], 0 24 | for e in sequence: 25 | l = len(e) 26 | r.append(l + s) 27 | s += l 28 | return r 29 | 30 | def isMulti(self): 31 | return isinstance(self.datasets[0], JigsawTestDatasetMultiple) 32 | 33 | def __init__(self, datasets): 34 | super(ConcatDataset, self).__init__() 35 | assert len(datasets) > 0, 'datasets should not be an empty iterable' 36 | self.datasets = list(datasets) 37 | self.cumulative_sizes = self.cumsum(self.datasets) 38 | 39 | def __len__(self): 40 | return self.cumulative_sizes[-1] 41 | 42 | def __getitem__(self, idx): 43 | dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) 44 | if dataset_idx == 0: 45 | sample_idx = idx 46 | else: 47 | sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] 48 | return self.datasets[dataset_idx][sample_idx], dataset_idx 49 | 50 | @property 51 | def cummulative_sizes(self): 52 | warnings.warn("cummulative_sizes attribute is renamed to " 53 | "cumulative_sizes", DeprecationWarning, stacklevel=2) 54 | return self.cumulative_sizes 55 | -------------------------------------------------------------------------------- /data/data_helper.py: -------------------------------------------------------------------------------- 1 | from os.path import join 2 | 3 | from torch.utils.data import DataLoader 4 | 5 | from data.JigsawLoader import * 6 | from data.concat_dataset import ConcatDataset 7 | from data.JigsawLoader import JigsawNewDataset, JigsawTestNewDataset 8 | 9 | 10 | class Subset(torch.utils.data.Dataset): 11 | def __init__(self, dataset, limit): 12 | indices = torch.randperm(len(dataset))[:limit] 13 | self.dataset = dataset 14 | self.indices = indices 15 | 16 | def __getitem__(self, idx): 17 | return self.dataset[self.indices[idx]] 18 | 19 | def __len__(self): 20 | return len(self.indices) 21 | 22 | 23 | def get_train_dataloader(args, patches): 24 | dataset_list = args.source 25 | assert isinstance(dataset_list, list) 26 | datasets = [] 27 | val_datasets = [] 28 | 29 | img_transformer, tile_transformer = get_train_transformers(args) 30 | img_transformer_val = get_val_transformer(args) 31 | 32 | limit = None 33 | 34 | if "PACS" in args.data_root: 35 | dataset_path = join(args.data_root, "kfold") 36 | elif args.data == "miniDomainNet": 37 | dataset_path = "/data/DataSets/" + "DomainNet" 38 | else: 39 | dataset_path = args.data_root 40 | 41 | for i, dname in enumerate(dataset_list): 42 | if args.data == "PACS": 43 | name_train, name_val, labels_train, labels_val, domain_labels_train, domain_labels_val = \ 44 | get_split_dataset_info_from_txt(txt_path=join(args.data_root, "pacs_label"), domain=dname, 45 | domain_label=i+1) 46 | # get_split_dataset_info_from_txt(txt_path=join(args.data_root, "splits"), domain=dname, 47 | # domain_label=i + 1) 48 | elif args.data == "miniDomainNet": 49 | name_train, name_val, labels_train, labels_val, domain_labels_train, domain_labels_val = \ 50 | get_split_dataset_info_from_txt(txt_path=args.data_root, domain=dname, domain_label=i+1, 51 | val_percentage=args.val_size) 52 | else: 53 | name_train, name_val, labels_train, labels_val, domain_labels_train, domain_labels_val = \ 54 | get_split_domain_info_from_dir(join(dataset_path, dname), dataset_name=args.data, 55 | val_percentage=args.val_size, domain_label=i+1) 56 | 57 | # if args.RandAug_flag == 1: 58 | # train_dataset = JigsawDatasetRandAug(name_train, labels_train, domain_labels_train, 59 | # dataset_path=dataset_path, patches=patches, 60 | # img_transformer=img_transformer, 61 | # bias_whole_image=args.bias_whole_image, args=args) 62 | # elif args.Fourier_flag == 1: 63 | # train_dataset = JigsawDatasetFourier(name_train, labels_train, domain_labels_train, 64 | # dataset_path=dataset_path, img_transformer=img_transformer, args=args, 65 | # dataset_list=dataset_list) 66 | # elif args.tobias_flag == 1: 67 | # train_dataset = JigsawDatasetTobias(name_train, labels_train, domain_labels_train, 68 | # dataset_path=dataset_path, img_transformer=img_transformer, args=args, 69 | # dataset_list=dataset_list) 70 | # else: 71 | train_dataset = JigsawNewDataset(name_train, labels_train, domain_labels_train, 72 | dataset_path=dataset_path, patches=patches, 73 | img_transformer=img_transformer, tile_transformer=tile_transformer, 74 | jig_classes=30) 75 | if limit: 76 | train_dataset = Subset(train_dataset, limit) 77 | datasets.append(train_dataset) 78 | if args.freq_analyse == 1: 79 | val_datasets.append( 80 | JigsawTestDatasetFreqAnalyse(name_val, labels_val, domain_labels_val, dataset_path=dataset_path, 81 | img_transformer=img_transformer_val, args=args, dataset_list=dataset_list)) 82 | else: 83 | val_datasets.append( 84 | JigsawTestNewDataset(name_val, labels_val, domain_labels_val, dataset_path=dataset_path, 85 | img_transformer=img_transformer_val, patches=patches, jig_classes=30)) 86 | dataset = ConcatDataset(datasets) 87 | val_dataset = ConcatDataset(val_datasets) 88 | 89 | loader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=4, 90 | pin_memory=True, drop_last=True) 91 | val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=True, num_workers=4, 92 | pin_memory=True, drop_last=False) 93 | return loader, val_loader 94 | 95 | 96 | def get_val_dataloader(args, patches=False, tSNE_flag=0): 97 | if "PACS" in args.data_root: 98 | dataset_path = join(args.data_root, "kfold") 99 | elif args.data == "miniDomainNet": 100 | dataset_path = "/data/DataSets/" + "DomainNet" 101 | else: 102 | dataset_path = args.data_root 103 | 104 | if args.data == "miniDomainNet": 105 | name_train, name_val, labels_train, labels_val, domain_label_train, domain_label_val = \ 106 | get_split_dataset_info_from_txt(txt_path=args.data_root, domain=args.target, domain_label=0, 107 | val_percentage=args.val_size) 108 | else: 109 | name_train, name_val, labels_train, labels_val, domain_label_train, domain_label_val = get_split_domain_info_from_dir( 110 | join(dataset_path, args.target), dataset_name=args.data, val_percentage=args.val_size, domain_label=0) 111 | 112 | if tSNE_flag == 0: 113 | names = name_train + name_val 114 | labels = labels_train + labels_val 115 | domain_label = domain_label_train + domain_label_val 116 | else: 117 | names = name_val 118 | labels = labels_val 119 | domain_label = domain_label_val 120 | 121 | img_tr = get_val_transformer(args) 122 | dataset_list = args.source 123 | if args.freq_analyse == 1: 124 | val_dataset = JigsawTestDatasetFreqAnalyse(names, labels, domain_label, dataset_path=dataset_path, 125 | img_transformer=img_tr, args=args, dataset_list=dataset_list) 126 | else: 127 | val_dataset = JigsawTestNewDataset(names, labels, domain_label, dataset_path=dataset_path, patches=patches, 128 | img_transformer=img_tr, jig_classes=30) 129 | 130 | # if args.limit_target and len(val_dataset) > args.limit_target: 131 | # val_dataset = Subset(val_dataset, args.limit_target) 132 | # print("Using %d subset of val dataset" % args.limit_target) 133 | 134 | dataset = ConcatDataset([val_dataset]) 135 | loader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=4, 136 | pin_memory=True, drop_last=False) 137 | return loader 138 | 139 | 140 | def get_train_transformers(args): 141 | 142 | img_tr = [transforms.RandomResizedCrop((int(args.image_size), int(args.image_size)), (args.min_scale, args.max_scale))] 143 | if args.random_horiz_flip > 0.0: 144 | img_tr.append(transforms.RandomHorizontalFlip(args.random_horiz_flip)) 145 | if args.jitter > 0.0: 146 | img_tr.append(transforms.ColorJitter(brightness=args.jitter, contrast=args.jitter, saturation=args.jitter, hue=min(0.5, args.jitter))) 147 | 148 | # this is special operation for JigenDG 149 | if args.gray_flag: 150 | img_tr.append(transforms.RandomGrayscale(args.tile_random_grayscale)) 151 | 152 | img_tr.append(transforms.ToTensor()) 153 | img_tr.append(transforms.Normalize([0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])) 154 | 155 | tile_tr = [] 156 | if args.tile_random_grayscale: 157 | tile_tr.append(transforms.RandomGrayscale(args.tile_random_grayscale)) 158 | tile_tr = tile_tr + [transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])] 159 | 160 | return transforms.Compose(img_tr), transforms.Compose(tile_tr) 161 | 162 | 163 | def get_val_transformer(args): 164 | img_tr = [ 165 | transforms.Resize((args.image_size, args.image_size)), 166 | transforms.ToTensor(), 167 | transforms.Normalize([0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 168 | ] 169 | return transforms.Compose(img_tr) 170 | 171 | 172 | # def get_train_dataloader_RandAug(args, patches): 173 | # dataset_list = args.source 174 | # assert isinstance(dataset_list, list) 175 | # datasets = [] 176 | # val_datasets = [] 177 | # img_transformer, tile_transformer = get_train_transformers(args) 178 | # limit = args.limit_source 179 | # for dname in dataset_list: 180 | # name_train, labels_train = _dataset_info(join('/data/DataSets/PACS', 'pacs_label', '%s_train_kfold.txt' % dname)) 181 | # name_val, labels_val = _dataset_info(join('/data/DataSets/PACS', 'pacs_label', '%s_crossval_kfold.txt' % dname)) 182 | # 183 | # train_dataset = JigsawDatasetRandAug(name_train, labels_train, patches=patches, img_transformer=img_transformer, 184 | # bias_whole_image=args.bias_whole_image, args=args) 185 | # if limit: 186 | # train_dataset = Subset(train_dataset, limit) 187 | # datasets.append(train_dataset) 188 | # val_datasets.append( 189 | # JigsawTestDatasetRandAug(name_val, labels_val, img_transformer=get_val_transformer(args), 190 | # patches=patches, args=args)) 191 | # dataset = ConcatDataset(datasets) 192 | # val_dataset = ConcatDataset(val_datasets) 193 | # loader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=True, drop_last=True) 194 | # val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4, pin_memory=True, drop_last=False) 195 | # return loader, val_loader 196 | 197 | 198 | # def get_val_dataloader_RandAug(args, patches=False): 199 | # names, labels = _dataset_info(join('/data/DataSets/PACS', 'pacs_label', '%s_test_kfold.txt' % args.target)) 200 | # img_tr = get_val_transformer(args) 201 | # val_dataset = JigsawTestDatasetRandAug(names, labels, patches=patches, img_transformer=img_tr, args=args) 202 | # if args.limit_target and len(val_dataset) > args.limit_target: 203 | # val_dataset = Subset(val_dataset, args.limit_target) 204 | # print("Using %d subset of val dataset" % args.limit_target) 205 | # dataset = ConcatDataset([val_dataset]) 206 | # loader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=False, num_workers=4, pin_memory=True, drop_last=False) 207 | # return loader 208 | -------------------------------------------------------------------------------- /data/samplers.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import numpy as np 3 | import random 4 | from collections import defaultdict 5 | from torch.utils.data.sampler import Sampler, RandomSampler, SequentialSampler 6 | import math 7 | 8 | 9 | 10 | class BatchSchedulerSampler(Sampler): 11 | """ 12 | iterate over tasks and provide a random batch per task in each mini-batch 13 | """ 14 | def __init__(self, dataset, batch_size): 15 | self.dataset = dataset 16 | self.batch_size = batch_size 17 | self.number_of_datasets = len(dataset.datasets) 18 | self.mini_batch_size = int(batch_size / self.number_of_datasets) 19 | # self.largest_dataset_size = max([len(cur_dataset.samples) for cur_dataset in dataset.datasets]) 20 | self.largest_dataset_size = max([len(cur_dataset) for cur_dataset in dataset.datasets]) 21 | 22 | def __len__(self): 23 | return self.mini_batch_size * math.ceil(self.largest_dataset_size / self.mini_batch_size) * len(self.dataset.datasets) 24 | 25 | def __iter__(self): 26 | samplers_list = [] 27 | sampler_iterators = [] 28 | for dataset_idx in range(self.number_of_datasets): 29 | cur_dataset = self.dataset.datasets[dataset_idx] 30 | sampler = RandomSampler(cur_dataset) 31 | samplers_list.append(sampler) 32 | cur_sampler_iterator = sampler.__iter__() 33 | sampler_iterators.append(cur_sampler_iterator) 34 | 35 | push_index_val = [0] + self.dataset.cumulative_sizes[:-1] 36 | step = self.batch_size 37 | # for this case we want to get all samples in dataset, this force us to resample from the smaller datasets 38 | epoch_samples = self.largest_dataset_size * self.number_of_datasets 39 | 40 | final_samples_list = [] # this is a list of indexes from the combined dataset 41 | for _ in range(0, epoch_samples, step): 42 | for i in range(self.number_of_datasets): 43 | cur_batch_sampler = sampler_iterators[i] 44 | cur_samples = [] 45 | for _ in range(self.mini_batch_size): 46 | try: 47 | cur_sample_org = cur_batch_sampler.__next__() 48 | cur_sample = cur_sample_org + push_index_val[i] 49 | cur_samples.append(cur_sample) 50 | except StopIteration: 51 | # got to the end of iterator - restart the iterator and continue to get samples 52 | # until reaching "epoch_samples" 53 | sampler_iterators[i] = samplers_list[i].__iter__() 54 | cur_batch_sampler = sampler_iterators[i] 55 | cur_sample_org = cur_batch_sampler.__next__() 56 | cur_sample = cur_sample_org + push_index_val[i] 57 | cur_samples.append(cur_sample) 58 | final_samples_list.extend(cur_samples) 59 | 60 | return iter(final_samples_list) 61 | 62 | 63 | class RandomDomainSampler(Sampler): 64 | """Randomly samples N domains each with K images 65 | to form a minibatch of size N*K. 66 | Args: 67 | data_source (list): list of Datums. 68 | batch_size (int): batch size. 69 | n_domain (int): number of domains to sample in a minibatch. 70 | """ 71 | 72 | def __init__(self, data_source, batch_size, n_domain): 73 | self.data_source = data_source 74 | 75 | # Keep track of image indices for each domain 76 | self.domain_dict = defaultdict(list) 77 | for i, item in enumerate(data_source): 78 | self.domain_dict[item.domain].append(i) 79 | self.domains = list(self.domain_dict.keys()) 80 | 81 | # Make sure each domain has equal number of images 82 | if n_domain is None or n_domain <= 0: 83 | n_domain = len(self.domains) 84 | assert batch_size % n_domain == 0 85 | self.n_img_per_domain = batch_size // n_domain 86 | 87 | self.batch_size = batch_size 88 | # n_domain denotes number of domains sampled in a minibatch 89 | self.n_domain = n_domain 90 | self.length = len(list(self.__iter__())) 91 | 92 | def __iter__(self): 93 | domain_dict = copy.deepcopy(self.domain_dict) 94 | final_idxs = [] 95 | stop_sampling = False 96 | 97 | while not stop_sampling: 98 | selected_domains = random.sample(self.domains, self.n_domain) 99 | 100 | for domain in selected_domains: 101 | idxs = domain_dict[domain] 102 | selected_idxs = random.sample(idxs, self.n_img_per_domain) 103 | final_idxs.extend(selected_idxs) 104 | 105 | for idx in selected_idxs: 106 | domain_dict[domain].remove(idx) 107 | 108 | remaining = len(domain_dict[domain]) 109 | if remaining < self.n_img_per_domain: 110 | stop_sampling = True 111 | 112 | return iter(final_idxs) 113 | 114 | def __len__(self): 115 | return self.length 116 | 117 | 118 | class SeqDomainSampler(Sampler): 119 | """Sequential domain sampler, which randomly samples K 120 | images from each domain to form a minibatch. 121 | Args: 122 | data_source (list): list of Datums. 123 | batch_size (int): batch size. 124 | """ 125 | 126 | def __init__(self, data_source, batch_size): 127 | self.data_source = data_source 128 | 129 | # Keep track of image indices for each domain 130 | self.domain_dict = defaultdict(list) 131 | for i, item in enumerate(data_source): 132 | self.domain_dict[item.domain].append(i) 133 | self.domains = list(self.domain_dict.keys()) 134 | self.domains.sort() 135 | 136 | # Make sure each domain has equal number of images 137 | n_domain = len(self.domains) 138 | assert batch_size % n_domain == 0 139 | self.n_img_per_domain = batch_size // n_domain 140 | 141 | self.batch_size = batch_size 142 | # n_domain denotes number of domains sampled in a minibatch 143 | self.n_domain = n_domain 144 | self.length = len(list(self.__iter__())) 145 | 146 | def __iter__(self): 147 | domain_dict = copy.deepcopy(self.domain_dict) 148 | final_idxs = [] 149 | stop_sampling = False 150 | 151 | while not stop_sampling: 152 | for domain in self.domains: 153 | idxs = domain_dict[domain] 154 | selected_idxs = random.sample(idxs, self.n_img_per_domain) 155 | final_idxs.extend(selected_idxs) 156 | 157 | for idx in selected_idxs: 158 | domain_dict[domain].remove(idx) 159 | 160 | remaining = len(domain_dict[domain]) 161 | if remaining < self.n_img_per_domain: 162 | stop_sampling = True 163 | 164 | return iter(final_idxs) 165 | 166 | def __len__(self): 167 | return self.length 168 | 169 | 170 | class RandomClassSampler(Sampler): 171 | """Randomly samples N classes each with K instances to 172 | form a minibatch of size N*K. 173 | Modified from https://github.com/KaiyangZhou/deep-person-reid. 174 | Args: 175 | data_source (list): list of Datums. 176 | batch_size (int): batch size. 177 | n_ins (int): number of instances per class to sample in a minibatch. 178 | """ 179 | 180 | def __init__(self, data_source, batch_size, n_ins): 181 | if batch_size < n_ins: 182 | raise ValueError( 183 | "batch_size={} must be no less " 184 | "than n_ins={}".format(batch_size, n_ins) 185 | ) 186 | 187 | self.data_source = data_source 188 | self.batch_size = batch_size 189 | self.n_ins = n_ins 190 | self.ncls_per_batch = self.batch_size // self.n_ins 191 | self.index_dic = defaultdict(list) 192 | for index, item in enumerate(data_source): 193 | self.index_dic[item.label].append(index) 194 | self.labels = list(self.index_dic.keys()) 195 | assert len(self.labels) >= self.ncls_per_batch 196 | 197 | # estimate number of images in an epoch 198 | self.length = len(list(self.__iter__())) 199 | 200 | def __iter__(self): 201 | batch_idxs_dict = defaultdict(list) 202 | 203 | for label in self.labels: 204 | idxs = copy.deepcopy(self.index_dic[label]) 205 | if len(idxs) < self.n_ins: 206 | idxs = np.random.choice(idxs, size=self.n_ins, replace=True) 207 | random.shuffle(idxs) 208 | batch_idxs = [] 209 | for idx in idxs: 210 | batch_idxs.append(idx) 211 | if len(batch_idxs) == self.n_ins: 212 | batch_idxs_dict[label].append(batch_idxs) 213 | batch_idxs = [] 214 | 215 | avai_labels = copy.deepcopy(self.labels) 216 | final_idxs = [] 217 | 218 | while len(avai_labels) >= self.ncls_per_batch: 219 | selected_labels = random.sample(avai_labels, self.ncls_per_batch) 220 | for label in selected_labels: 221 | batch_idxs = batch_idxs_dict[label].pop(0) 222 | final_idxs.extend(batch_idxs) 223 | if len(batch_idxs_dict[label]) == 0: 224 | avai_labels.remove(label) 225 | 226 | return iter(final_idxs) 227 | 228 | def __len__(self): 229 | return self.length 230 | 231 | 232 | def build_sampler( 233 | sampler_type, 234 | cfg=None, 235 | data_source=None, 236 | batch_size=32, 237 | n_domain=0, 238 | n_ins=16 239 | ): 240 | if sampler_type == "RandomSampler": 241 | return RandomSampler(data_source) 242 | 243 | elif sampler_type == "SequentialSampler": 244 | return SequentialSampler(data_source) 245 | 246 | elif sampler_type == "RandomDomainSampler": 247 | return RandomDomainSampler(data_source, batch_size, n_domain) 248 | 249 | elif sampler_type == "SeqDomainSampler": 250 | return SeqDomainSampler(data_source, batch_size) 251 | 252 | elif sampler_type == "RandomClassSampler": 253 | return RandomClassSampler(data_source, batch_size, n_ins) 254 | 255 | else: 256 | raise ValueError("Unknown sampler type: {}".format(sampler_type)) 257 | 258 | -------------------------------------------------------------------------------- /engine.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | """ 4 | Train and eval functions used in main.py 5 | """ 6 | import math 7 | import sys 8 | from typing import Iterable, Optional 9 | 10 | import torch 11 | 12 | from timm.data import Mixup 13 | from timm.utils import accuracy, ModelEma 14 | 15 | from losses import DistillationLoss 16 | import utils 17 | 18 | import random 19 | import torch.nn.functional as F 20 | 21 | 22 | def train_one_epoch(model: torch.nn.Module, criterion: DistillationLoss, 23 | data_loader: Iterable, optimizer: torch.optim.Optimizer, 24 | device: torch.device, epoch: int, loss_scaler, max_norm: float = 0, 25 | model_ema: Optional[ModelEma] = None, mixup_fn: Optional[Mixup] = None, 26 | set_training_mode=True): 27 | model.train(set_training_mode) 28 | 29 | metric_logger = utils.MetricLogger(delimiter=" ") 30 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) 31 | header = 'Epoch: [{}]'.format(epoch) 32 | print_freq = 200 33 | 34 | for (samples, targets), domains in metric_logger.log_every(data_loader, print_freq, header): 35 | samples = samples.to(device, non_blocking=True) 36 | targets = targets.to(device, non_blocking=True) 37 | 38 | if mixup_fn is not None: 39 | samples, targets = mixup_fn(samples, targets) 40 | 41 | with torch.cuda.amp.autocast(): 42 | outputs = model(samples) 43 | loss = criterion(samples, outputs, targets.long()) 44 | 45 | loss_value = loss.item() 46 | 47 | if not math.isfinite(loss_value): 48 | print("Loss is {}, stopping training".format(loss_value)) 49 | sys.exit(1) 50 | 51 | optimizer.zero_grad() 52 | 53 | # this attribute is added by timm on one optimizer (adahessian) 54 | is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order 55 | loss_scaler(loss, optimizer, clip_grad=max_norm, 56 | parameters=model.parameters(), create_graph=is_second_order) 57 | 58 | torch.cuda.synchronize() 59 | if model_ema is not None: 60 | model_ema.update(model) 61 | 62 | metric_logger.update(loss=loss_value) 63 | metric_logger.update(lr=optimizer.param_groups[0]["lr"]) 64 | # gather the stats from all processes 65 | metric_logger.synchronize_between_processes() 66 | # print("Averaged stats:", metric_logger) 67 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 68 | 69 | 70 | @torch.no_grad() 71 | def evaluate(data_loader, model, device): 72 | criterion = torch.nn.CrossEntropyLoss() 73 | 74 | metric_logger = utils.MetricLogger(delimiter=" ") 75 | header = 'Test:' 76 | 77 | # switch to evaluation mode 78 | model.eval() 79 | 80 | for (images, target), _ in metric_logger.log_every(data_loader, 200, header): 81 | images = images.to(device, non_blocking=True) 82 | target = target.to(device, non_blocking=True) 83 | 84 | # compute output 85 | with torch.cuda.amp.autocast(): 86 | output = model(images) 87 | loss = criterion(output, target) 88 | 89 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 90 | 91 | batch_size = images.shape[0] 92 | metric_logger.update(loss=loss.item()) 93 | metric_logger.meters['acc1'].update(acc1.item(), n=batch_size) 94 | metric_logger.meters['acc5'].update(acc5.item(), n=batch_size) 95 | # gather the stats from all processes 96 | metric_logger.synchronize_between_processes() 97 | # print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}' 98 | # .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss)) 99 | 100 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 101 | 102 | 103 | @torch.no_grad() 104 | def get_feature(data_loader, model, device, norm_flag=0): 105 | metric_logger = utils.MetricLogger(delimiter=" ") 106 | header = 'Test:' 107 | 108 | # switch to evaluation mode 109 | model.eval() 110 | 111 | features = [] 112 | targets = [] 113 | for (images, target), _ in metric_logger.log_every(data_loader, 200, header): 114 | images = images.to(device, non_blocking=True) 115 | target = target.to(device, non_blocking=True) 116 | 117 | # compute output 118 | with torch.cuda.amp.autocast(): 119 | output, _ = model.forward_features(images) 120 | if norm_flag == 1: 121 | output = F.normalize(output, p=2, dim=1) 122 | features.append(output) 123 | targets.append(target) 124 | features = torch.cat(features, dim=0) 125 | targets = torch.cat(targets, dim=0) 126 | 127 | return features, targets 128 | 129 | -------------------------------------------------------------------------------- /gfnet.py: -------------------------------------------------------------------------------- 1 | import math 2 | import logging 3 | from functools import partial 4 | from collections import OrderedDict 5 | from copy import Error, deepcopy 6 | from re import S 7 | from numpy.lib.arraypad import pad 8 | import numpy as np 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 14 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 15 | import torch.fft 16 | from torch.nn.modules.container import Sequential 17 | import random 18 | from math import sqrt 19 | from functools import partial, reduce 20 | from operator import mul 21 | 22 | _logger = logging.getLogger(__name__) 23 | 24 | 25 | def _cfg(url='', **kwargs): 26 | return { 27 | 'url': url, 28 | 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, 29 | 'crop_pct': .9, 'interpolation': 'bicubic', 30 | 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 31 | 'first_conv': 'patch_embed.proj', 'classifier': 'head', 32 | **kwargs 33 | } 34 | 35 | 36 | class Mlp(nn.Module): 37 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 38 | super().__init__() 39 | out_features = out_features or in_features 40 | hidden_features = hidden_features or in_features 41 | self.fc1 = nn.Linear(in_features, hidden_features) 42 | self.act = act_layer() 43 | self.fc2 = nn.Linear(hidden_features, out_features) 44 | self.drop = nn.Dropout(drop) 45 | 46 | def forward(self, x): 47 | x = self.fc1(x) 48 | x = self.act(x) 49 | x = self.drop(x) 50 | x = self.fc2(x) 51 | x = self.drop(x) 52 | return x 53 | 54 | 55 | class GlobalFilter(nn.Module): 56 | def __init__(self, dim, h=14, w=8, 57 | mask_radio=0.1, mask_alpha=0.5, 58 | noise_mode=1, 59 | uncertainty_model=0, perturb_prob=0.5, 60 | uncertainty_factor=1.0, 61 | noise_layer_flag=0, gauss_or_uniform=0, ): 62 | super().__init__() 63 | self.complex_weight = nn.Parameter(torch.randn(h, w, dim, 2, dtype=torch.float32) * 0.02) 64 | self.w = w 65 | self.h = h 66 | 67 | self.mask_radio = mask_radio 68 | 69 | self.noise_mode = noise_mode 70 | self.noise_layer_flag = noise_layer_flag 71 | 72 | self.alpha = mask_alpha 73 | 74 | self.eps = 1e-6 75 | self.factor = uncertainty_factor 76 | self.uncertainty_model = uncertainty_model 77 | self.p = perturb_prob 78 | self.gauss_or_uniform = gauss_or_uniform 79 | 80 | def _reparameterize(self, mu, std, epsilon_norm): 81 | # epsilon = torch.randn_like(std) * self.factor 82 | epsilon = epsilon_norm * self.factor 83 | mu_t = mu + epsilon * std 84 | return mu_t 85 | 86 | def spectrum_noise(self, img_fft, ratio=1.0, noise_mode=1, 87 | uncertainty_model=0, gauss_or_uniform=0): 88 | """Input image size: ndarray of [H, W, C]""" 89 | """noise_mode: 1 amplitude; 2: phase 3:both""" 90 | """uncertainty_model: 1 batch-wise modeling 2: channel-wise modeling 3:token-wise modeling""" 91 | if random.random() > self.p: 92 | return img_fft 93 | batch_size, h, w, c = img_fft.shape 94 | 95 | img_abs, img_pha = torch.abs(img_fft), torch.angle(img_fft) 96 | 97 | img_abs = torch.fft.fftshift(img_abs, dim=(1)) 98 | 99 | h_crop = int(h * sqrt(ratio)) 100 | w_crop = int(w * sqrt(ratio)) 101 | h_start = h // 2 - h_crop // 2 102 | w_start = 0 103 | 104 | img_abs_ = img_abs.clone() 105 | if noise_mode != 0: 106 | if uncertainty_model != 0: 107 | if uncertainty_model == 1: 108 | # batch level modeling 109 | miu = torch.mean(img_abs_[:, h_start:h_start + h_crop, w_start:w_start + w_crop, :], dim=(1, 2), 110 | keepdim=True) 111 | var = torch.var(img_abs_[:, h_start:h_start + h_crop, w_start:w_start + w_crop, :], dim=(1, 2), 112 | keepdim=True) 113 | sig = (var + self.eps).sqrt() # Bx1x1xC 114 | 115 | var_of_miu = torch.var(miu, dim=0, keepdim=True) 116 | var_of_sig = torch.var(sig, dim=0, keepdim=True) 117 | sig_of_miu = (var_of_miu + self.eps).sqrt().repeat(miu.shape[0], 1, 1, 1) 118 | sig_of_sig = (var_of_sig + self.eps).sqrt().repeat(miu.shape[0], 1, 1, 1) # Bx1x1xC 119 | 120 | if gauss_or_uniform == 0: 121 | epsilon_norm_miu = torch.randn_like(sig_of_miu) # N(0,1) 122 | epsilon_norm_sig = torch.randn_like(sig_of_sig) 123 | 124 | miu_mean = miu 125 | sig_mean = sig 126 | 127 | beta = self._reparameterize(mu=miu_mean, std=sig_of_miu, epsilon_norm=epsilon_norm_miu) 128 | gamma = self._reparameterize(mu=sig_mean, std=sig_of_sig, epsilon_norm=epsilon_norm_sig) 129 | elif gauss_or_uniform == 1: 130 | epsilon_norm_miu = torch.rand_like(sig_of_miu) * 2 - 1. # U(-1,1) 131 | epsilon_norm_sig = torch.rand_like(sig_of_sig) * 2 - 1. 132 | beta = self._reparameterize(mu=miu, std=sig_of_miu, epsilon_norm=epsilon_norm_miu) 133 | gamma = self._reparameterize(mu=sig, std=sig_of_sig, epsilon_norm=epsilon_norm_sig) 134 | else: 135 | epsilon_norm_miu = torch.randn_like(sig_of_miu) # N(0,1) 136 | epsilon_norm_sig = torch.randn_like(sig_of_sig) 137 | beta = self._reparameterize(mu=miu, std=1., epsilon_norm=epsilon_norm_miu) 138 | gamma = self._reparameterize(mu=sig, std=1., epsilon_norm=epsilon_norm_sig) 139 | 140 | # adjust statistics for each sample 141 | img_abs[:, h_start:h_start + h_crop, w_start:w_start + w_crop, :] = gamma * ( 142 | img_abs[:, h_start:h_start + h_crop, w_start:w_start + w_crop, :] - miu) / sig + beta 143 | 144 | elif uncertainty_model == 2: 145 | # element level modeling 146 | miu_of_elem = torch.mean(img_abs_[:, h_start:h_start + h_crop, w_start:w_start + w_crop, :], dim=0, 147 | keepdim=True) 148 | var_of_elem = torch.var(img_abs_[:, h_start:h_start + h_crop, w_start:w_start + w_crop, :], dim=0, 149 | keepdim=True) 150 | sig_of_elem = (var_of_elem + self.eps).sqrt() # 1xHxWxC 151 | 152 | if gauss_or_uniform == 0: 153 | epsilon_sig = torch.randn_like( 154 | img_abs[:, h_start:h_start + h_crop, w_start:w_start + w_crop, :]) # BxHxWxC N(0,1) 155 | gamma = epsilon_sig * sig_of_elem * self.factor 156 | elif gauss_or_uniform == 1: 157 | epsilon_sig = torch.rand_like( 158 | img_abs[:, h_start:h_start + h_crop, w_start:w_start + w_crop, :]) * 2 - 1. # U(-1,1) 159 | gamma = epsilon_sig * sig_of_elem * self.factor 160 | else: 161 | epsilon_sig = torch.randn_like( 162 | img_abs[:, h_start:h_start + h_crop, w_start:w_start + w_crop, :]) # BxHxWxC N(0,1) 163 | gamma = epsilon_sig * self.factor 164 | 165 | img_abs[:, h_start:h_start + h_crop, w_start:w_start + w_crop, :] = \ 166 | img_abs[:, h_start:h_start + h_crop, w_start:w_start + w_crop, :] + gamma 167 | img_abs = torch.fft.ifftshift(img_abs, dim=(1)) # recover 168 | img_mix = img_abs * (np.e ** (1j * img_pha)) 169 | return img_mix 170 | 171 | def forward(self, x, spatial_size=None): 172 | B, N, C = x.shape 173 | if spatial_size is None: 174 | a = b = int(math.sqrt(N)) 175 | else: 176 | a, b = spatial_size 177 | 178 | x = x.view(B, a, b, C) 179 | x = x.to(torch.float32) 180 | x = torch.fft.rfft2(x, dim=(1, 2), norm='ortho') 181 | 182 | if self.training: 183 | if self.noise_mode != 0 and self.noise_layer_flag == 1: 184 | x = self.spectrum_noise(x, ratio=self.mask_radio, noise_mode=self.noise_mode, 185 | uncertainty_model=self.uncertainty_model, 186 | gauss_or_uniform=self.gauss_or_uniform) 187 | weight = torch.view_as_complex(self.complex_weight) 188 | x = x * weight 189 | x = torch.fft.irfft2(x, s=(a, b), dim=(1, 2), norm='ortho') 190 | x = x.reshape(B, N, C) 191 | return x 192 | 193 | 194 | class Block(nn.Module): 195 | def __init__(self, dim, mlp_ratio=4., drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, h=14, w=8, 196 | mask_radio=0.1, mask_alpha=0.5, 197 | noise_mode=1, 198 | uncertainty_model=0, perturb_prob=0.5, 199 | uncertainty_factor=1.0, 200 | gauss_or_uniform=0, ): 201 | super().__init__() 202 | self.norm1 = norm_layer(dim) 203 | self.filter = GlobalFilter(dim, h=h, w=w, 204 | mask_radio=mask_radio, 205 | mask_alpha=mask_alpha, 206 | noise_mode=noise_mode, 207 | uncertainty_model=uncertainty_model, perturb_prob=perturb_prob, 208 | uncertainty_factor=uncertainty_factor, noise_layer_flag=1, 209 | gauss_or_uniform=gauss_or_uniform, ) 210 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 211 | self.norm2 = norm_layer(dim) 212 | mlp_hidden_dim = int(dim * mlp_ratio) 213 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 214 | 215 | def forward(self, input): 216 | x = input 217 | x = x + self.drop_path(self.mlp(self.norm2(self.filter(self.norm1(x))))) 218 | # Drop_path: In residual architecture, drop the current block for randomly seleted samples 219 | return x 220 | 221 | 222 | class BlockLayerScale(nn.Module): 223 | def __init__(self, dim, mlp_ratio=4., drop=0., drop_path=0., act_layer=nn.GELU, 224 | norm_layer=nn.LayerNorm, h=14, w=8, init_values=1e-5, 225 | mask_radio=0.1, mask_alpha=0.5, noise_mode=1, 226 | uncertainty_model=0, perturb_prob=0.5, uncertainty_factor=1.0, 227 | layer_index=0, noise_layers=[0, 1, 2, 3], gauss_or_uniform=0, ): 228 | super().__init__() 229 | self.norm1 = norm_layer(dim) 230 | 231 | if layer_index in noise_layers: 232 | noise_layer_flag = 1 233 | else: 234 | noise_layer_flag = 0 235 | self.filter = GlobalFilter(dim, h=h, w=w, 236 | mask_radio=mask_radio, 237 | mask_alpha=mask_alpha, 238 | noise_mode=noise_mode, 239 | uncertainty_model=uncertainty_model, perturb_prob=perturb_prob, 240 | uncertainty_factor=uncertainty_factor, 241 | noise_layer_flag=noise_layer_flag, gauss_or_uniform=gauss_or_uniform, ) 242 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 243 | self.norm2 = norm_layer(dim) 244 | mlp_hidden_dim = int(dim * mlp_ratio) 245 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 246 | self.gamma = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True) 247 | 248 | self.layer_index = layer_index # where is the block in 249 | 250 | def forward(self, input): 251 | x = input 252 | x = x + self.drop_path(self.gamma * self.mlp(self.norm2(self.filter(self.norm1(x))))) 253 | return x 254 | 255 | 256 | class PatchEmbed(nn.Module): 257 | """ Image to Patch Embedding 258 | """ 259 | 260 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): 261 | super().__init__() 262 | img_size = to_2tuple(img_size) 263 | patch_size = to_2tuple(patch_size) 264 | num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) 265 | self.img_size = img_size 266 | self.patch_size = patch_size 267 | self.num_patches = num_patches 268 | 269 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 270 | 271 | def init_weights(self, stop_grad_conv1=0): 272 | val = math.sqrt(6. / float(3 * reduce(mul, self.patch_size, 1) + self.embed_dim)) 273 | nn.init.uniform_(self.proj.weight, -val, val) 274 | nn.init.zeros_(self.proj.bias) 275 | 276 | if stop_grad_conv1: 277 | self.proj.weight.requires_grad = False 278 | self.proj.bias.requires_grad = False 279 | 280 | def forward(self, x): 281 | B, C, H, W = x.shape 282 | # FIXME look at relaxing size constraints 283 | assert H == self.img_size[0] and W == self.img_size[1], \ 284 | f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." 285 | x = self.proj(x).flatten(2).transpose(1, 2) # BxCxHxW -> BxNxC , N=(224/4)^2=3136, C=64 286 | 287 | return x 288 | 289 | 290 | class DownLayer(nn.Module): 291 | """ Image to Patch Embedding 292 | """ 293 | 294 | def __init__(self, img_size=56, dim_in=64, dim_out=128): 295 | super().__init__() 296 | self.img_size = img_size 297 | self.dim_in = dim_in 298 | self.dim_out = dim_out 299 | self.proj = nn.Conv2d(dim_in, dim_out, kernel_size=2, stride=2) 300 | self.num_patches = img_size * img_size // 4 301 | 302 | def forward(self, x): 303 | B, N, C = x.size() 304 | x = x.view(B, self.img_size, self.img_size, C).permute(0, 3, 1, 2) 305 | x = self.proj(x).permute(0, 2, 3, 1) 306 | x = x.reshape(B, -1, self.dim_out) 307 | 308 | return x 309 | 310 | 311 | class GFNet(nn.Module): 312 | 313 | def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, 314 | mlp_ratio=4., representation_size=None, uniform_drop=False, 315 | drop_rate=0., drop_path_rate=0., norm_layer=None, 316 | dropcls=0, ): 317 | """ 318 | Args: 319 | img_size (int, tuple): input image size 320 | patch_size (int, tuple): patch size 321 | in_chans (int): number of input channels 322 | num_classes (int): number of classes for classification head 323 | embed_dim (int): embedding dimension 324 | depth (int): depth of transformer 325 | num_heads (int): number of attention heads 326 | mlp_ratio (int): ratio of mlp hidden dim to embedding dim 327 | qkv_bias (bool): enable bias for qkv if True 328 | qk_scale (float): override default qk scale of head_dim ** -0.5 if set 329 | representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set 330 | drop_rate (float): dropout rate 331 | attn_drop_rate (float): attention dropout rate 332 | drop_path_rate (float): stochastic depth rate 333 | hybrid_backbone (nn.Module): CNN backbone to use in-place of PatchEmbed module 334 | norm_layer: (nn.Module): normalization layer 335 | """ 336 | super().__init__() 337 | self.num_classes = num_classes 338 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models 339 | norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) 340 | 341 | self.patch_embed = PatchEmbed( 342 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) 343 | num_patches = self.patch_embed.num_patches 344 | 345 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) 346 | self.pos_drop = nn.Dropout(p=drop_rate) 347 | 348 | h = img_size // patch_size 349 | w = h // 2 + 1 350 | 351 | if uniform_drop: 352 | print('using uniform droppath with expect rate', drop_path_rate) 353 | dpr = [drop_path_rate for _ in range(depth)] # stochastic depth decay rule 354 | else: 355 | print('using linear droppath with expect rate', drop_path_rate * 0.5) 356 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 357 | 358 | self.blocks = nn.ModuleList([ 359 | Block( 360 | dim=embed_dim, mlp_ratio=mlp_ratio, 361 | drop=drop_rate, drop_path=dpr[i], norm_layer=norm_layer, h=h, w=w) 362 | for i in range(depth)]) 363 | 364 | self.norm = norm_layer(embed_dim) 365 | 366 | # Representation layer 367 | if representation_size: 368 | self.num_features = representation_size 369 | self.pre_logits = nn.Sequential(OrderedDict([ 370 | ('fc', nn.Linear(embed_dim, representation_size)), 371 | ('act', nn.Tanh()) 372 | ])) 373 | else: 374 | self.pre_logits = nn.Identity() 375 | 376 | # Classifier head 377 | self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() 378 | 379 | if dropcls > 0: 380 | print('dropout %.2f before classifier' % dropcls) 381 | self.final_dropout = nn.Dropout(p=dropcls) 382 | else: 383 | self.final_dropout = nn.Identity() 384 | 385 | trunc_normal_(self.pos_embed, std=.02) 386 | self.apply(self._init_weights) 387 | 388 | def _init_weights(self, m): 389 | if isinstance(m, nn.Linear): 390 | trunc_normal_(m.weight, std=.02) 391 | if isinstance(m, nn.Linear) and m.bias is not None: 392 | nn.init.constant_(m.bias, 0) 393 | elif isinstance(m, nn.LayerNorm): 394 | nn.init.constant_(m.bias, 0) 395 | nn.init.constant_(m.weight, 1.0) 396 | 397 | @torch.jit.ignore 398 | def no_weight_decay(self): 399 | return {'pos_embed', 'cls_token'} 400 | 401 | def get_classifier(self): 402 | return self.head 403 | 404 | def reset_classifier(self, num_classes, global_pool=''): 405 | self.num_classes = num_classes 406 | self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() 407 | 408 | def forward_features(self, x): 409 | B = x.shape[0] 410 | x = self.patch_embed(x) 411 | x = x + self.pos_embed 412 | x = self.pos_drop(x) 413 | 414 | for blk in self.blocks: 415 | x = blk(x) 416 | 417 | x = self.norm(x).mean(1) 418 | return x 419 | 420 | def forward(self, x): 421 | x = self.forward_features(x) 422 | x = self.final_dropout(x) 423 | x = self.head(x) 424 | return x 425 | 426 | 427 | class GFNetPyramid(nn.Module): 428 | 429 | def __init__(self, img_size=224, patch_size=4, num_classes=1000, embed_dim=[64, 128, 256, 512], depth=[2, 2, 10, 4], 430 | mlp_ratio=[4, 4, 4, 4], 431 | drop_rate=0., drop_path_rate=0., norm_layer=None, init_values=0.001, no_layerscale=False, dropcls=0, 432 | mask_radio=0.1, mask_alpha=0.5, noise_mode=1, 433 | uncertainty_model=0, 434 | perturb_prob=0.5, 435 | uncertainty_factor=1.0, 436 | noise_layers=[0, 1, 2, 3], gauss_or_uniform=0, ): 437 | """ 438 | Args: 439 | img_size (int, tuple): input image size 440 | patch_size (int, tuple): patch size 441 | in_chans (int): number of input channels 442 | num_classes (int): number of classes for classification head 443 | embed_dim (int): embedding dimension 444 | depth (int): depth of transformer 445 | num_heads (int): number of attention heads 446 | mlp_ratio (int): ratio of mlp hidden dim to embedding dim 447 | qkv_bias (bool): enable bias for qkv if True 448 | qk_scale (float): override default qk scale of head_dim ** -0.5 if set 449 | representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set 450 | drop_rate (float): dropout rate 451 | attn_drop_rate (float): attention dropout rate 452 | drop_path_rate (float): stochastic depth rate 453 | hybrid_backbone (nn.Module): CNN backbone to use in-place of PatchEmbed module 454 | norm_layer: (nn.Module): normalization layer 455 | """ 456 | super().__init__() 457 | self.num_classes = num_classes 458 | self.num_features = self.embed_dim = embed_dim[-1] # num_features for consistency with other models 459 | norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) 460 | 461 | self.patch_embed = nn.ModuleList() 462 | 463 | patch_embed = PatchEmbed( 464 | img_size=img_size, patch_size=patch_size, in_chans=3, embed_dim=embed_dim[0]) 465 | num_patches = patch_embed.num_patches 466 | 467 | # patch_embed.init_weights(stop_grad_conv1=False) 468 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim[0])) 469 | 470 | self.patch_embed.append(patch_embed) 471 | 472 | sizes = [56, 28, 14, 7] 473 | for i in range(4): 474 | sizes[i] = sizes[i] * img_size // 224 475 | 476 | for i in range(3): 477 | patch_embed = DownLayer(sizes[i], embed_dim[i], embed_dim[i + 1]) 478 | num_patches = patch_embed.num_patches 479 | self.patch_embed.append(patch_embed) 480 | 481 | self.pos_drop = nn.Dropout(p=drop_rate) 482 | self.blocks = nn.ModuleList() 483 | 484 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depth))] # stochastic depth decay rule 485 | cur = 0 486 | for i in range(4): 487 | h = sizes[i] 488 | w = h // 2 + 1 489 | 490 | if no_layerscale: 491 | print('using standard block') 492 | blk = nn.Sequential(*[ 493 | Block( 494 | dim=embed_dim[i], mlp_ratio=mlp_ratio[i], 495 | drop=drop_rate, drop_path=dpr[cur + j], norm_layer=norm_layer, h=h, w=w, 496 | mask_radio=mask_radio, 497 | mask_alpha=mask_alpha, 498 | noise_mode=noise_mode, 499 | uncertainty_model=uncertainty_model, perturb_prob=perturb_prob, 500 | uncertainty_factor=uncertainty_factor, 501 | gauss_or_uniform=gauss_or_uniform, 502 | ) 503 | for j in range(depth[i]) 504 | ]) 505 | else: 506 | print('using layerscale block') 507 | blk = nn.Sequential(*[ 508 | BlockLayerScale( 509 | dim=embed_dim[i], mlp_ratio=mlp_ratio[i], 510 | drop=drop_rate, drop_path=dpr[cur + j], norm_layer=norm_layer, h=h, w=w, 511 | init_values=init_values, 512 | mask_radio=mask_radio, mask_alpha=mask_alpha, noise_mode=noise_mode, 513 | uncertainty_model=uncertainty_model, perturb_prob=perturb_prob, 514 | uncertainty_factor=uncertainty_factor, 515 | layer_index=i, 516 | noise_layers=noise_layers, gauss_or_uniform=gauss_or_uniform, 517 | ) 518 | for j in range(depth[i]) 519 | ]) 520 | self.blocks.append(blk) 521 | cur += depth[i] 522 | 523 | # Classifier head 524 | self.norm = norm_layer(embed_dim[-1]) 525 | 526 | self.head = nn.Linear(self.num_features, num_classes) 527 | 528 | if dropcls > 0: 529 | print('dropout %.2f before classifier' % dropcls) 530 | self.final_dropout = nn.Dropout(p=dropcls) 531 | else: 532 | self.final_dropout = nn.Identity() 533 | 534 | trunc_normal_(self.pos_embed, std=.02) 535 | self.apply(self._init_weights) 536 | 537 | def _init_weights(self, m): 538 | if isinstance(m, nn.Linear): 539 | trunc_normal_(m.weight, std=.02) 540 | if isinstance(m, nn.Linear) and m.bias is not None: 541 | nn.init.constant_(m.bias, 0) 542 | elif isinstance(m, nn.LayerNorm): 543 | nn.init.constant_(m.bias, 0) 544 | nn.init.constant_(m.weight, 1.0) 545 | 546 | @torch.jit.ignore 547 | def no_weight_decay(self): 548 | return {'pos_embed', 'cls_token'} 549 | 550 | def get_classifier(self): 551 | return self.head 552 | 553 | def reset_classifier(self, num_classes, global_pool=''): 554 | self.num_classes = num_classes 555 | self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() 556 | 557 | def forward_features(self, x): 558 | for i in range(4): 559 | x = self.patch_embed[i](x) 560 | if i == 0: 561 | x = x + self.pos_embed 562 | 563 | x = self.blocks[i]((x)) 564 | x = self.norm(x).mean(1) 565 | return x 566 | 567 | def forward(self, x): 568 | x = self.forward_features(x) 569 | x = self.final_dropout(x) 570 | x = self.head(x) 571 | return x 572 | 573 | 574 | def resize_pos_embed(posemb, posemb_new): 575 | # Rescale the grid of position embeddings when loading from state_dict. Adapted from 576 | # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224 577 | _logger.info('Resized position embedding: %s to %s', posemb.shape, posemb_new.shape) 578 | ntok_new = posemb_new.shape[1] 579 | if True: 580 | posemb_tok, posemb_grid = posemb[:, :1], posemb[0, 1:] 581 | ntok_new -= 1 582 | else: 583 | posemb_tok, posemb_grid = posemb[:, :0], posemb[0] 584 | gs_old = int(math.sqrt(len(posemb_grid))) 585 | gs_new = int(math.sqrt(ntok_new)) 586 | _logger.info('Position embedding grid-size from %s to %s', gs_old, gs_new) 587 | posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2) 588 | posemb_grid = F.interpolate(posemb_grid, size=(gs_new, gs_new), mode='bilinear') 589 | posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new * gs_new, -1) 590 | posemb = torch.cat([posemb_tok, posemb_grid], dim=1) 591 | return posemb 592 | 593 | 594 | def checkpoint_filter_fn(state_dict, model): 595 | """ convert patch embedding weight from manual patchify + linear proj to conv""" 596 | out_dict = {} 597 | if 'model' in state_dict: 598 | # For deit models 599 | state_dict = state_dict['model'] 600 | for k, v in state_dict.items(): 601 | if 'patch_embed.proj.weight' in k and len(v.shape) < 4: 602 | # For old models that I trained prior to conv based patchification 603 | O, I, H, W = model.patch_embed.proj.weight.shape 604 | v = v.reshape(O, -1, H, W) 605 | elif k == 'pos_embed' and v.shape != model.pos_embed.shape: 606 | # To resize pos embedding when using model at different size from pretrained weights 607 | v = resize_pos_embed(v, model.pos_embed) 608 | out_dict[k] = v 609 | return out_dict 610 | -------------------------------------------------------------------------------- /losses.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | """ 4 | Implements the knowledge distillation loss 5 | """ 6 | from abc import get_cache_token 7 | import torch 8 | from torch.nn import functional as F 9 | from torch.nn.modules.loss import MSELoss, BCEWithLogitsLoss, CrossEntropyLoss 10 | from utils import batch_index_select 11 | from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy 12 | import math 13 | 14 | class DistillationLoss(torch.nn.Module): 15 | """ 16 | This module wraps a standard criterion and adds an extra knowledge distillation loss by 17 | taking a teacher model prediction and using it as additional supervision. 18 | """ 19 | def __init__(self, base_criterion: torch.nn.Module, teacher_model: torch.nn.Module, 20 | distillation_type: str, alpha: float, tau: float): 21 | super().__init__() 22 | self.base_criterion = base_criterion 23 | self.teacher_model = teacher_model 24 | assert distillation_type in ['none', 'soft', 'hard'] 25 | self.distillation_type = distillation_type 26 | self.alpha = alpha 27 | self.tau = tau 28 | 29 | def forward(self, inputs, outputs, labels): 30 | """ 31 | Args: 32 | inputs: The original inputs that are feed to the teacher model 33 | outputs: the outputs of the model to be trained. It is expected to be 34 | either a Tensor, or a Tuple[Tensor, Tensor], with the original output 35 | in the first position and the distillation predictions as the second output 36 | labels: the labels for the base criterion 37 | """ 38 | outputs_kd = None 39 | if not isinstance(outputs, torch.Tensor): 40 | # assume that the model outputs a tuple of [outputs, outputs_kd] 41 | outputs, outputs_kd = outputs 42 | base_loss = self.base_criterion(outputs, labels) 43 | if self.distillation_type == 'none': 44 | return base_loss 45 | 46 | if outputs_kd is None: 47 | raise ValueError("When knowledge distillation is enabled, the model is " 48 | "expected to return a Tuple[Tensor, Tensor] with the output of the " 49 | "class_token and the dist_token") 50 | # don't backprop throught the teacher 51 | with torch.no_grad(): 52 | teacher_outputs = self.teacher_model(inputs) 53 | 54 | if self.distillation_type == 'soft': 55 | T = self.tau 56 | # taken from https://github.com/peterliht/knowledge-distillation-pytorch/blob/master/model/net.py#L100 57 | # with slight modifications 58 | distillation_loss = F.kl_div( 59 | F.log_softmax(outputs_kd / T, dim=1), 60 | F.log_softmax(teacher_outputs / T, dim=1), 61 | reduction='sum', 62 | log_target=True 63 | ) * (T * T) / outputs_kd.numel() 64 | elif self.distillation_type == 'hard': 65 | distillation_loss = F.cross_entropy(outputs_kd, teacher_outputs.argmax(dim=1)) 66 | 67 | loss = base_loss * (1 - self.alpha) + distillation_loss * self.alpha 68 | return loss 69 | -------------------------------------------------------------------------------- /main_gfnet.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import datetime 3 | import numpy as np 4 | import time 5 | import torch 6 | import torch.backends.cudnn as cudnn 7 | import json 8 | 9 | from pathlib import Path 10 | 11 | from timm.data import Mixup 12 | from timm.models import create_model 13 | from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy 14 | from timm.scheduler import create_scheduler 15 | from timm.optim import create_optimizer_v2, create_optimizer 16 | from timm.utils import NativeScaler, get_state_dict, ModelEma 17 | from functools import partial 18 | import torch.nn as nn 19 | 20 | from engine import train_one_epoch, evaluate 21 | from losses import DistillationLoss 22 | import utils 23 | from gfnet import GFNet, GFNetPyramid 24 | 25 | from data import data_helper 26 | import os 27 | 28 | from torch import optim 29 | 30 | 31 | import warnings 32 | warnings.filterwarnings("ignore", message="Argument interpolation should be") 33 | 34 | def get_args_parser(): 35 | parser = argparse.ArgumentParser('DeiT training and evaluation script', add_help=False) 36 | 37 | parser.add_argument("--target", default=2, type=int, help="Target") 38 | parser.add_argument("--device", type=int, default=3, help="GPU num") 39 | 40 | parser.add_argument('--batch-size', default=64, type=int) 41 | parser.add_argument('--epochs', default=50, type=int) 42 | 43 | parser.add_argument('--perturb_prob', default=0.5, type=float) 44 | parser.add_argument('--mask_alpha', default=0.2, type=float) 45 | 46 | parser.add_argument('--noise_mode', default=1, type=int, help="0: close; 1: add noise") 47 | parser.add_argument('--uncertainty_model', default=2, type=int, help="1:batch+mean 2:batch+element") 48 | parser.add_argument('--uncertainty_factor', default=1.0, type=float) 49 | parser.add_argument('--mask_radio', default=0.5, type=float) 50 | parser.add_argument('--gauss_or_uniform', default=0, type=int, help="0: gaussian; 1: uniform; 2: random") 51 | parser.add_argument('--noise_layers', default=[0, 1, 2, 3], nargs="+", type=int, help="where to use augmentation.") 52 | 53 | parser.add_argument('--set_training_mode', default=1, type=int, help="0:eval 1:train") 54 | parser.add_argument('--freq_analyse', default=0, type=int, help="whether do frequency analyse") 55 | 56 | parser.add_argument('--data_root', default='/data/DataSets/', type=str, help='dataset path') 57 | parser.add_argument('--output_dir', default='/ALOFT_results/', help='path where to save, empty for no saving') 58 | 59 | # * Finetuning params 60 | parser.add_argument('--finetune', default='/pretrained_model/', help='finetune from checkpoint') 61 | 62 | # Model parameters 63 | parser.add_argument('--arch', default='gfnet-h-ti', type=str, 64 | help='Name of model to train') 65 | parser.add_argument('--input_size', default=224, type=int, help='images input size') 66 | 67 | parser.add_argument('--drop', type=float, default=0.0, metavar='PCT', 68 | help='Dropout rate (default: 0.)') 69 | parser.add_argument('--drop-path', type=float, default=0.1, metavar='PCT', 70 | help='Drop path rate (default: 0.1)') 71 | 72 | # parser.add_argument('--model-ema', action='store_true') 73 | parser.add_argument('--model-ema', default=False) 74 | parser.set_defaults(model_ema=False) 75 | parser.add_argument('--model-ema-decay', type=float, default=0.99996, help='') 76 | parser.add_argument('--model-ema-force-cpu', action='store_true', default=False, help='') 77 | 78 | # Optimizer parameters 79 | parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER', 80 | help='Optimizer (default: "adamw"') 81 | parser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON', 82 | help='Optimizer Epsilon (default: 1e-8)') 83 | parser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA', 84 | help='Optimizer Betas (default: None, use opt default)') 85 | parser.add_argument('--clip-grad', type=float, default=1, metavar='NORM', 86 | help='Clip gradient norm (default: None, no clipping)') 87 | parser.add_argument('--momentum', type=float, default=0.9, metavar='M', 88 | help='SGD momentum (default: 0.9)') 89 | parser.add_argument('--weight-decay', type=float, default=0.05, 90 | help='weight decay (default: 0.05)') 91 | 92 | # Learning rate schedule parameters 93 | parser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER', 94 | help='LR scheduler (default: "cosine"') 95 | parser.add_argument('--lr', type=float, default=0.0005, metavar='LR', 96 | help='learning rate (default: 5e-4)') 97 | parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct', 98 | help='learning rate noise on/off epoch percentages') 99 | parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT', 100 | help='learning rate noise limit percent (default: 0.67)') 101 | parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV', 102 | help='learning rate noise std-dev (default: 1.0)') 103 | parser.add_argument('--warmup-lr', type=float, default=1e-6, metavar='LR', 104 | help='warmup learning rate (default: 1e-6)') 105 | parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR', 106 | help='lower lr bound for cyclic schedulers that hit 0 (1e-5)') 107 | 108 | parser.add_argument('--decay-epochs', type=float, default=30, metavar='N', 109 | help='epoch interval to decay LR') 110 | parser.add_argument('--warmup-epochs', type=int, default=5, metavar='N', 111 | help='epochs to warmup LR, if scheduler supports') 112 | parser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N', 113 | help='epochs to cooldown LR at min_lr, after cyclic schedule ends') 114 | parser.add_argument('--patience-epochs', type=int, default=10, metavar='N', 115 | help='patience epochs for Plateau LR scheduler (default: 10') 116 | parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE', 117 | help='LR decay rate (default: 0.1)') 118 | 119 | # Augmentation parameters 120 | parser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT', 121 | help='Color jitter factor (default: 0.4)') 122 | parser.add_argument('--aa', type=str, default=None, metavar='NAME', 123 | help='Use AutoAugment policy. "v0" or "original". " + \ 124 | "(default: rand-m9-mstd0.5-inc1)'), 125 | parser.add_argument('--smoothing', type=float, default=0, help='Label smoothing (default: 0.1)') 126 | parser.add_argument('--train-interpolation', type=str, default='bicubic', 127 | help='Training interpolation (random, bilinear, bicubic default: "bicubic")') 128 | 129 | parser.add_argument('--repeated-aug', action='store_true') 130 | parser.set_defaults(repeated_aug=False) 131 | 132 | # * Random Erase params 133 | parser.add_argument('--reprob', type=float, default=0, metavar='PCT', 134 | help='Random erase prob (default: 0.25)') 135 | parser.add_argument('--remode', type=str, default='pixel', 136 | help='Random erase mode (default: "pixel")') 137 | parser.add_argument('--recount', type=int, default=1, 138 | help='Random erase count (default: 1)') 139 | parser.add_argument('--resplit', action='store_true', default=False, 140 | help='Do not random erase first (clean) augmentation split') 141 | 142 | # * Mixup params 143 | parser.add_argument('--mixup', type=float, default=0, 144 | help='mixup alpha, mixup enabled if > 0. (default: 0.8)') 145 | parser.add_argument('--cutmix', type=float, default=0, 146 | help='cutmix alpha, cutmix enabled if > 0. (default: 1.0)') 147 | parser.add_argument('--cutmix-minmax', type=float, nargs='+', default=None, 148 | help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)') 149 | parser.add_argument('--mixup-prob', type=float, default=1.0, 150 | help='Probability of performing mixup or cutmix when either/both is enabled') 151 | parser.add_argument('--mixup-switch-prob', type=float, default=0.5, 152 | help='Probability of switching to cutmix when both mixup and cutmix enabled') 153 | parser.add_argument('--mixup-mode', type=str, default='batch', 154 | help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"') 155 | 156 | # Distillation parameters 157 | parser.add_argument('--teacher-model', default='', type=str, metavar='MODEL', 158 | help='Name of teacher model to train (default: "regnety_160"') 159 | parser.add_argument('--teacher-path', type=str, default='') 160 | parser.add_argument('--distillation-type', default='none', choices=['none', 'soft', 'hard'], type=str, help="") 161 | parser.add_argument('--distillation-alpha', default=0.5, type=float, help="") 162 | parser.add_argument('--distillation-tau', default=1.0, type=float, help="") 163 | 164 | parser.add_argument('--resume', default='', help='resume from checkpoint') 165 | parser.add_argument('--eval', default=0, type=int, help='Perform evaluation only') 166 | 167 | parser.add_argument('--data', default='PACS', choices=['CIFAR', 'IMNET', 'INAT', 'INAT19', 'PACS', 'OfficeHome', 'VLCS', 'digits_dg'], 168 | type=str, help='Image Net dataset path') 169 | parser.add_argument('--inat-category', default='name', 170 | choices=['kingdom', 'phylum', 'class', 'order', 'supercategory', 'family', 'genus', 'name'], 171 | type=str, help='semantic granularity') 172 | 173 | parser.add_argument("--image_size", type=int, default=224, help="Image size") 174 | parser.add_argument("--min_scale", default=0.8, type=float, help="Minimum scale percent") 175 | parser.add_argument("--max_scale", default=1.0, type=float, help="Maximum scale percent") 176 | parser.add_argument("--gray_flag", default=0, type=int, help="whether use random gray") 177 | parser.add_argument("--random_horiz_flip", default=0.5, type=float, help="Chance of random horizontal flip") 178 | parser.add_argument("--jitter", default=0.4, type=float, help="Color jitter amount") 179 | parser.add_argument("--tile_random_grayscale", default=0.1, type=float, 180 | help="Chance of randomly greyscaling a tile") 181 | 182 | parser.add_argument('--seed', default=0, type=int) 183 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N', 184 | help='start epoch') 185 | # parser.add_argument('--eval', action='store_true', help='Perform evaluation only') 186 | 187 | parser.add_argument('--dist-eval', action='store_true', default=True, help='Enabling distributed evaluation') 188 | parser.add_argument('--num_workers', default=10, type=int) 189 | parser.add_argument('--pin-mem', action='store_true', 190 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') 191 | parser.add_argument('--no-pin-mem', action='store_false', dest='pin_mem', 192 | help='') 193 | parser.set_defaults(pin_mem=True) 194 | 195 | # distributed training parameters 196 | parser.add_argument('--world_size', default=1, type=int, 197 | help='number of distributed processes') 198 | parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training') 199 | return parser 200 | 201 | 202 | domain_map = { 203 | 'PACS': ['photo', 'art_painting', 'cartoon', 'sketch'], 204 | 'PACS_random_split': ['photo', 'art_painting', 'cartoon', 'sketch'], 205 | 'OfficeHome': ['Art', 'Clipart', 'Product', 'RealWorld'], 206 | 'VLCS': ["CALTECH", "LABELME", "PASCAL", "SUN"], 207 | 'digits_dg': ['mnist', 'mnist_m', 'svhn', 'syn'], 208 | } 209 | classes_map = { 210 | 'PACS': 7, 211 | 'PACS_random_split': 7, 212 | 'OfficeHome': 65, 213 | 'VLCS': 5, 214 | 'digits_dg': 32, 215 | } 216 | val_size_map = { 217 | 'PACS': 0.1, 218 | 'PACS_random_split': 0.1, 219 | 'OfficeHome': 0.1, 220 | 'VLCS': 0.3, 221 | 'digits_dg': 0.2, 222 | } 223 | 224 | 225 | def get_domain(name): 226 | if name not in domain_map: 227 | raise ValueError('Name of dataset unknown %s' %name) 228 | return domain_map[name] 229 | 230 | 231 | def main(args): 232 | utils.init_distributed_mode(args) 233 | 234 | domain = get_domain(args.data) 235 | args.target = domain.pop(args.target) 236 | args.source = domain 237 | print("Target domain: {}".format(args.target)) 238 | args.data_root = os.path.join(args.data_root, "PACS") if "PACS" in args.data else os.path.join(args.data_root, 239 | args.data) 240 | args.n_classes = classes_map[args.data] 241 | args.n_domains = len(domain) 242 | args.val_size = val_size_map[args.data] 243 | 244 | if args.distillation_type != 'none' and args.finetune and not args.eval: 245 | raise NotImplementedError("Finetuning with distillation not yet supported") 246 | 247 | os.environ['CUDA_VISIBLE_DEVICES'] = str(args.device) 248 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 249 | # device = torch.device(args.device) 250 | 251 | # fix the seed for reproducibility 252 | seed = args.seed + utils.get_rank() 253 | torch.manual_seed(seed) 254 | np.random.seed(seed) 255 | 256 | cudnn.benchmark = True 257 | 258 | args.nb_classes = args.n_classes 259 | 260 | data_loader_train, data_loader_val = data_helper.get_train_dataloader(args, patches=False) 261 | data_loader_test = data_helper.get_val_dataloader(args, patches=False) 262 | 263 | mixup_fn = None 264 | mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None 265 | if mixup_active: 266 | print('standard mix up') 267 | mixup_fn = Mixup( 268 | mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax, 269 | prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode, 270 | label_smoothing=args.smoothing, num_classes=args.nb_classes) 271 | else: 272 | print('mix up is not used') 273 | 274 | print(f"Creating model: {args.arch}") 275 | 276 | if args.arch == 'gfnet-h-ti': 277 | model = GFNetPyramid( 278 | img_size=args.input_size, 279 | patch_size=4, 280 | num_classes=args.n_classes, 281 | embed_dim=[64, 128, 256, 512], depth=[3, 3, 10, 3], 282 | mlp_ratio=[4, 4, 4, 4], 283 | norm_layer=partial(nn.LayerNorm, eps=1e-6), drop_path_rate=0.1, 284 | mask_radio=args.mask_radio, 285 | mask_alpha=args.mask_alpha, 286 | noise_mode=args.noise_mode, 287 | uncertainty_model=args.uncertainty_model, perturb_prob=args.perturb_prob, 288 | noise_layers=args.noise_layers, gauss_or_uniform=args.gauss_or_uniform, 289 | ) 290 | elif args.arch == 'gfnet-h-s': 291 | model = GFNetPyramid( 292 | img_size=args.input_size, 293 | patch_size=4, 294 | num_classes=args.n_classes, 295 | embed_dim=[96, 192, 384, 768], depth=[3, 3, 10, 3], 296 | mlp_ratio=[4, 4, 4, 4], 297 | norm_layer=partial(nn.LayerNorm, eps=1e-6), drop_path_rate=0.2, 298 | init_values=1e-5, 299 | mask_radio=args.mask_radio, 300 | mask_alpha=args.mask_alpha, 301 | noise_mode=args.noise_mode, 302 | uncertainty_model=args.uncertainty_model, perturb_prob=args.perturb_prob, 303 | noise_layers=args.noise_layers, gauss_or_uniform=args.gauss_or_uniform, 304 | ) 305 | else: 306 | raise NotImplementedError 307 | 308 | 309 | if args.finetune: 310 | args.finetune += "/" + args.arch + ".pth" 311 | 312 | if args.finetune.startswith('https'): 313 | checkpoint = torch.hub.load_state_dict_from_url( 314 | args.finetune, map_location='cpu', check_hash=True) 315 | else: 316 | checkpoint = torch.load(args.finetune, map_location='cpu') 317 | 318 | checkpoint_model = checkpoint['model'] 319 | state_dict = model.state_dict() 320 | for k in ['head.weight', 'head.bias', 'head_dist.weight', 'head_dist.bias']: 321 | if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape: 322 | print(f"Removing key {k} from pretrained checkpoint") 323 | del checkpoint_model[k] 324 | 325 | # interpolate position embedding 326 | pos_embed_checkpoint = checkpoint_model['pos_embed'] 327 | embedding_size = pos_embed_checkpoint.shape[-1] 328 | 329 | if args.arch in ['gfnet-ti', 'gfnet-xs', 'gfnet-s', 'gfnet-b']: 330 | num_patches = (args.input_size // 16) ** 2 331 | elif args.arch in ['gfnet-h-ti', 'gfnet-h-s', 'gfnet-h-b']: 332 | num_patches = (args.input_size // 4) ** 2 333 | else: 334 | raise NotImplementedError 335 | 336 | num_extra_tokens = 0 337 | # height (== width) for the checkpoint position embedding 338 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) 339 | # height (== width) for the new position embedding 340 | new_size = int(num_patches ** 0.5) 341 | 342 | scale_up_ratio = new_size / orig_size 343 | # class_token and dist_token are kept unchanged 344 | # only the position tokens are interpolated 345 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] 346 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) 347 | pos_tokens = torch.nn.functional.interpolate( 348 | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) 349 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) 350 | checkpoint_model['pos_embed'] = pos_tokens 351 | 352 | for name in checkpoint_model.keys(): 353 | if 'complex_weight' in name: 354 | h, w, num_heads = checkpoint_model[name].shape[0:3] # h, w, c, 2 355 | origin_weight = checkpoint_model[name] 356 | upsample_h = h * new_size // orig_size 357 | upsample_w = upsample_h // 2 + 1 358 | origin_weight = origin_weight.reshape(1, h, w, num_heads * 2).permute(0, 3, 1, 2) 359 | new_weight = torch.nn.functional.interpolate( 360 | origin_weight, size=(upsample_h, upsample_w), mode='bicubic', align_corners=True).permute(0, 2, 3, 1).reshape(upsample_h, upsample_w, num_heads, 2) 361 | checkpoint_model[name] = new_weight 362 | model.load_state_dict(checkpoint_model, strict=False) 363 | 364 | model.to(device) 365 | 366 | model_ema = None 367 | if args.model_ema: 368 | # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper 369 | model_ema = ModelEma( 370 | model, 371 | decay=args.model_ema_decay, 372 | device='cpu' if args.model_ema_force_cpu else '', 373 | resume='') 374 | 375 | model_without_ddp = model 376 | if args.distributed: 377 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) 378 | model_without_ddp = model.module 379 | n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) 380 | print('number of params:', n_parameters) 381 | 382 | if args.opt == "adamw": 383 | linear_scaled_lr = args.lr * args.batch_size * utils.get_world_size() / 512.0 384 | args.lr = linear_scaled_lr 385 | optimizer = create_optimizer(args, model_without_ddp) 386 | lr_scheduler, _ = create_scheduler(args, optimizer) 387 | elif args.opt == "sgd": 388 | optimizer = create_optimizer_v2(model_without_ddp, opt='sgd', lr=args.lr, weight_decay=0.0005, momentum=0.9) 389 | lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=args.epochs * 0.8, gamma=0.1) 390 | 391 | loss_scaler = NativeScaler() 392 | criterion = LabelSmoothingCrossEntropy() 393 | 394 | if args.mixup > 0. or args.cutmix > 0.: 395 | # smoothing is handled with mixup label transform 396 | criterion = SoftTargetCrossEntropy() 397 | elif args.smoothing: 398 | criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing) 399 | else: 400 | criterion = torch.nn.CrossEntropyLoss() 401 | 402 | teacher_model = None 403 | if args.distillation_type != 'none': 404 | assert args.teacher_path, 'need to specify teacher-path when using distillation' 405 | print(f"Creating teacher model: {args.teacher_model}") 406 | teacher_model = create_model( 407 | args.teacher_model, 408 | pretrained=False, 409 | num_classes=args.nb_classes, 410 | global_pool='avg', 411 | ) 412 | if args.teacher_path.startswith('https'): 413 | checkpoint = torch.hub.load_state_dict_from_url( 414 | args.teacher_path, map_location='cpu', check_hash=True) 415 | else: 416 | checkpoint = torch.load(args.teacher_path, map_location='cpu') 417 | teacher_model.load_state_dict(checkpoint['model']) 418 | teacher_model.to(device) 419 | teacher_model.eval() 420 | 421 | # wrap the criterion in our custom DistillationLoss, which 422 | # just dispatches to the original criterion if args.distillation_type is 'none' 423 | 424 | criterion = DistillationLoss( 425 | criterion, teacher_model, args.distillation_type, args.distillation_alpha, args.distillation_tau 426 | ) 427 | 428 | dir_name = args.opt + str(args.lr) + "E" + str(args.epochs) 429 | 430 | if args.arch == "gfnet-h-s": 431 | dir_name += "_gfnet-h-s" 432 | 433 | if args.noise_mode != 0: 434 | dir_name += "_noise" + "_M" + str(args.mask_radio) + "_p" + str(args.perturb_prob) 435 | if args.noise_mode == 1: 436 | dir_name += "_amp" 437 | if args.uncertainty_model != 0: 438 | if args.uncertainty_model == 1: 439 | dir_name += "_batch_mean" 440 | else: 441 | dir_name += "_batch_elem" 442 | if args.gauss_or_uniform == 1: 443 | dir_name += "_unif" 444 | elif args.gauss_or_uniform == 2: 445 | dir_name += "_randGauss" 446 | dir_name += "_f" + str(args.uncertainty_factor) 447 | 448 | if args.gray_flag == 0: 449 | dir_name += "_nogray" 450 | 451 | output_dir = os.path.join(args.output_dir, args.data, dir_name, args.target + str(args.seed)) 452 | if not os.path.exists(output_dir): 453 | os.makedirs(output_dir) 454 | output_dir = Path(output_dir) 455 | 456 | if args.resume: 457 | if args.resume.startswith('https'): 458 | checkpoint = torch.hub.load_state_dict_from_url( 459 | args.resume, map_location='cpu', check_hash=True) 460 | else: 461 | model_path = args.resume + "/" + args.target + str(args.seed) + "/checkpoint_last.pth" 462 | checkpoint = torch.load(model_path, map_location='cpu') 463 | model_without_ddp.load_state_dict(checkpoint['model'], strict=False) 464 | if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint: 465 | optimizer.load_state_dict(checkpoint['optimizer']) 466 | print('lr scheduler will not be updated') 467 | # lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) 468 | args.start_epoch = checkpoint['epoch'] + 1 469 | if args.model_ema: 470 | utils._load_checkpoint_for_ema(model_ema, checkpoint['model_ema']) 471 | if 'scaler' in checkpoint: 472 | loss_scaler.load_state_dict(checkpoint['scaler']) 473 | 474 | if args.eval: 475 | val_stats = evaluate(data_loader_val, model, device)['acc1'] 476 | test_stats = evaluate(data_loader_test, model, device)['acc1'] 477 | print(f"Accuracy of the network on the {len(data_loader_val.dataset)} val images: {val_stats:.2f}%") 478 | print(f"Accuracy of the network on the {len(data_loader_test.dataset)} test images: {test_stats:.2f}%") 479 | return 480 | 481 | print(f"Start training for {args.epochs} epochs") 482 | start_time = time.time() 483 | max_accuracy_test = 0.0 484 | max_accuracy_val = 0.0 485 | max_val_test = 0.0 486 | max_val_epoch = 0 487 | max_test_epoch = 0 488 | 489 | if args.set_training_mode == 1: 490 | args.set_training_mode = True 491 | else: 492 | args.set_training_mode = False 493 | 494 | for epoch in range(args.start_epoch, args.epochs): 495 | if args.distributed: 496 | data_loader_train.sampler.set_epoch(epoch) 497 | 498 | train_stats = train_one_epoch( 499 | model, criterion, data_loader_train, 500 | optimizer, device, epoch, loss_scaler, 501 | args.clip_grad, model_ema, mixup_fn, 502 | set_training_mode=args.set_training_mode, # keep in eval mode during finetuning 503 | ) 504 | 505 | lr_scheduler.step(epoch) 506 | 507 | if args.output_dir: 508 | checkpoint_paths = [output_dir / 'checkpoint_last.pth'] 509 | for checkpoint_path in checkpoint_paths: 510 | if model_ema is not None: 511 | utils.save_on_master({ 512 | 'model': model_without_ddp.state_dict(), 513 | 'model_ema': get_state_dict(model_ema), 514 | }, checkpoint_path) 515 | else: 516 | utils.save_on_master({ 517 | 'model': model_without_ddp.state_dict(), 518 | }, checkpoint_path) 519 | 520 | if (epoch + 1) % 100 == 0: 521 | file_name = 'checkpoint_epoch%d.pth' % epoch 522 | checkpoint_path = output_dir / file_name 523 | if model_ema is not None: 524 | utils.save_on_master({ 525 | 'model': model_without_ddp.state_dict(), 526 | 'model_ema': get_state_dict(model_ema), 527 | }, checkpoint_path) 528 | else: 529 | utils.save_on_master({ 530 | 'model': model_without_ddp.state_dict(), 531 | }, checkpoint_path) 532 | 533 | val_stats = evaluate(data_loader_val, model, device) 534 | print(f"Accuracy of the network on the {len(data_loader_val.dataset)} val images: {val_stats['acc1']:.2f}%") 535 | test_stats = evaluate(data_loader_test, model, device) 536 | print(f"Accuracy of the network on the {len(data_loader_test.dataset)} test images: {test_stats['acc1']:.2f}%") 537 | 538 | max_accuracy_val = max(max_accuracy_val, val_stats["acc1"]) 539 | print(f'Max accuracy val: {max_accuracy_val:.2f}%') 540 | print(f"Corresponding test accuracy: {test_stats['acc1']:.2f}%") 541 | max_accuracy_test = max(max_accuracy_test, test_stats["acc1"]) 542 | if max_accuracy_val == test_stats["acc1"]: 543 | max_test_epoch = epoch 544 | # print(f'Max accuracy test: {max_accuracy_test:.2f}%') 545 | 546 | if max_accuracy_val == val_stats["acc1"]: 547 | max_val_test = test_stats['acc1'] 548 | max_val_epoch = epoch 549 | 550 | checkpoint_path = output_dir / 'checkpoint_best.pth' 551 | if model_ema is not None: 552 | utils.save_on_master({ 553 | 'model': model_without_ddp.state_dict(), 554 | 'model_ema': get_state_dict(model_ema), 555 | }, checkpoint_path) 556 | else: 557 | utils.save_on_master({ 558 | 'model': model_without_ddp.state_dict(), 559 | }, checkpoint_path) 560 | 561 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 562 | **{f'val_{k}': v for k, v in val_stats.items()}, 563 | **{f'test_{k}': v for k, v in test_stats.items()}, 564 | 'epoch': epoch, 565 | 'n_parameters': n_parameters} 566 | 567 | if args.output_dir and utils.is_main_process(): 568 | with (output_dir / "log.txt").open("a") as f: 569 | f.write(json.dumps(log_stats) + "\n") 570 | 571 | log_stats = {**{f'Best val': max_accuracy_val}, 572 | **{f'Corresponding test': max_val_test}, 573 | **{f'At Epoch': max_val_epoch}, 574 | **{f'Best test': max_accuracy_test}, 575 | **{f'At Epoch': max_test_epoch}, 576 | } 577 | with (output_dir / "log.txt").open("a") as f: 578 | f.write(json.dumps(log_stats) + "\n") 579 | 580 | total_time = time.time() - start_time 581 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 582 | print('Training time {}'.format(total_time_str)) 583 | 584 | 585 | if __name__ == '__main__': 586 | parser = argparse.ArgumentParser('GFNet training and evaluation script', parents=[get_args_parser()]) 587 | args = parser.parse_args() 588 | if args.output_dir: 589 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 590 | # torch.autograd.set_detect_anomaly(True) 591 | main(args) 592 | -------------------------------------------------------------------------------- /samplers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | import torch 4 | import torch.distributed as dist 5 | import math 6 | 7 | 8 | class RASampler(torch.utils.data.Sampler): 9 | """Sampler that restricts data loading to a subset of the dataset for distributed, 10 | with repeated augmentation. 11 | It ensures that different each augmented version of a sample will be visible to a 12 | different process (GPU) 13 | Heavily based on torch.utils.data.DistributedSampler 14 | """ 15 | 16 | def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True): 17 | if num_replicas is None: 18 | if not dist.is_available(): 19 | raise RuntimeError("Requires distributed package to be available") 20 | num_replicas = dist.get_world_size() 21 | if rank is None: 22 | if not dist.is_available(): 23 | raise RuntimeError("Requires distributed package to be available") 24 | rank = dist.get_rank() 25 | self.dataset = dataset 26 | self.num_replicas = num_replicas 27 | self.rank = rank 28 | self.epoch = 0 29 | self.num_samples = int(math.ceil(len(self.dataset) * 3.0 / self.num_replicas)) 30 | self.total_size = self.num_samples * self.num_replicas 31 | # self.num_selected_samples = int(math.ceil(len(self.dataset) / self.num_replicas)) 32 | self.num_selected_samples = int(math.floor(len(self.dataset) // 256 * 256 / self.num_replicas)) 33 | self.shuffle = shuffle 34 | 35 | def __iter__(self): 36 | # deterministically shuffle based on epoch 37 | g = torch.Generator() 38 | g.manual_seed(self.epoch) 39 | if self.shuffle: 40 | indices = torch.randperm(len(self.dataset), generator=g).tolist() 41 | else: 42 | indices = list(range(len(self.dataset))) 43 | 44 | # add extra samples to make it evenly divisible 45 | indices = [ele for ele in indices for i in range(3)] 46 | indices += indices[:(self.total_size - len(indices))] 47 | assert len(indices) == self.total_size 48 | 49 | # subsample 50 | indices = indices[self.rank:self.total_size:self.num_replicas] 51 | assert len(indices) == self.num_samples 52 | 53 | return iter(indices[:self.num_selected_samples]) 54 | 55 | def __len__(self): 56 | return self.num_selected_samples 57 | 58 | def set_epoch(self, epoch): 59 | self.epoch = epoch 60 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | """ 4 | Misc functions, including distributed helpers. 5 | 6 | Mostly copy-paste from torchvision references. 7 | """ 8 | import io 9 | import os 10 | import time 11 | from collections import defaultdict, deque 12 | import datetime 13 | 14 | import torch 15 | import torch.distributed as dist 16 | 17 | 18 | class SmoothedValue(object): 19 | """Track a series of values and provide access to smoothed values over a 20 | window or the global series average. 21 | """ 22 | 23 | def __init__(self, window_size=20, fmt=None): 24 | if fmt is None: 25 | fmt = "{median:.4f} ({global_avg:.4f})" 26 | self.deque = deque(maxlen=window_size) 27 | self.total = 0.0 28 | self.count = 0 29 | self.fmt = fmt 30 | 31 | def update(self, value, n=1): 32 | self.deque.append(value) 33 | self.count += n 34 | self.total += value * n 35 | 36 | def synchronize_between_processes(self): 37 | """ 38 | Warning: does not synchronize the deque! 39 | """ 40 | if not is_dist_avail_and_initialized(): 41 | return 42 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') 43 | dist.barrier() 44 | dist.all_reduce(t) 45 | t = t.tolist() 46 | self.count = int(t[0]) 47 | self.total = t[1] 48 | 49 | @property 50 | def median(self): 51 | d = torch.tensor(list(self.deque)) 52 | return d.median().item() 53 | 54 | @property 55 | def avg(self): 56 | d = torch.tensor(list(self.deque), dtype=torch.float32) 57 | return d.mean().item() 58 | 59 | @property 60 | def global_avg(self): 61 | return self.total / self.count 62 | 63 | @property 64 | def max(self): 65 | return max(self.deque) 66 | 67 | @property 68 | def value(self): 69 | return self.deque[-1] 70 | 71 | def __str__(self): 72 | return self.fmt.format( 73 | median=self.median, 74 | avg=self.avg, 75 | global_avg=self.global_avg, 76 | max=self.max, 77 | value=self.value) 78 | 79 | 80 | class MetricLogger(object): 81 | def __init__(self, delimiter="\t"): 82 | self.meters = defaultdict(SmoothedValue) 83 | self.delimiter = delimiter 84 | 85 | def update(self, **kwargs): 86 | for k, v in kwargs.items(): 87 | if isinstance(v, torch.Tensor): 88 | v = v.item() 89 | assert isinstance(v, (float, int)) 90 | self.meters[k].update(v) 91 | 92 | def __getattr__(self, attr): 93 | if attr in self.meters: 94 | return self.meters[attr] 95 | if attr in self.__dict__: 96 | return self.__dict__[attr] 97 | raise AttributeError("'{}' object has no attribute '{}'".format( 98 | type(self).__name__, attr)) 99 | 100 | def __str__(self): 101 | loss_str = [] 102 | for name, meter in self.meters.items(): 103 | loss_str.append( 104 | "{}: {}".format(name, str(meter)) 105 | ) 106 | return self.delimiter.join(loss_str) 107 | 108 | def synchronize_between_processes(self): 109 | for meter in self.meters.values(): 110 | meter.synchronize_between_processes() 111 | 112 | def add_meter(self, name, meter): 113 | self.meters[name] = meter 114 | 115 | def log_every(self, iterable, print_freq, header=None): 116 | i = 0 117 | if not header: 118 | header = '' 119 | start_time = time.time() 120 | end = time.time() 121 | iter_time = SmoothedValue(fmt='{avg:.4f}') 122 | data_time = SmoothedValue(fmt='{avg:.4f}') 123 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd' 124 | log_msg = [ 125 | header, 126 | '[{0' + space_fmt + '}/{1}]', 127 | 'eta: {eta}', 128 | '{meters}', 129 | 'time: {time}', 130 | 'data: {data}' 131 | ] 132 | if torch.cuda.is_available(): 133 | log_msg.append('max mem: {memory:.0f}') 134 | log_msg = self.delimiter.join(log_msg) 135 | MB = 1024.0 * 1024.0 136 | for obj in iterable: 137 | data_time.update(time.time() - end) 138 | yield obj 139 | iter_time.update(time.time() - end) 140 | if i % print_freq == 0 or i == len(iterable) - 1: 141 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 142 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 143 | if torch.cuda.is_available(): 144 | print(log_msg.format( 145 | i, len(iterable), eta=eta_string, 146 | meters=str(self), 147 | time=str(iter_time), data=str(data_time), 148 | memory=torch.cuda.max_memory_allocated() / MB)) 149 | else: 150 | print(log_msg.format( 151 | i, len(iterable), eta=eta_string, 152 | meters=str(self), 153 | time=str(iter_time), data=str(data_time))) 154 | i += 1 155 | end = time.time() 156 | total_time = time.time() - start_time 157 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 158 | print('{} Total time: {} ({:.4f} s / it)'.format( 159 | header, total_time_str, total_time / len(iterable))) 160 | 161 | 162 | def _load_checkpoint_for_ema(model_ema, checkpoint): 163 | """ 164 | Workaround for ModelEma._load_checkpoint to accept an already-loaded object 165 | """ 166 | mem_file = io.BytesIO() 167 | torch.save(checkpoint, mem_file) 168 | mem_file.seek(0) 169 | model_ema._load_checkpoint(mem_file) 170 | 171 | 172 | def setup_for_distributed(is_master): 173 | """ 174 | This function disables printing when not in master process 175 | """ 176 | import builtins as __builtin__ 177 | builtin_print = __builtin__.print 178 | 179 | def print(*args, **kwargs): 180 | force = kwargs.pop('force', False) 181 | if is_master or force: 182 | builtin_print(*args, **kwargs) 183 | 184 | __builtin__.print = print 185 | 186 | 187 | def is_dist_avail_and_initialized(): 188 | if not dist.is_available(): 189 | return False 190 | if not dist.is_initialized(): 191 | return False 192 | return True 193 | 194 | 195 | def get_world_size(): 196 | if not is_dist_avail_and_initialized(): 197 | return 1 198 | return dist.get_world_size() 199 | 200 | 201 | def get_rank(): 202 | if not is_dist_avail_and_initialized(): 203 | return 0 204 | return dist.get_rank() 205 | 206 | 207 | def is_main_process(): 208 | return get_rank() == 0 209 | 210 | 211 | def save_on_master(*args, **kwargs): 212 | if is_main_process(): 213 | torch.save(*args, **kwargs) 214 | 215 | 216 | def init_distributed_mode(args): 217 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 218 | args.rank = int(os.environ["RANK"]) 219 | args.world_size = int(os.environ['WORLD_SIZE']) 220 | args.gpu = int(os.environ['LOCAL_RANK']) 221 | elif 'SLURM_PROCID' in os.environ: 222 | args.rank = int(os.environ['SLURM_PROCID']) 223 | args.gpu = args.rank % torch.cuda.device_count() 224 | else: 225 | print('Not using distributed mode') 226 | args.distributed = False 227 | return 228 | 229 | args.distributed = True 230 | 231 | torch.cuda.set_device(args.gpu) 232 | args.dist_backend = 'nccl' 233 | print('| distributed init (rank {}): {}'.format( 234 | args.rank, args.dist_url), flush=True) 235 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 236 | world_size=args.world_size, rank=args.rank) 237 | torch.distributed.barrier() 238 | setup_for_distributed(args.rank == 0) 239 | 240 | 241 | def batch_index_select(x, idx): 242 | if len(x.size()) == 3: 243 | B, N, C = x.size() 244 | N_new = idx.size(1) 245 | offset = torch.arange(B, dtype=torch.long, device=x.device).view(B, 1) * N 246 | idx = idx + offset 247 | out = x.reshape(B*N, C)[idx.reshape(-1)].reshape(B, N_new, C) 248 | return out 249 | elif len(x.size()) == 2: 250 | B, N = x.size() 251 | N_new = idx.size(1) 252 | offset = torch.arange(B, dtype=torch.long, device=x.device).view(B, 1) * N 253 | idx = idx + offset 254 | out = x.reshape(B*N)[idx.reshape(-1)].reshape(B, N_new) 255 | return out 256 | else: 257 | raise NotImplementedError --------------------------------------------------------------------------------