├── README.md ├── common ├── __init__.py ├── data_reader.py └── utils.py ├── main_agg.py ├── model_resnet_all.py ├── requirements.txt ├── resnet_decoder.py └── resnet_vanilla_updata.py /README.md: -------------------------------------------------------------------------------- 1 | # FFDI 2 | ## Introduction 3 | Code release for "Domain Generalization via Frequency-domain-based Feature Disentanglement and Interaction" (ACM MM 2022): [https://arxiv.org/abs/2201.08029](https://arxiv.org/abs/2201.08029). 4 | 5 | Part of the code is inherited from [Episodic-DG](https://github.com/HAHA-DL/Episodic-DG). 6 | 7 | ## Enviroments 8 | ```bash 9 | GPU GeForce RTX 1080 Ti 10 | pytorch==1.9.0 11 | torchvision==0.10.0 12 | cudatoolkit==10.2.89 13 | ``` 14 | 15 | ## Prepare 16 | ### Datasets 17 | Please download the [PACS](https://drive.google.com/drive/folders/0B6x7gtvErXgfUU1WcGY5SzdwZVk?resourcekey=0-2fvpQY_QSyJf2uIECzqPuQ&usp=sharing) datasets and use the official train/val split. 18 | 19 | ### ImageNet pretrained model 20 | We use the pytorch pretrained ResNet-18 model from [https://download.pytorch.org/models/resnet18-5c106cde.pth](https://download.pytorch.org/models/resnet18-5c106cde.pth). 21 | 22 | ## Run 23 | - Train from scratch with command: 24 | ```bash 25 | 26 | CUDA_VISIBLE_DEVICES=0,1,2,3 python main_agg.py 27 | 28 | ``` 29 | 30 | ## Reference 31 | 32 | Please cite the related works in your publications if it helps your research: 33 | 34 | @inproceedings{wang2022domain, 35 | title={Domain Generalization via Frequency-domain-based Feature Disentanglement and Interaction}, 36 | author={Wang, Jingye and Du, Ruoyi and Chang, Dongliang and Liang, KongMing and Ma, Zhanyu}, 37 | booktitle={Proceedings of the 30th ACM International Conference on Multimedia}, 38 | year={2022} 39 | } 40 | 41 | 42 | 43 | 44 | -------------------------------------------------------------------------------- /common/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PRIS-CV/FFDI/77fa09f323adb6dfd212677d6a574a17b31300cf/common/__init__.py -------------------------------------------------------------------------------- /common/data_reader.py: -------------------------------------------------------------------------------- 1 | import h5py 2 | import numpy as np 3 | import random 4 | from PIL import Image 5 | import os 6 | 7 | dirpath = os.pardir 8 | import sys 9 | 10 | sys.path.append(dirpath) 11 | from common.utils import unfold_label, shuffle_data, my_fft, my_fft_trans 12 | 13 | 14 | class BatchImageGenerator: #datasets 15 | def __init__(self, flags, stage, file_path, b_unfold_label): 16 | 17 | if stage not in ['train', 'val', 'test']: 18 | assert ValueError('invalid stage!') 19 | 20 | self.configuration(flags, stage, file_path) 21 | self.load_data(b_unfold_label) 22 | 23 | def configuration(self, flags, stage, file_path): 24 | self.batch_size = flags.batch_size 25 | self.file_path = file_path 26 | self.stage = stage 27 | self.flags = flags 28 | self.nums_d = len(self.file_path) 29 | self.current_indexs = [-1 for _ in range(self.nums_d)] 30 | 31 | 32 | def normalize(self, inputs): 33 | 34 | # the mean and std used for the normalization of 35 | # the inputs for the pytorch pretrained model 36 | mean = [0.485, 0.456, 0.406] 37 | std = [0.229, 0.224, 0.225] 38 | 39 | # norm to [0, 1] 40 | inputs = inputs / 255.0 41 | 42 | inputs_norm = [] #注释了 43 | for item in inputs: 44 | item = np.transpose(item, (2, 0, 1)) #from hwc to chw 45 | item_norm = [] 46 | for c, m, s in zip(item, mean, std): 47 | c = np.subtract(c, m) 48 | c = np.divide(c, s) 49 | item_norm.append(c) 50 | item_norm = np.stack(item_norm) 51 | inputs_norm.append(item_norm) 52 | 53 | return np.stack(inputs_norm) 54 | 55 | def load_data(self, b_unfold_label): 56 | # resize the image to 224 for the pretrained model 57 | def resize(x): 58 | x = x[:, :, 59 | [2, 1, 0]] # we use the pre-read hdf5 data file from the download page and need to change BGR to RGB 60 | x = x.astype(np.uint8) 61 | img = np.array(Image.fromarray(x).resize((224, 224))) 62 | 63 | return img 64 | 65 | self.images = [[] for _ in range(self.nums_d)] 66 | self.labels = [[] for _ in range(self.nums_d)] 67 | 68 | for d_index in range(self.nums_d): 69 | f = h5py.File(self.file_path[d_index], "r") 70 | print(len(f['images'])) 71 | images = np.array(list(map(resize, np.array(f['images'])))) 72 | self.images[d_index] = self.normalize(images) 73 | self.labels[d_index] = np.array(f['labels']) 74 | f.close() 75 | 76 | self.file_num_trains = [[] for _ in range(self.nums_d)] 77 | for d_index in range(self.nums_d): 78 | assert np.max(self.images[d_index]) <= 5.0 and np.min(self.images[d_index]) >= -5.0 79 | assert len(self.images[d_index]) == len(self.labels[d_index]) 80 | # shift the labels to start from 0 81 | self.labels[d_index] -= np.min(self.labels[d_index]) 82 | self.file_num_trains[d_index] = len(self.labels[d_index]) 83 | 84 | if self.stage is 'train': 85 | for d_index in range(self.nums_d): 86 | self.images[d_index], self.labels[d_index] = \ 87 | shuffle_data(samples=self.images[d_index], labels=self.labels[d_index]) 88 | 89 | def get_images_labels_batch(self, domain_index): 90 | 91 | images = [] 92 | labels = [] 93 | H = [] 94 | L = [] 95 | for index in range(self.batch_size): 96 | self.current_indexs[domain_index]+=1 97 | # void over flow 98 | if self.current_indexs[domain_index] > self.file_num_trains[domain_index] - 1: 99 | self.current_indexs[domain_index] %= self.file_num_trains[domain_index] 100 | self.images[domain_index], self.labels[domain_index] = \ 101 | shuffle_data(samples=self.images[domain_index], labels=self.labels[domain_index]) 102 | 103 | image_src = self.images[domain_index][self.current_indexs[domain_index]] 104 | decoder_H, decoder_L, decoder_H_ag, decoder_L_ag, thresh = my_fft(image_src, self.flags.threshold) 105 | decoder_H = np.transpose(decoder_H, (2, 0, 1)) 106 | decoder_L = np.transpose(decoder_L, (2, 0, 1)) 107 | decoder_H_ag = np.transpose(decoder_H_ag, (2, 0, 1)) 108 | decoder_L_ag = np.transpose(decoder_L_ag, (2, 0, 1)) 109 | 110 | image_ag = my_fft_trans(image_src, thresh) 111 | image_ag = self.normalize(np.array([image_ag])).squeeze() 112 | 113 | images.append(image_ag) 114 | images.append(self.images[domain_index][self.current_indexs[domain_index]]) 115 | 116 | labels.extend([self.labels[domain_index][self.current_indexs[domain_index]]]*2) 117 | H.extend([decoder_H_ag/255, decoder_H/255]) 118 | L.extend([decoder_L_ag/255, decoder_L/255]) 119 | 120 | images = np.stack(images) 121 | labels = np.stack(labels) 122 | H = np.stack(H) 123 | L = np.stack(L) 124 | return images, labels, H, L 125 | 126 | -------------------------------------------------------------------------------- /common/utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | from PIL import Image 3 | import os 4 | import numpy as np 5 | import torch 6 | import torch.optim as optim 7 | from sklearn.metrics import accuracy_score 8 | import cv2 9 | 10 | 11 | def decoder_image(img, mean, std): 12 | inputs_decoder = [] 13 | for ss, m, s in zip(img, mean, std): 14 | ss = np.array(ss * s) 15 | ss = np.array(ss + m) 16 | ss = ss * 255 17 | inputs_decoder.append(ss) 18 | return np.stack(inputs_decoder) 19 | 20 | def my_fft(img, threshold): 21 | 22 | mean = [0.485, 0.456, 0.406] 23 | std = [0.229, 0.224, 0.225] 24 | 25 | img = decoder_image(img, mean, std) 26 | img = np.transpose(img, (1, 2, 0)) 27 | 28 | H,W,C = img.shape 29 | 30 | if threshold == None: 31 | thresholds = [15, 25, 35, 45, 55, 65, 75, 85] 32 | index = np.random.randint(0, len(thresholds)) 33 | threshold = thresholds[index] 34 | 35 | f = np.fft.fft2(img, axes=(0,1)) 36 | fshift = np.fft.fftshift(f) 37 | 38 | crows, ccols =int(H/2), int(W/2) 39 | mask = np.zeros((H, W, C), dtype=np.uint8) 40 | mask[crows-threshold:crows+threshold, ccols-threshold:ccols+threshold] = 1 #求低频 41 | fshift = fshift * mask 42 | ishift = np.fft.ifftshift(fshift) 43 | i_img = np.fft.ifft2(ishift, axes=(0,1)) 44 | i_img_L = np.abs(i_img) 45 | 46 | img_H_temp = (img - i_img_L) 47 | img_H_temp = img_H_temp*(255/np.max(img_H_temp))*3 48 | 49 | img_H_temp[img_H_temp>255] = 255 50 | img_H_temp[img_H_temp<0] = 0 51 | i_img_L[i_img_L>255] = 255 52 | i_img_L[i_img_L<0] = 0 53 | 54 | img_H_temp = img_H_temp.astype(np.uint8) 55 | img_H_temp = np.array(Image.fromarray(img_H_temp).resize((112, 112))) 56 | i_img_L = i_img_L.astype(np.uint8) 57 | i_img_L = np.array(Image.fromarray(i_img_L).resize((112, 112))) 58 | 59 | img_H_temp_ag = np.array(Image.fromarray(img_H_temp).transpose(Image.FLIP_LEFT_RIGHT)) 60 | i_img_L_ag = np.array(Image.fromarray(i_img_L).transpose(Image.FLIP_LEFT_RIGHT)) 61 | 62 | return img_H_temp, i_img_L, img_H_temp_ag, i_img_L_ag, threshold 63 | 64 | def gen_gaussian_noise(image,SNR): 65 | """ 66 | :param image: source image 67 | :param SNR: signal-noise ratio 68 | :return: noise 69 | """ 70 | assert len(image.shape) == 3 71 | H, W, C = image.shape 72 | noise=np.random.randn(H, W, 1) 73 | noise = noise - np.mean(noise) 74 | image_power=1/(H*W*3)*np.sum(np.power(image,2)) 75 | noise_variance=image_power/np.power(10,(SNR/10)) 76 | noise=(np.sqrt(noise_variance)/np.std(noise))*noise 77 | return noise 78 | 79 | def my_fft_trans(img1, threshold): 80 | #pre-process 81 | mean = [0.485, 0.456, 0.406] 82 | std = [0.229, 0.224, 0.225] 83 | 84 | img1 = decoder_image(img1, mean, std) 85 | 86 | img1 = np.transpose(img1, (1, 2, 0)) 87 | 88 | H,W,C = img1.shape 89 | crows, ccols = H//2, W//2 90 | 91 | f1= np.fft.fft2(img1, axes=(0,1)) 92 | 93 | fig_abs_temp = np.abs(f1) 94 | fig_pha_temp = np.angle(f1) 95 | 96 | #add noise 97 | noise_w_h_p = np.random.uniform(0.8, 1.2, (H,W,1)) 98 | noise_b_h_p = np.random.uniform(-np.pi/6, np.pi/6, (H,W,1)) 99 | fig_pha_ag = noise_w_h_p*fig_pha_temp + noise_b_h_p 100 | 101 | noise_w_h_a = np.random.uniform(0.5, 1.5, (H,W,1)) 102 | noise_b_h_a = gen_gaussian_noise(fig_abs_temp, 30) 103 | fig_abs_ag = noise_w_h_a*fig_abs_temp + noise_b_h_a 104 | 105 | f_ag = fig_abs_ag*np.cos(fig_pha_ag) + fig_abs_ag*np.sin(fig_pha_ag)*1j 106 | 107 | #ifft 108 | img_ag = np.fft.ifft2(f_ag, axes=(0,1)) 109 | img_ag = np.abs(img_ag) 110 | img_ag = np.uint8(np.clip(img_ag, 0, 255)) 111 | 112 | img_ag = np.array(Image.fromarray(img_ag).transpose(Image.FLIP_LEFT_RIGHT)) 113 | 114 | return img_ag 115 | 116 | def MMD_Loss_func(num_source, sigmas=None): 117 | if sigmas is None: 118 | sigmas = [1, 5, 10] 119 | def loss(e_pred,d_ture): 120 | cost = 0.0 121 | for i in range(num_source): 122 | domain_i = e_pred[d_ture == i] 123 | for j in range(i+1,num_source): 124 | domain_j = e_pred[d_ture == j] 125 | single_res = mmd_two_distribution(domain_i,domain_j,sigmas=sigmas) 126 | cost += single_res 127 | return cost 128 | return loss 129 | 130 | def mmd_two_distribution(source, target, sigmas): 131 | sigmas = torch.tensor(sigmas).cuda() 132 | xy = rbf_kernel(source, target, sigmas) 133 | xx = rbf_kernel(source, source, sigmas) 134 | yy = rbf_kernel(target, target, sigmas) 135 | return xx + yy - 2 * xy 136 | 137 | def rbf_kernel(x, y, sigmas): 138 | sigmas = sigmas.reshape(sigmas.shape + (1,)) 139 | beta = 1. / (2. * sigmas) 140 | dist = compute_pairwise_distances(x, y) 141 | dot = -torch.matmul(beta, torch.reshape(dist, (1, -1))) 142 | exp = torch.mean(torch.exp(dot)) 143 | return exp 144 | 145 | def compute_pairwise_distances(x, y): 146 | dist = torch.zeros(x.size(0),y.size(0)).cuda() 147 | for i in range(x.size(0)): 148 | dist[i,:] = torch.sum(torch.square(x[i].expand(y.shape) - y),dim=1) 149 | return dist 150 | 151 | 152 | def unfold_label(labels, classes): 153 | # can not be used when classes are not complete 154 | new_labels = [] 155 | 156 | assert len(np.unique(labels)) == classes 157 | # minimum value of labels 158 | mini = np.min(labels) 159 | 160 | for index in range(len(labels)): 161 | dump = np.full(shape=[classes], fill_value=0).astype(np.int8) 162 | _class = int(labels[index]) - mini 163 | dump[_class] = 1 164 | new_labels.append(dump) 165 | 166 | return np.array(new_labels) 167 | 168 | 169 | def shuffle_data(samples, labels): 170 | num = len(labels) 171 | shuffle_index = np.random.permutation(np.arange(num)) 172 | shuffled_samples = samples[shuffle_index] 173 | shuffled_labels = labels[shuffle_index] 174 | return shuffled_samples, shuffled_labels 175 | 176 | 177 | def shuffle_list(li): 178 | np.random.shuffle(li) 179 | return li 180 | 181 | 182 | def shuffle_list_with_ind(li): 183 | shuffle_index = np.random.permutation(np.arange(len(li))) 184 | li = li[shuffle_index] 185 | return li, shuffle_index 186 | 187 | 188 | def num_flat_features(x): 189 | size = x.size()[1:] # all dimensions except the batch dimension 190 | num_features = 1 191 | for s in size: 192 | num_features *= s 193 | return num_features 194 | 195 | 196 | def crossentropyloss(): 197 | loss_fn = torch.nn.CrossEntropyLoss() 198 | return loss_fn 199 | 200 | 201 | def mseloss(): 202 | loss_fn = torch.nn.MSELoss() 203 | return loss_fn 204 | 205 | 206 | def sgd(parameters, lr, weight_decay=0.00005, momentum=0.9): 207 | opt = optim.SGD(params=parameters, lr=lr, momentum=momentum, weight_decay=weight_decay) 208 | return opt 209 | 210 | 211 | def write_log(log, log_path): 212 | f = open(log_path, mode='a') 213 | f.write(str(log)) 214 | f.write('\n') 215 | f.close() 216 | 217 | 218 | def fix_python_seed(seed): 219 | print('seed-----------python', seed) 220 | random.seed(seed) 221 | np.random.seed(seed) 222 | 223 | 224 | def fix_torch_seed(seed): 225 | print('seed-----------torch', seed) 226 | torch.manual_seed(seed) 227 | torch.cuda.manual_seed_all(seed) 228 | 229 | 230 | def fix_all_seed(seed): 231 | print('seed-----------all device', seed) 232 | os.environ['PYTHONHASHSEED'] = str(seed) 233 | random.seed(seed) 234 | np.random.seed(seed) 235 | torch.manual_seed(seed) 236 | torch.cuda.manual_seed(seed) 237 | torch.cuda.manual_seed_all(seed) 238 | 239 | 240 | def compute_accuracy(predictions, labels): 241 | if np.ndim(labels) == 2: 242 | y_true = np.argmax(labels, axis=-1) 243 | else: 244 | y_true = labels 245 | accuracy = accuracy_score(y_true=y_true, y_pred=np.argmax(predictions, axis=-1)) 246 | return accuracy 247 | -------------------------------------------------------------------------------- /main_agg.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from model_resnet_all import ModelAggregate 4 | 5 | 6 | def main(): 7 | train_arg_parser = argparse.ArgumentParser(description="parser") 8 | train_arg_parser.add_argument("--test_every", type=int, default=100, 9 | help="") 10 | train_arg_parser.add_argument("--batch_size", type=int, default=8, 11 | help="") 12 | train_arg_parser.add_argument("--num_classes", type=int, default=7, 13 | help="") 14 | train_arg_parser.add_argument("--num_domains", type=int, default=3, 15 | help="") 16 | train_arg_parser.add_argument("--step_size", type=int, default=1001, 17 | help="") 18 | train_arg_parser.add_argument("--loops_train", type=int, default=5001, 19 | help="") 20 | train_arg_parser.add_argument("--unseen_index", type=int, default=1, 21 | help="") 22 | train_arg_parser.add_argument("--lr", type=float, default=[0.001, 0.01], 23 | help='') 24 | train_arg_parser.add_argument("--weight_decay", type=float, default=0.0001, 25 | help='') 26 | train_arg_parser.add_argument("--momentum", type=float, default=0.9, 27 | help='') 28 | train_arg_parser.add_argument("--logs", type=str, default='', 29 | help='日志目录') 30 | train_arg_parser.add_argument("--model_path", type=str, default='', 31 | help='保存模型地址的目录') 32 | train_arg_parser.add_argument("--state_dict", type=str, default='', 33 | help='起始模型的地址') 34 | train_arg_parser.add_argument("--data_root", type=str, default='./datasets/PACS_DataSet/Train_val_splits_and_h5py_files_pre-read', 35 | help='数据集的目录') 36 | train_arg_parser.add_argument("--threshold",type=int, default=25, 37 | help='保存decoder生成的图片地址') 38 | 39 | args = train_arg_parser.parse_args() 40 | 41 | index = [0,1,2,3] 42 | styles = ['art','cartoon','photo','sketch',] 43 | for x in range(3): 44 | for i in index: 45 | args.unseen_index = i 46 | args.logs = 'logs_PACS/{}/attack_{}'.format(str(x),styles[i]) 47 | args.model_path = 'logs_PACS/{}/attack_{}_model'.format(str(x),styles[i]) 48 | model_obj = ModelAggregate(flags=args) 49 | model_obj.train(flags=args) 50 | 51 | 52 | if __name__ == "__main__": 53 | main() 54 | -------------------------------------------------------------------------------- /model_resnet_all.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | dirpath = os.pardir 4 | import sys 5 | 6 | sys.path.append(dirpath) 7 | import torch.utils.model_zoo as model_zoo 8 | from torch.autograd import Variable 9 | from torch.optim import lr_scheduler 10 | 11 | import resnet_vanilla_updata 12 | from common.data_reader import BatchImageGenerator 13 | from common.utils import * 14 | import cv2 15 | 16 | class ModelAggregate: 17 | def __init__(self, flags): 18 | if torch.cuda.is_available(): 19 | torch.set_default_tensor_type('torch.cuda.FloatTensor') 20 | else: 21 | torch.set_default_tensor_type('torch.FloatTensor') 22 | print("使用cpu........") 23 | 24 | 25 | self.setup(flags) 26 | self.setup_path(flags) 27 | self.configure(flags) 28 | self.mean = [0.485, 0.456, 0.406] 29 | self.std = [0.229, 0.224, 0.225] 30 | 31 | def setup(self, flags): 32 | 33 | model = resnet_vanilla_updata.resnet18(pretrained=True, num_classes=flags.num_classes, num_domains=flags.num_domains, flags=flags) 34 | 35 | if torch.cuda.is_available(): 36 | if torch.cuda.device_count()>1: 37 | model = torch.nn.DataParallel(model) 38 | self.network = model.cuda() 39 | else: 40 | self.network = model.cpu() 41 | 42 | print(self.network) 43 | print('flags:', flags) 44 | if not os.path.exists(flags.logs): 45 | os.makedirs(flags.logs) 46 | 47 | flags_log = os.path.join(flags.logs, 'flags_log.txt') 48 | write_log(flags, flags_log) 49 | 50 | def setup_path(self, flags): 51 | 52 | root_folder = flags.data_root 53 | train_data = [ 54 | 'art_painting_train.hdf5', 55 | 'cartoon_train.hdf5', 56 | 'photo_train.hdf5', 57 | 'sketch_train.hdf5' 58 | ] 59 | 60 | val_data = [ 61 | 'art_painting_val.hdf5', 62 | 'cartoon_val.hdf5', 63 | 'photo_val.hdf5', 64 | 'sketch_val.hdf5' 65 | ] 66 | 67 | test_data = [ 68 | 'art_painting_test.hdf5', 69 | 'cartoon_test.hdf5', 70 | 'photo_test.hdf5', 71 | 'sketch_test.hdf5' 72 | ] 73 | 74 | self.train_paths = [] 75 | for data in train_data: 76 | path = os.path.join(root_folder, data) 77 | self.train_paths.append(path) 78 | 79 | self.val_paths = [] 80 | for data in val_data: 81 | path = os.path.join(root_folder, data) 82 | self.val_paths.append(path) 83 | 84 | unseen_index = flags.unseen_index 85 | 86 | self.unseen_data_path = os.path.join(root_folder, test_data[unseen_index]) 87 | self.train_paths.remove(self.train_paths[unseen_index]) 88 | self.val_paths.remove(self.val_paths[unseen_index]) 89 | 90 | if not os.path.exists(flags.logs): 91 | os.makedirs(flags.logs) 92 | 93 | flags_log = os.path.join(flags.logs, 'path_log.txt') 94 | write_log(str(self.train_paths), flags_log) 95 | write_log(str(self.val_paths), flags_log) 96 | write_log(str(self.unseen_data_path), flags_log) 97 | 98 | self.batImageGenTrains = BatchImageGenerator(flags=flags, file_path=self.train_paths, stage='train', 99 | b_unfold_label=False) 100 | 101 | self.batImageGenVals = BatchImageGenerator(flags=flags, file_path=self.val_paths, stage='val', 102 | b_unfold_label=False) 103 | 104 | self.batImageGenTest = BatchImageGenerator(flags=flags, file_path=[self.unseen_data_path], stage='test', 105 | b_unfold_label=False) 106 | 107 | def load_state_dict(self, flags, nn): 108 | 109 | if flags.state_dict: 110 | 111 | try: 112 | tmp = torch.load(flags.state_dict) 113 | if 'state' in tmp.keys(): 114 | pretrained_dict = tmp['state'] 115 | else: 116 | pretrained_dict = tmp 117 | except: 118 | pretrained_dict = model_zoo.load_url(flags.state_dict) 119 | 120 | model_dict = nn.state_dict() 121 | 122 | # 1. filter out unnecessary keys 123 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if 124 | k in model_dict and v.size() == model_dict[k].size()} 125 | 126 | print('model dict keys:', len(model_dict.keys()), 'pretrained keys:', len(pretrained_dict.keys())) 127 | print('model dict keys:', model_dict.keys(), 'pretrained keys:', pretrained_dict.keys()) 128 | # 2. overwrite entries in the existing state dict 129 | model_dict.update(pretrained_dict) 130 | # 3. load the new state dict 131 | nn.load_state_dict(model_dict) 132 | 133 | def configure(self, flags): 134 | 135 | if torch.cuda.device_count()>1: 136 | base_params = list(map(id, self.network.module.fc.parameters())) 137 | logits_params = filter(lambda p: id(p) not in base_params, self.network.module.parameters()) 138 | params = [ 139 | {"params": logits_params, "lr": flags.lr[0]}, 140 | {"params": self.network.module.fc.parameters(), "lr": flags.lr[1]}, 141 | ] 142 | self.opt_network = torch.optim.SGD(params, weight_decay=flags.weight_decay, momentum=flags.momentum) 143 | else: 144 | base_params = list(map(id, self.network.fc.parameters())) 145 | logits_params = filter(lambda p: id(p) not in base_params, self.network.parameters()) 146 | params = [ 147 | {"params": logits_params, "lr": flags.lr[0]}, 148 | {"params": self.network.fc.parameters(), "lr": flags.lr[1]}, 149 | ] 150 | self.opt_network = torch.optim.SGD(params, weight_decay=flags.weight_decay, momentum=flags.momentum) 151 | 152 | self.scheduler = lr_scheduler.StepLR(optimizer=self.opt_network, step_size=int(flags.step_size), gamma=0.1) 153 | 154 | self.loss_fn_CE = torch.nn.CrossEntropyLoss() 155 | self.loss_fn_MSE = torch.nn.MSELoss() 156 | self.loss_fn_BCE = torch.nn.BCELoss() 157 | 158 | 159 | def train(self, flags): 160 | self.network.train() 161 | 162 | self.best_accuracy_val = -1 163 | self.best_accuracy_test = -1 164 | 165 | for ite in range(flags.loops_train): 166 | 167 | # get the inputs and labels from the data reader 168 | total_loss = 0.0 169 | flag = 0 170 | for index in range(flags.num_domains): 171 | images_train, labels_train, H, L = self.batImageGenTrains.get_images_labels_batch(index) 172 | 173 | inputs, labels, H, L = torch.from_numpy( 174 | np.array(images_train, dtype=np.float32)), torch.from_numpy( 175 | np.array(labels_train, dtype=np.float32)),torch.from_numpy( 176 | np.array(H, dtype=np.float32)),torch.from_numpy( 177 | np.array(L, dtype=np.float32)) 178 | 179 | # wrap the inputs and labels in Variable 180 | if torch.cuda.is_available(): 181 | inputs, labels, H, L = Variable(inputs, requires_grad=False).cuda(), \ 182 | Variable(labels, requires_grad=False).long().cuda(), \ 183 | Variable(H, requires_grad=False).cuda(), \ 184 | Variable(L, requires_grad=False).cuda() 185 | else: 186 | inputs, labels, H, L = Variable(inputs, requires_grad=False).cpu(), \ 187 | Variable(labels, requires_grad=False).long().cpu(), \ 188 | Variable(H, requires_grad=False).cpu(), \ 189 | Variable(L, requires_grad=False).cpu() 190 | 191 | # forward with the adapted parameters 192 | outputs_lc, outputs_hc, outputs_L, outputs_H = self.network(x=inputs, types='disentangle') 193 | 194 | #extract high-frequency features 195 | loss_H_mse = self.loss_fn_MSE(outputs_H, H) 196 | 197 | #extract low-frequency features 198 | loss_L_mse = self.loss_fn_MSE(outputs_L, L) 199 | 200 | loss_C_A = self.loss_fn_CE(outputs_lc, labels) + self.loss_fn_CE(outputs_hc, labels) 201 | 202 | total_loss = total_loss + loss_C_A + loss_H_mse + loss_L_mse 203 | 204 | if flag == 0: 205 | data, image_labels = inputs, labels 206 | flag = 1 207 | else: 208 | data, image_labels = \ 209 | torch.cat((data,inputs),0),\ 210 | torch.cat((image_labels,labels),0) 211 | 212 | # init the grad to zeros first 213 | self.opt_network.zero_grad() 214 | # backward your network 215 | total_loss.backward() 216 | 217 | shuffle_index = torch.randperm(len(image_labels)) 218 | data, image_labels = data[shuffle_index], image_labels[shuffle_index] 219 | 220 | outputs_c_i, _ = self.network(x=data, types='interact') 221 | loss_C_I = self.loss_fn_CE(outputs_c_i, image_labels) 222 | 223 | # backward your network 224 | loss_C_I.backward() 225 | # optimize the parameters 226 | self.opt_network.step() 227 | 228 | flags_log = os.path.join(flags.logs, 'loss_log.txt') 229 | write_log( 230 | "total_loss:" + str(total_loss.item()) + " loss_C_I:" +str(loss_C_I.item()), 231 | flags_log) 232 | 233 | self.scheduler.step(epoch=ite) 234 | 235 | if ite < 500 or ite % 500 == 0: 236 | print( 237 | 'ite:', ite, 'total loss:', total_loss.cpu().item() + loss_C_I.cpu().item(), 238 | 'lr:', self.opt_network.param_groups[0]['lr']) 239 | 240 | if ite % flags.test_every == 0 and ite is not 0: 241 | self.test_workflow(self.batImageGenVals, flags, ite) 242 | 243 | def test_workflow(self, batImageGenVals, flags, ite): 244 | 245 | accuracies = [] 246 | for d_index in range(flags.num_domains): 247 | accuracy_val = self.test(batImageGenTest=batImageGenVals, flags=flags, d_index=d_index) 248 | 249 | accuracies.append(accuracy_val) 250 | 251 | mean_acc = np.mean(accuracies) 252 | 253 | if mean_acc > self.best_accuracy_val or ite > 1000: 254 | if mean_acc > self.best_accuracy_val: 255 | self.best_accuracy_val = mean_acc 256 | 257 | acc_test = self.test(batImageGenTest=self.batImageGenTest, flags=flags) 258 | 259 | if not os.path.exists(flags.model_path): 260 | os.makedirs(flags.model_path) 261 | 262 | if acc_test > self.best_accuracy_test: 263 | self.best_accuracy_test = acc_test 264 | outfile = os.path.join(flags.model_path, 'best_model.tar') 265 | torch.save({'ite': ite, 'state': self.network.state_dict()}, outfile) 266 | f = open(os.path.join(flags.logs, 'Best_val.txt'), mode='a') 267 | f.write( 268 | 'ite:{}, best val accuracy:{}, test accuracy:{}\n'.format(ite, self.best_accuracy_val, 269 | acc_test)) 270 | f.close() 271 | 272 | def test(self, flags, batImageGenTest=None, d_index=0): 273 | 274 | # switch on the network test mode 275 | self.network.eval() 276 | 277 | if batImageGenTest is None: 278 | batImageGenTest = BatchImageGenerator(flags=flags, file_path='', stage='test', b_unfold_label=True) 279 | 280 | images_test = batImageGenTest.images[d_index] 281 | labels_test = batImageGenTest.labels[d_index] 282 | 283 | threshold = 50 284 | if len(images_test) > threshold: 285 | 286 | n_slices_test = int(len(images_test) / threshold) 287 | indices_test = [] 288 | for per_slice in range(n_slices_test - 1): 289 | indices_test.append(int(len(images_test) * (per_slice + 1) / n_slices_test)) 290 | test_image_splits = np.split(images_test, indices_or_sections=indices_test) 291 | 292 | # Verify the splits are correct 293 | test_image_splits_2_whole = np.concatenate(test_image_splits) 294 | assert np.all(images_test == test_image_splits_2_whole) 295 | 296 | # split the test data into splits and test them one by one 297 | test_image_preds = [] 298 | index = 0 299 | for test_image_split in test_image_splits: 300 | index = index + 1 301 | if torch.cuda.is_available(): 302 | images_test = Variable(torch.from_numpy(np.array(test_image_split, dtype=np.float32))).cuda() 303 | else: 304 | images_test = Variable(torch.from_numpy(np.array(test_image_split, dtype=np.float32))).cpu() 305 | tuples = self.network(images_test, 'interact') 306 | 307 | predictions = tuples[1]['Predictions'] 308 | predictions = predictions.cpu().data.numpy() 309 | test_image_preds.append(predictions) 310 | 311 | # concatenate the test predictions first 312 | predictions = np.concatenate(test_image_preds) 313 | else: 314 | if torch.cuda.is_available(): 315 | images_test = Variable(torch.from_numpy(np.array(images_test, dtype=np.float32))).cuda() 316 | else: 317 | images_test = Variable(torch.from_numpy(np.array(images_test, dtype=np.float32))).cpu() 318 | 319 | tuples = self.network(images_test, 'interact') 320 | 321 | predictions = tuples[1]['Predictions'] 322 | predictions = predictions.cpu().data.numpy() 323 | 324 | accuracy = compute_accuracy(predictions=predictions, labels=labels_test) 325 | 326 | # switch on the network train mode 327 | self.network.train() 328 | 329 | return accuracy 330 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | opencv-python==4.5.2.54 2 | h5py==2.10.0 3 | scikit-learn==0.24.2 4 | pillow==7.1.2 5 | scipy==1.7.0 6 | numpy==1.19.2 -------------------------------------------------------------------------------- /resnet_decoder.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | 5 | class UnFlatten(nn.Module): 6 | def forward(self,input): 7 | return input.view(input.size(0),512,4,4) 8 | 9 | def conv3x3(in_planes, out_planes, stride=1): 10 | """3x3 convolution with padding""" 11 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 12 | padding=1, bias=False) 13 | 14 | def Upsample2d(in_planes, out_planes, stride=2): 15 | return nn.Upsample(scale_factor=2, mode="nearest") 16 | 17 | 18 | class BasicBlock(nn.Module): 19 | expansion = 1 20 | 21 | def __init__(self, inplanes, planes, stride=1, upsample=None, flag=True): 22 | super(BasicBlock, self).__init__() 23 | self.conv0_0 = nn.Conv2d(inplanes,planes,3,padding=1) 24 | self.bn0_0 = nn.BatchNorm2d(planes) 25 | 26 | self.conv0_1 = nn.Conv2d(inplanes,inplanes,3,padding=1) 27 | self.bn0_1 = nn.BatchNorm2d(inplanes) 28 | 29 | self.conv1 = conv3x3(inplanes, planes) 30 | self.bn1 = nn.BatchNorm2d(planes) 31 | self.relu = nn.ReLU(inplace=True) 32 | self.conv2 = Upsample2d(inplanes, inplanes) 33 | self.bn2 = nn.BatchNorm2d(inplanes) 34 | self.upsample = upsample 35 | self.stride = stride 36 | self.flag = flag 37 | 38 | def forward(self, x): 39 | residual = x 40 | if self.upsample is not None: 41 | residual_temp = self.upsample(x) 42 | residual = self.conv0_0(residual_temp) 43 | residual = self.bn0_0(residual) 44 | 45 | if self.flag: 46 | x = self.conv2(x) 47 | x = self.conv0_1(x) 48 | x = self.bn0_1(x) 49 | 50 | x = self.conv1(x) 51 | x = self.bn1(x) 52 | x = self.relu(x) 53 | 54 | out = x + residual 55 | 56 | return out 57 | 58 | 59 | class ResNetDecoder(nn.Module): 60 | 61 | def __init__(self, block, layers): 62 | self.inplanes = 512 63 | super(ResNetDecoder, self).__init__() 64 | 65 | self.conv_end = nn.Conv2d(64,3,kernel_size=3,stride=1,padding=1) 66 | 67 | self.layer0 = self._make_layer(block, 512, layers[0]) 68 | self.layer1 = self._make_layer(block, 256, layers[1]) 69 | self.layer2 = self._make_layer(block, 128, layers[2]) 70 | self.layer3 = self._make_layer(block, 64, layers[3]) 71 | 72 | self.tanh = nn.Tanh() 73 | self.sigmoid = nn.Sigmoid() 74 | 75 | for m in self.modules(): 76 | if isinstance(m, nn.ConvTranspose2d): 77 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 78 | elif isinstance(m, nn.BatchNorm2d): 79 | nn.init.constant_(m.weight, 1) 80 | nn.init.constant_(m.bias, 0) 81 | 82 | def bn_eval(self): 83 | for m in self.modules(): 84 | if isinstance(m, nn.BatchNorm2d): 85 | m.eval() 86 | 87 | def _make_layer(self, block, planes, blocks, stride=1, flag=False): 88 | upsample = nn.Sequential( 89 | Upsample2d(self.inplanes, planes * block.expansion), 90 | ) 91 | 92 | layers = [] 93 | layers.append(block(self.inplanes, planes, stride, upsample=upsample, flag=True)) 94 | for i in range(1, blocks): 95 | layers.append(block(planes, planes, flag=False)) 96 | self.inplanes = planes * block.expansion 97 | 98 | return nn.Sequential(*layers) 99 | 100 | def forward(self, x): 101 | 102 | x = self.layer0(x) #B×512×14×14 103 | x = self.layer1(x) #B×256×28×28 104 | x = self.layer2(x) #B×128×56×56 105 | x = self.layer3(x) #B×64×112×112 106 | 107 | x = self.conv_end(x) 108 | out = self.sigmoid(x) 109 | 110 | return out 111 | 112 | def resnet18_decoder(pretrained=False): 113 | """ 114 | Constructs a ResNet-18 decoder model. 115 | """ 116 | model = ResNetDecoder(BasicBlock, [2, 2, 2, 2]) 117 | return model 118 | 119 | -------------------------------------------------------------------------------- /resnet_vanilla_updata.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.utils.model_zoo as model_zoo 3 | import torch.nn.functional as F 4 | # from resnet_decoder_all import * 5 | from resnet_decoder import * 6 | import torch 7 | from torch.autograd import Variable 8 | 9 | model_urls = { 10 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 11 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 12 | } 13 | 14 | class Flatten(nn.Module): 15 | def forward(self, inputs): 16 | return inputs.view(inputs.size(0), -1) 17 | 18 | def conv3x3(in_planes, out_planes, stride=1): 19 | """3x3 convolution with padding""" 20 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 21 | padding=1, bias=False) 22 | 23 | class SpatialAttention(nn.Module): 24 | def __init__(self, kernel_size=3): 25 | super(SpatialAttention, self).__init__() 26 | 27 | self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=kernel_size//2, bias=False) 28 | self.sigmoid = nn.Sigmoid() 29 | 30 | def forward(self, x): 31 | avg_out = torch.mean(x, dim=1, keepdim=True) 32 | max_out, _ = torch.max(x, dim=1, keepdim=True) 33 | x = torch.cat([avg_out, max_out], dim=1) 34 | x = self.conv1(x) 35 | return self.sigmoid(x) 36 | 37 | class BasicBlock(nn.Module): 38 | expansion = 1 39 | 40 | def __init__(self, inplanes, planes, stride=1, downsample=None): 41 | super(BasicBlock, self).__init__() 42 | self.conv1 = conv3x3(inplanes, planes, stride) 43 | self.bn1 = nn.BatchNorm2d(planes) 44 | self.relu = nn.ReLU(inplace=True) 45 | self.conv2 = conv3x3(planes, planes) 46 | self.bn2 = nn.BatchNorm2d(planes) 47 | self.downsample = downsample 48 | self.stride = stride 49 | 50 | def forward(self, x): 51 | residual = x 52 | 53 | out = self.conv1(x) 54 | out = self.bn1(out) 55 | out = self.relu(out) 56 | 57 | out = self.conv2(out) 58 | out = self.bn2(out) 59 | 60 | if self.downsample is not None: 61 | residual = self.downsample(x) 62 | 63 | out += residual 64 | out = self.relu(out) 65 | 66 | return out 67 | 68 | 69 | class ResNet(nn.Module): 70 | 71 | def __init__(self, block, layers, num_classes, num_domains, flags): 72 | self.inplanes = 64 73 | self.flags = flags 74 | super(ResNet, self).__init__() 75 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 76 | bias=False) 77 | self.bn1 = nn.BatchNorm2d(64) 78 | self.relu = nn.ReLU(inplace=True) 79 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 80 | self.layer1 = self._make_layer(block, 64, layers[0]) 81 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 82 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 83 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 84 | self.fc = nn.Linear(512, num_classes) 85 | self.fc_l = nn.Linear(512, num_classes) 86 | self.fc_h = nn.Linear(512, num_classes) 87 | self.block6 = nn.Sequential( 88 | nn.AvgPool2d(7), 89 | Flatten(), 90 | ) 91 | 92 | self.distangler_H = nn.Sequential(nn.Conv2d(512,512,kernel_size=3,stride=1,padding=1), 93 | nn.BatchNorm2d(512), 94 | nn.ReLU(), 95 | ) 96 | 97 | self.distangler_L = nn.Sequential(nn.Conv2d(512,512,kernel_size=3,stride=1,padding=1), 98 | nn.BatchNorm2d(512), 99 | nn.ReLU(), 100 | ) 101 | 102 | self.spatial_attention = SpatialAttention(3) 103 | 104 | for m in self.modules(): 105 | if isinstance(m, nn.Conv2d): 106 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 107 | elif isinstance(m, nn.BatchNorm2d): 108 | nn.init.constant_(m.weight, 1) 109 | nn.init.constant_(m.bias, 0) 110 | 111 | self.decoder_H = resnet18_decoder() 112 | self.decoder_L = resnet18_decoder() 113 | 114 | 115 | def _make_layer(self, block, planes, blocks, stride=1): 116 | downsample = None 117 | if stride != 1 or self.inplanes != planes * block.expansion: 118 | downsample = nn.Sequential( 119 | nn.Conv2d(self.inplanes, planes * block.expansion, 120 | kernel_size=1, stride=stride, bias=False), 121 | nn.BatchNorm2d(planes * block.expansion), 122 | ) 123 | 124 | layers = [] 125 | layers.append(block(self.inplanes, planes, stride, downsample)) 126 | self.inplanes = planes * block.expansion 127 | for i in range(1, blocks): 128 | layers.append(block(self.inplanes, planes)) 129 | 130 | return nn.Sequential(*layers) 131 | 132 | def forward(self, x, types): 133 | 134 | end_points = {} 135 | 136 | x = self.conv1(x) 137 | x = self.bn1(x) 138 | x = self.relu(x) 139 | x = self.maxpool(x) 140 | 141 | x = self.layer1(x) 142 | x = self.layer2(x) 143 | x = self.layer3(x) 144 | out = self.layer4(x) 145 | if types == 'disentangle': 146 | L_info = self.distangler_L(out) 147 | x_L_image = self.decoder_L(L_info) 148 | H_info = self.distangler_H(out) 149 | x_H_image = self.decoder_H(H_info) 150 | l = self.block6(L_info) 151 | h = self.block6(H_info) 152 | x_l = self.fc_l(l) 153 | x_h = self.fc_h(h) 154 | 155 | return x_l, x_h, x_L_image, x_H_image 156 | 157 | elif types == 'interact': 158 | L_info = self.distangler_L(out) 159 | L_att = self.spatial_attention(L_info) 160 | H_info = self.distangler_H(out) 161 | 162 | inter_info = H_info*L_att 163 | inter_info = self.block6(inter_info) 164 | cls = self.fc(inter_info) 165 | 166 | end_points['Predictions'] = F.softmax(input=cls, dim=-1) 167 | return cls, end_points 168 | 169 | def resnet18(pretrained=False, **kwargs): 170 | """Constructs a ResNet-18 model. 171 | 172 | Args: 173 | pretrained (bool): If True, returns a model pre-trained on ImageNet 174 | """ 175 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 176 | if pretrained: 177 | model_dict = model.state_dict() 178 | pretrained_dict = model_zoo.load_url(model_urls['resnet18']) 179 | 180 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if 181 | k in model_dict and v.size() == model_dict[k].size()} 182 | 183 | print('model dict keys:', len(model_dict.keys()), 'pretrained keys:', len(pretrained_dict.keys())) 184 | print('model dict keys:', model_dict.keys(), 'pretrained keys:', pretrained_dict.keys()) 185 | # 2. overwrite entries in the existing state dict 186 | model_dict.update(pretrained_dict) 187 | # 3. load the new state dict 188 | model.load_state_dict(model_dict) 189 | return model 190 | 191 | def resnet50(pretrained=False, **kwargs): 192 | """Constructs a ResNet-50 model. 193 | 194 | Args: 195 | pretrained (bool): If True, returns a model pre-trained on ImageNet 196 | """ 197 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 198 | if pretrained: 199 | model_dict = model.state_dict() 200 | pretrained_dict = model_zoo.load_url(model_urls['resnet50']) 201 | 202 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if 203 | k in model_dict and v.size() == model_dict[k].size()} 204 | 205 | print('model dict keys:', len(model_dict.keys()), 'pretrained keys:', len(pretrained_dict.keys())) 206 | print('model dict keys:', model_dict.keys(), 'pretrained keys:', pretrained_dict.keys()) 207 | # 2. overwrite entries in the existing state dict 208 | model_dict.update(pretrained_dict) 209 | # 3. load the new state dict 210 | model.load_state_dict(model_dict) 211 | return model 212 | --------------------------------------------------------------------------------