├── Initial_Synthetic_Dataset
├── CIFAR10_IPC10_images.pt
├── CIFAR10_IPC10_labels.pt
├── CIFAR10_IPC1_images.pt
├── CIFAR10_IPC1_labels.pt
├── CIFAR10_IPC50_images.pt
└── CIFAR10_IPC50_labels.pt
├── README.md
├── distill_test_model.py
├── google8905e38a0c973ed3.html
├── img
├── DataDAM_pipeline.png
└── HPTable.png
├── main_DataDAM.py
├── networks.py
├── requirements.txt
└── utils.py
/Initial_Synthetic_Dataset/CIFAR10_IPC10_images.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DataDistillation/DataDAM/0f6ea09ed019073933c9bd46382cd4aa3bec7fb8/Initial_Synthetic_Dataset/CIFAR10_IPC10_images.pt
--------------------------------------------------------------------------------
/Initial_Synthetic_Dataset/CIFAR10_IPC10_labels.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DataDistillation/DataDAM/0f6ea09ed019073933c9bd46382cd4aa3bec7fb8/Initial_Synthetic_Dataset/CIFAR10_IPC10_labels.pt
--------------------------------------------------------------------------------
/Initial_Synthetic_Dataset/CIFAR10_IPC1_images.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DataDistillation/DataDAM/0f6ea09ed019073933c9bd46382cd4aa3bec7fb8/Initial_Synthetic_Dataset/CIFAR10_IPC1_images.pt
--------------------------------------------------------------------------------
/Initial_Synthetic_Dataset/CIFAR10_IPC1_labels.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DataDistillation/DataDAM/0f6ea09ed019073933c9bd46382cd4aa3bec7fb8/Initial_Synthetic_Dataset/CIFAR10_IPC1_labels.pt
--------------------------------------------------------------------------------
/Initial_Synthetic_Dataset/CIFAR10_IPC50_images.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DataDistillation/DataDAM/0f6ea09ed019073933c9bd46382cd4aa3bec7fb8/Initial_Synthetic_Dataset/CIFAR10_IPC50_images.pt
--------------------------------------------------------------------------------
/Initial_Synthetic_Dataset/CIFAR10_IPC50_labels.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DataDistillation/DataDAM/0f6ea09ed019073933c9bd46382cd4aa3bec7fb8/Initial_Synthetic_Dataset/CIFAR10_IPC50_labels.pt
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # DataDAM: Efficient Dataset Distillation with Attention Matching
2 | Official implementation of "DataDAM: Efficient Dataset Distillation with Attention Matching", published as a conference paper at ICCV 2023.
3 | - Project Page: https://datadistillation.github.io/DataDAM/
4 | ## Abstract
5 | Researchers have long tried to minimize training costs in deep learning while maintaining strong generalization across diverse datasets. Emerging research on dataset distillation aims to reduce training costs by creating a small synthetic set that contains the information of a larger real dataset and ultimately achieves test accuracy equivalent to a model trained on the whole dataset. Unfortunately, the synthetic data generated by previous methods are not guaranteed to distribute and discriminate as well as the original training data, and they incur significant computational costs. Despite promising results, there still exists a significant performance gap between models trained on condensed synthetic sets and those trained on the whole dataset. In this paper, we address these challenges using efficient Dataset Distillation with Attention Matching (DataDAM), achieving state-of-the-art performance while reducing training costs. Specifically, we learn synthetic images by matching the spatial attention maps of real and synthetic data generated by different layers within a family of randomly initialized neural networks. Our method outperforms the prior methods on several datasets, including MNIST, CIFAR10/100, TinyImageNet, and ImageNet-1K, across most of the settings, and achieves improvements of up to 6.5\% and 4.1\% on CIFAR100 and ImageNet-1K, respectively. We also show that our high-quality distilled images have practical benefits for downstream applications, such as continual learning and neural architecture search.
6 |
7 |
8 |
9 |
10 | ## File Tree
11 | This folder contains all neccesary code files and supplemental material for the main paper.
12 | ```
13 | .
14 | ├── main_DataDAM.py # Source Code for reproducing DataDAM results on behncmark datasets and IPCs
15 | ├── networks.py # Defines all relevant network architectures, including cross-arch models
16 | ├── utils.py # Defines all utility functions required for any task or ablation in main paper, inlcuding our attention module
17 | ├── distill_test_model.py # Script to test the frozen models
18 | ├── requirements.txt # Lists all related Python packages neccessary for reproducing our model results
19 | ├── Supplementary.pdf # Supplementary pdf for our main paper -- DataDAM
20 | └── README.md
21 | ```
22 |
23 |
24 |
25 | ## HyperParameter Table
26 | For reproducibility, we outline our associated hyperparameters below:
27 |
28 |
29 |
30 |
31 | ## Distilled Datasets & Frozen Evaluation Models
32 |
33 | We provide saved tensors of the dataset and frozen evaluation models trained on the respective distilled dataset on our HuggingFace Page: https://huggingface.co/datasets/uoft-dsp-lab/DataDAM
34 |
35 | Additionally these frozen models can be tested with "distill_test_model.py"
36 |
--------------------------------------------------------------------------------
/distill_test_model.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import copy
4 | import argparse
5 | import numpy as np
6 | import torch
7 | import torch.nn as nn
8 |
9 | from utils import get_loops, get_dataset, get_network, get_eval_pool, evaluate_synset, get_daparam, get_time, TensorDataset, epoch, DiffAugment, ParamDiffAug, get_attention
10 | import matplotlib.pyplot as plt
11 | from torchvision import transforms
12 | from torch.utils.data.distributed import DistributedSampler
13 | import kornia as K
14 | import torch.distributed as dist
15 | import torch.cuda.comm
16 | from torchvision.utils import save_image
17 |
18 | def main():
19 |
20 | parser = argparse.ArgumentParser(description='Parameter Processing')
21 | parser.add_argument('--dataset', type=str, default='CIFAR10', help='dataset')
22 | parser.add_argument('--model', type=str, default='ConvNet', help='model')
23 | parser.add_argument('--ipc', type=int, default=50, help='image(s) per class')
24 | parser.add_argument('--eval_mode', type=str, default='SS', help='eval_mode')
25 | parser.add_argument('--num_exp', type=int, default=1, help='the number of experiments')
26 | parser.add_argument('--num_eval', type=int, default=10, help='the number of evaluating randomly initialized models')
27 | parser.add_argument('--epoch_eval_train', type=int, default=1800, help='epochs to train a model with synthetic data')
28 | parser.add_argument('--Iteration', type=int, default=20000, help='training iterations')
29 | parser.add_argument('--lr_img', type=float, default=1, help='learning rate for updating synthetic images, 1 for low IPCs 10 for >= 100')
30 | parser.add_argument('--lr_net', type=float, default=0.01, help='learning rate for updating network parameters')
31 | parser.add_argument('--batch_real', type=int, default=64, help='batch size for real data')
32 | parser.add_argument('--batch_train', type=int, default=64, help='batch size for training networks')
33 | parser.add_argument('--init', type=str, default='real', help='noise/real/smart: initialize synthetic images from random noise or randomly sampled real images.')
34 | parser.add_argument('--dsa_strategy', type=str, default='color_crop_cutout_flip_scale_rotate', help='differentiable Siamese augmentation strategy')
35 | parser.add_argument('--data_path', type=str, default='', help='dataset path')
36 | parser.add_argument('--zca', type=bool, default=False, help='Zca Whitening')
37 | parser.add_argument('--save_path', type=str, default='', help='path to save results')
38 | parser.add_argument('--task_balance', type=float, default=0.01, help='balance attention with output')
39 |
40 | args = parser.parse_args()
41 | args.method = 'DataDAM'
42 | args.device = 'cuda' if torch.cuda.is_available() else 'cpu'
43 | args.dsa_param = ParamDiffAug()
44 | args.dsa = False if args.dsa_strategy in ['none', 'None'] else True
45 |
46 | if not os.path.exists(args.data_path):
47 | os.mkdir(args.data_path)
48 |
49 | if not os.path.exists(args.save_path):
50 | os.mkdir(args.save_path)
51 |
52 | args.save_path += "/{}".format(args.dataset.lower())
53 | channel, im_size, num_classes, class_names, mean, std, dst_train, dst_test, testloader, zca = get_dataset(args.dataset, args.data_path, args)
54 | model_eval_pool = get_eval_pool(args.eval_mode, args.model, args.model)
55 |
56 |
57 | model_eval = model_eval_pool[0]
58 |
59 | data_save = torch.load(os.path.join(args.save_path, 'syn_data_%s_ipc_%d.pt'%(args.dataset.lower(), args.ipc)))["data"]
60 |
61 | image_syn_eval = torch.tensor(data_save[0])
62 | label_syn_eval = torch.tensor(data_save[1])
63 | net_model_dict = torch.load(os.path.join(args.save_path, 'model_params_%s_ipc_%d.pt'%(args.dataset.lower(), args.ipc)))["net_parameters"]
64 |
65 | net_eval = get_network(model_eval, channel, num_classes, im_size).to(args.device) # get a random model
66 |
67 | net_eval.load_state_dict(net_model_dict) # load the state dict
68 | _, _, acc_test = evaluate_synset(-1, net_eval, image_syn_eval, label_syn_eval, testloader, args, skip=True) # evaluate the model
69 | print("Trained Model Best", acc_test)
70 |
71 | main()
72 |
73 |
74 |
--------------------------------------------------------------------------------
/google8905e38a0c973ed3.html:
--------------------------------------------------------------------------------
1 | google-site-verification: google8905e38a0c973ed3.html
2 |
--------------------------------------------------------------------------------
/img/DataDAM_pipeline.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DataDistillation/DataDAM/0f6ea09ed019073933c9bd46382cd4aa3bec7fb8/img/DataDAM_pipeline.png
--------------------------------------------------------------------------------
/img/HPTable.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DataDistillation/DataDAM/0f6ea09ed019073933c9bd46382cd4aa3bec7fb8/img/HPTable.png
--------------------------------------------------------------------------------
/main_DataDAM.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import copy
4 | import argparse
5 | import numpy as np
6 | import torch
7 | import torch.nn as nn
8 |
9 | from utils import get_loops, get_dataset, get_network, get_eval_pool, evaluate_synset, get_daparam, get_time, TensorDataset, epoch, DiffAugment, ParamDiffAug, get_attention
10 | import matplotlib.pyplot as plt
11 | from torchvision import transforms
12 | from torch.utils.data.distributed import DistributedSampler
13 | import kornia as K
14 | import torch.distributed as dist
15 | import torch.cuda.comm
16 |
17 | def main():
18 |
19 | parser = argparse.ArgumentParser(description='Parameter Processing')
20 | parser.add_argument('--dataset', type=str, default='CIFAR10', help='dataset')
21 | parser.add_argument('--model', type=str, default='ConvNet', help='model')
22 | parser.add_argument('--ipc', type=int, default=50, help='image(s) per class')
23 | parser.add_argument('--eval_mode', type=str, default='SS', help='eval_mode')
24 | parser.add_argument('--num_exp', type=int, default=1, help='the number of experiments')
25 | parser.add_argument('--num_eval', type=int, default=10, help='the number of evaluating randomly initialized models')
26 | parser.add_argument('--epoch_eval_train', type=int, default=1800, help='epochs to train a model with synthetic data')
27 | parser.add_argument('--Iteration', type=int, default=20000, help='training iterations')
28 | parser.add_argument('--lr_img', type=float, default=1, help='learning rate for updating synthetic images, 1 for low IPCs 10 for >= 100')
29 | parser.add_argument('--lr_net', type=float, default=0.01, help='learning rate for updating network parameters')
30 | parser.add_argument('--batch_real', type=int, default=64, help='batch size for real data')
31 | parser.add_argument('--batch_train', type=int, default=64, help='batch size for training networks')
32 | parser.add_argument('--init', type=str, default='real', help='noise/real/smart: initialize synthetic images from random noise or randomly sampled real images.')
33 | parser.add_argument('--dsa_strategy', type=str, default='color_crop_cutout_flip_scale_rotate', help='differentiable Siamese augmentation strategy')
34 | parser.add_argument('--data_path', type=str, default='', help='dataset path')
35 | parser.add_argument('--zca', type=bool, default=False, help='Zca Whitening')
36 | parser.add_argument('--save_path', type=str, default='', help='path to save results')
37 | parser.add_argument('--task_balance', type=float, default=0.01, help='balance attention with output')
38 |
39 | args = parser.parse_args()
40 | args.method = 'DataDAM'
41 | args.device = 'cuda' if torch.cuda.is_available() else 'cpu'
42 | args.dsa_param = ParamDiffAug()
43 | args.dsa = False if args.dsa_strategy in ['none', 'None'] else True
44 |
45 | if not os.path.exists(args.data_path):
46 | os.mkdir(args.data_path)
47 |
48 | if not os.path.exists(args.save_path):
49 | os.mkdir(args.save_path)
50 |
51 | eval_it_pool = np.arange(0, args.Iteration+1, 2000).tolist()[:] if args.eval_mode == 'S' or args.eval_mode == 'SS' else [args.Iteration] # The list of iterations when we evaluate models and record results.
52 | print('eval_it_pool: ', eval_it_pool)
53 |
54 | channel, im_size, num_classes, class_names, mean, std, dst_train, dst_test, testloader, zca = get_dataset(args.dataset, args.data_path, args)
55 | model_eval_pool = get_eval_pool(args.eval_mode, args.model, args.model)
56 |
57 |
58 | accs_all_exps = dict() # record performances of all experiments
59 | for key in model_eval_pool:
60 | accs_all_exps[key] = []
61 |
62 | data_save = []
63 |
64 | total_mean = {}
65 | best_5 = []
66 | accuracy_logging = {"mean":[], "std":[], "max_mean":[]}
67 | for exp in range(args.num_exp):
68 | total_mean[exp] = {'mean':[], 'std':[]}
69 | best_5.append(0)
70 | print('\n================== Exp %d ==================\n '%exp)
71 | print('Hyper-parameters: \n', args.__dict__)
72 | print('Evaluation model pool: ', model_eval_pool)
73 |
74 | ''' organize the real dataset '''
75 | images_all = []
76 | labels_all = []
77 | indices_class = [[] for c in range(num_classes)]
78 |
79 | images_all = [torch.unsqueeze(dst_train[i][0], dim=0) for i in range(len(dst_train))]
80 | labels_all = [dst_train[i][1] for i in range(len(dst_train))]
81 | for i, lab in enumerate(labels_all):
82 | indices_class[lab].append(i)
83 | images_all = torch.cat(images_all, dim=0).to(args.device)
84 | labels_all = torch.tensor(labels_all, dtype=torch.long, device=args.device)
85 |
86 |
87 |
88 | for c in range(num_classes):
89 | print('class c = %d: %d real images'%(c, len(indices_class[c])))
90 |
91 | def get_images(c, n): # get random n images from class c
92 | idx_shuffle = np.random.permutation(indices_class[c])[:n]
93 | return images_all[idx_shuffle]
94 |
95 | for ch in range(channel):
96 | print('real images channel %d, mean = %.4f, std = %.4f'%(ch, torch.mean(images_all[:, ch]), torch.std(images_all[:, ch])))
97 |
98 |
99 | ''' initialize the synthetic data '''
100 | image_syn = torch.randn(size=(num_classes*args.ipc, channel, im_size[0], im_size[1]), dtype=torch.float, requires_grad=True, device=args.device)
101 | label_syn = torch.tensor([np.ones(args.ipc)*i for i in range(num_classes)], dtype=torch.long, requires_grad=False, device=args.device).view(-1) # [0,0,0, 1,1,1, ..., 9,9,9]
102 | if args.init == 'real':
103 | print('initialize synthetic data from random real images')
104 | for c in range(num_classes):
105 | image_syn.data[c*args.ipc:(c+1)*args.ipc] = get_images(c, args.ipc).detach().data
106 | elif args.init =='noise' :
107 | print('initialize synthetic data from random noise')
108 |
109 | elif args.init =='smart' :
110 | print('initialize synthetic data from SMART selection')
111 | Path = './'
112 | if args.dataset == "CIFAR10":
113 | Path+='CIFAR10_'
114 |
115 | elif args.dataset == "CIFAR100":
116 | Path+='CIFAR100_'
117 |
118 | if args.ipc == 1:
119 | Path += 'IPC1_'
120 |
121 | elif args.ipc == 10:
122 | Path += 'IPC10_'
123 |
124 | elif args.ipc == 50:
125 | Path += 'IPC50_'
126 |
127 | elif args.ipc == 100:
128 | Path += 'IPC100_'
129 |
130 | elif args.ipc == 200:
131 | Path += 'IPC200_'
132 | image_syn.data[:][:][:][:] = torch.load(Path+'images.pt')
133 | label_syn.data[:] = torch.load(Path+'labels.pt')
134 |
135 | if(args.zca):
136 | print("ZCA Whitened Complete")
137 | image_syn.data[:][:][:][:] = zca(image_syn.data[:][:][:][:], include_fit=True)
138 | else:
139 | print("No ZCA Whiteinign")
140 |
141 |
142 |
143 |
144 |
145 | ''' training '''
146 | optimizer_img = torch.optim.SGD([image_syn, ], lr=args.lr_img, momentum=0.5) # optimizer_img for synthetic data
147 | optimizer_img.zero_grad()
148 | print('%s training begins'%get_time())
149 | ''' Defining the Hook Function to collect Activations '''
150 | activations = {}
151 | def getActivation(name):
152 | def hook_func(m, inp, op):
153 | activations[name] = op.clone()
154 | return hook_func
155 |
156 | ''' Defining the Refresh Function to store Activations and reset Collection '''
157 | def refreshActivations(activations):
158 | model_set_activations = [] # Jagged Tensor Creation
159 | for i in activations.keys():
160 | model_set_activations.append(activations[i])
161 | activations = {}
162 | return activations, model_set_activations
163 |
164 | ''' Defining the Delete Hook Function to collect Remove Hooks '''
165 | def delete_hooks(hooks):
166 | for i in hooks:
167 | i.remove()
168 | return
169 |
170 | def attach_hooks(net):
171 | hooks = []
172 | base = net.module if torch.cuda.device_count() > 1 else net
173 | for module in (base.features.named_modules()):
174 | if isinstance(module[1], nn.ReLU):
175 | # Hook the Ouptus of a ReLU Layer
176 | hooks.append(base.features[int(module[0])].register_forward_hook(getActivation('ReLU_'+str(len(hooks)))))
177 | return hooks
178 |
179 | max_mean = 0
180 | for it in range(args.Iteration+1):
181 |
182 | ''' Evaluate synthetic data '''
183 | if it in eval_it_pool:
184 | for model_eval in model_eval_pool:
185 | print('-------------------------\nEvaluation\nmodel_train = %s, model_eval = %s, iteration = %d'%(args.model, model_eval, it))
186 |
187 | print('DSA augmentation strategy: \n', args.dsa_strategy)
188 | print('DSA augmentation parameters: \n', args.dsa_param.__dict__)
189 |
190 | accs = []
191 | Start = time.time()
192 | for it_eval in range(args.num_eval):
193 | net_eval = get_network(model_eval, channel, num_classes, im_size).to(args.device) # get a random model
194 | image_syn_eval, label_syn_eval = copy.deepcopy(image_syn.detach()), copy.deepcopy(label_syn.detach()) # avoid any unaware modification
195 | mini_net, acc_train, acc_test = evaluate_synset(it_eval, net_eval, image_syn_eval, label_syn_eval, testloader, args)
196 | accs.append(acc_test)
197 | if acc_test > best_5[-1]:
198 | best_5[-1] = acc_test
199 |
200 | Finish = (time.time() - Start)/10
201 |
202 | print("TOTAL TIME WAS: ", Finish)
203 |
204 |
205 | print('Evaluate %d random %s, mean = %.4f std = %.4f\n-------------------------'%(len(accs), model_eval, np.mean(accs), np.std(accs)))
206 | if np.mean(accs) > max_mean:
207 | data=[]
208 | data_save.append([copy.deepcopy(image_syn.detach().cpu()), copy.deepcopy(label_syn.detach().cpu())])
209 | torch.save({'data': data_save, 'accs_all_exps': accs_all_exps, }, os.path.join(args.save_path, 'res_%s_%s_%s_%dipc_.pt'%(args.method, args.dataset, args.model, args.ipc)))
210 | # Track All of them!
211 | total_mean[exp]['mean'].append(np.mean(accs))
212 | total_mean[exp]['std'].append(np.std(accs))
213 |
214 | accuracy_logging["mean"].append(np.mean(accs))
215 | accuracy_logging["std"].append(np.std(accs))
216 | accuracy_logging["max_mean"].append(np.max(accs))
217 |
218 |
219 | if it == args.Iteration: # record the final results
220 | accs_all_exps[model_eval] += accs
221 |
222 | ''' visualize and save '''
223 | # save_name = os.path.join(args.save_path, 'vis_%s_%s_%s_%dipc_exp%d_iter%d.png'%(args.method, args.dataset, args.model, args.ipc, exp, it))
224 | # image_syn_vis = copy.deepcopy(image_syn.detach().cpu())
225 | # for ch in range(channel):
226 | # image_syn_vis[:, ch] = image_syn_vis[:, ch] * std[ch] + mean[ch]
227 | # image_syn_vis[image_syn_vis<0] = 0.0
228 | # image_syn_vis[image_syn_vis>1] = 1.0
229 | # save_image(image_syn_vis, save_name, nrow=args.ipc) # Trying normalize = True/False may get better visual effects.
230 |
231 | ''' Train synthetic data '''
232 | net = get_network(args.model, channel, num_classes, im_size).to(args.device) # get a random model
233 | net.train()
234 | for param in list(net.parameters()):
235 | param.requires_grad = False
236 |
237 | loss_avg = 0
238 | def error(real, syn, err_type="MSE"):
239 |
240 | if(err_type == "MSE"):
241 | err = torch.sum((torch.mean(real, dim=0) - torch.mean(syn, dim=0))**2)
242 |
243 | elif (err_type == "MAE"):
244 | err = torch.sum(torch.abs(torch.mean(real, dim=0) - torch.mean(syn, dim=0)))
245 |
246 | elif (err_type == "ANG"):
247 | rl = torch.mean(real, dim=0)
248 | sy = torch.mean(syn, dim=0)
249 | num = torch.matmul(rl, sy)
250 | denom = (torch.sum(rl**2)**0.5) * (torch.sum(sy**2)**0.5)
251 | err = torch.acos(num/denom)
252 |
253 | elif(err_type == "MSE_B"):
254 | err = torch.sum((torch.mean(real.reshape(num_classes, args.batch_real, -1), dim=1).cpu() - torch.mean(syn.cpu().reshape(num_classes, args.ipc, -1), dim=1))**2)
255 | elif(err_type == "MAE_B"):
256 | err = torch.sum(torch.abs(torch.mean(real.reshape(num_classes, args.batch_real, -1), dim=1).cpu() - torch.mean(syn.reshape(num_classes, args.ipc, -1).cpu(), dim=1)))
257 | elif (err_type == "ANG_B"):
258 | rl = torch.mean(real.reshape(num_classes, args.batch_real, -1), dim=1).cpu()
259 | sy = torch.mean(syn.reshape(num_classes, args.ipc, -1), dim=1)
260 |
261 | denom = (torch.sum(rl**2)**0.5).cpu() * (torch.sum(sy**2)**0.5).cpu()
262 | num = rl.cpu() * sy.cpu()
263 | err = torch.sum(torch.acos(num/denom))
264 | return err
265 |
266 | ''' update synthetic data '''
267 | loss = torch.tensor(0.0)
268 | mid_loss = 0
269 | out_loss = 0
270 |
271 | images_real_all = []
272 | images_syn_all = []
273 | for c in range(num_classes):
274 | img_real = get_images(c, args.batch_real)
275 | img_syn = image_syn[c*args.ipc:(c+1)*args.ipc].reshape((args.ipc, channel, im_size[0], im_size[1]))
276 |
277 | if args.dsa:
278 | seed = int(time.time() * 1000) % 100000
279 | img_real = DiffAugment(img_real, args.dsa_strategy, seed=seed, param=args.dsa_param)
280 | img_syn = DiffAugment(img_syn, args.dsa_strategy, seed=seed, param=args.dsa_param)
281 |
282 | images_real_all.append(img_real)
283 | images_syn_all.append(img_syn)
284 |
285 | images_real_all = torch.cat(images_real_all, dim=0)
286 |
287 | images_syn_all = torch.cat(images_syn_all, dim=0)
288 |
289 |
290 | hooks = attach_hooks(net)
291 |
292 | output_real = net(images_real_all)[0].detach()
293 | activations, original_model_set_activations = refreshActivations(activations)
294 |
295 | output_syn = net(images_syn_all)[0]
296 | activations, syn_model_set_activations = refreshActivations(activations)
297 | delete_hooks(hooks)
298 |
299 | length_of_network = len(original_model_set_activations)# of Feature Map Sets
300 |
301 | for layer in range(length_of_network-1):
302 |
303 | real_attention = get_attention(original_model_set_activations[layer].detach(), param=1, exp=1, norm='l2')
304 | syn_attention = get_attention(syn_model_set_activations[layer], param=1, exp=1, norm='l2')
305 |
306 | tl = 100*error(real_attention, syn_attention, err_type="MSE_B")
307 | loss+=tl
308 | mid_loss += tl
309 |
310 | output_loss = 100*args.task_balance * error(output_real, output_syn, err_type="MSE_B")
311 |
312 | loss += output_loss
313 | out_loss += output_loss
314 |
315 | optimizer_img.zero_grad()
316 | loss.backward()
317 | optimizer_img.step()
318 | loss_avg += loss.item()
319 | torch.cuda.empty_cache()
320 |
321 | loss_avg /= (num_classes)
322 | out_loss /= (num_classes)
323 | mid_loss /= (num_classes)
324 | if it%10 == 0:
325 | print('%s iter = %05d, loss = %.4f' % (get_time(), it, loss_avg))
326 | print('\n==================== Final Results ====================\n')
327 | for key in model_eval_pool:
328 | accs = accs_all_exps[key]
329 | print('Run %d experiments, train on %s, evaluate %d random %s, mean = %.2f%% std = %.2f%%'%(args.num_exp, args.model, len(accs), key, np.mean(accs)*100, np.std(accs)*100))
330 |
331 | print('\n==================== Maximum Results ====================\n')
332 |
333 | best_means = []
334 | best_std = []
335 | for exp in total_mean.keys():
336 | best_idx = np.argmax(total_mean[exp]['mean'])
337 | best_means.append(total_mean[exp]['mean'][best_idx])
338 | best_std.append(total_mean[exp]['std'][best_idx])
339 |
340 | mean = np.mean(best_means)
341 | std = np.mean(best_std)
342 |
343 | num_eval = args.num_exp*args.num_eval
344 | print('Run %d experiments, train on %s, evaluate %d random %s, mean = %.2f%% std = %.2f%%'%(args.num_exp, args.model,num_eval, key, mean*100, std*100))
345 |
346 |
347 | print('\n==================== Top 5 Results ====================\n')
348 |
349 |
350 | mean = np.mean(best_5)
351 | std = np.std(best_5)
352 |
353 | num_eval = args.num_exp*args.num_eval
354 | print('Run %d experiments, train on %s, evaluate %d random %s, mean = %.2f%% std = %.2f%%'%(args.num_exp, args.model,num_eval, key, mean*100, std*100))
355 |
356 |
357 | if __name__ == '__main__':
358 | main()
359 |
360 |
361 |
--------------------------------------------------------------------------------
/networks.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | # Acknowledgement to
5 | # https://github.com/kuangliu/pytorch-cifar,
6 | # https://github.com/BIGBALLON/CIFAR-ZOO,
7 |
8 | from einops import rearrange, repeat
9 | from einops.layers.torch import Rearrange
10 |
11 |
12 |
13 | ''' Swish activation '''
14 | class Swish(nn.Module): # Swish(x) = x∗σ(x)
15 | def __init__(self):
16 | super().__init__()
17 |
18 | def forward(self, input):
19 | return input * torch.sigmoid(input)
20 |
21 |
22 | ''' MLP '''
23 | class MLP(nn.Module):
24 | def __init__(self, channel, num_classes):
25 | super(MLP, self).__init__()
26 | self.fc_1 = nn.Linear(28*28*1 if channel==1 else 32*32*3, 128)
27 | self.fc_2 = nn.Linear(128, 128)
28 | self.fc_3 = nn.Linear(128, num_classes)
29 |
30 | def forward(self, x):
31 | out = x.view(x.size(0), -1)
32 | out = F.relu(self.fc_1(out))
33 | out = F.relu(self.fc_2(out))
34 | out = self.fc_3(out)
35 | return out
36 |
37 |
38 |
39 | ''' ConvNet '''
40 | class ConvNet(nn.Module):
41 | def __init__(self, channel, num_classes, net_width, net_depth, net_act, net_norm, net_pooling, im_size = (32,32)):
42 | super(ConvNet, self).__init__()
43 |
44 | self.features, shape_feat = self._make_layers(channel, net_width, net_depth, net_norm, net_act, net_pooling, im_size)
45 | num_feat = shape_feat[0]*shape_feat[1]*shape_feat[2]
46 | self.classifier = nn.Linear(num_feat, num_classes)
47 |
48 | def forward(self, x):
49 | # print(x.shape)
50 | out = self.features(x)
51 | emb = out.reshape(out.size(0), -1)
52 | # emb = self.embed(x)
53 | out = self.classifier(emb)
54 | return emb, out
55 |
56 | def embed(self, x):
57 | out = self.features(x)
58 | out = out.view(out.size(0), -1)
59 | return out
60 |
61 | def _get_activation(self, net_act):
62 | if net_act == 'sigmoid':
63 | return nn.Sigmoid()
64 | elif net_act == 'relu':
65 | return nn.ReLU(inplace=True)
66 | elif net_act == 'leakyrelu':
67 | return nn.LeakyReLU(negative_slope=0.01)
68 | elif net_act == 'swish':
69 | return Swish()
70 | else:
71 | exit('unknown activation function: %s'%net_act)
72 |
73 | def _get_pooling(self, net_pooling):
74 | if net_pooling == 'maxpooling':
75 | return nn.MaxPool2d(kernel_size=2, stride=2)
76 | elif net_pooling == 'avgpooling':
77 | return nn.AvgPool2d(kernel_size=2, stride=2)
78 | elif net_pooling == 'none':
79 | return None
80 | else:
81 | exit('unknown net_pooling: %s'%net_pooling)
82 |
83 | def _get_normlayer(self, net_norm, shape_feat):
84 | # shape_feat = (c*h*w)
85 | if net_norm == 'batchnorm':
86 | return nn.BatchNorm2d(shape_feat[0], affine=True)
87 | elif net_norm == 'layernorm':
88 | return nn.LayerNorm(shape_feat, elementwise_affine=True)
89 | elif net_norm == 'instancenorm':
90 | return nn.GroupNorm(shape_feat[0], shape_feat[0], affine=True)
91 | elif net_norm == 'groupnorm':
92 | return nn.GroupNorm(4, shape_feat[0], affine=True)
93 | elif net_norm == 'none':
94 | return None
95 | else:
96 | exit('unknown net_norm: %s'%net_norm)
97 |
98 | def _make_layers(self, channel, net_width, net_depth, net_norm, net_act, net_pooling, im_size):
99 | layers = []
100 | in_channels = channel
101 | if im_size[0] == 28:
102 | im_size = (32, 32)
103 | shape_feat = [in_channels, im_size[0], im_size[1]]
104 | for d in range(net_depth):
105 | layers += [nn.Conv2d(in_channels, net_width, kernel_size=3, padding=3 if channel == 1 and d == 0 else 1)]
106 | shape_feat[0] = net_width
107 | if net_norm != 'none':
108 | layers += [self._get_normlayer(net_norm, shape_feat)]
109 | layers += [self._get_activation(net_act)]
110 | in_channels = net_width
111 | if net_pooling != 'none':
112 | layers += [self._get_pooling(net_pooling)]
113 | shape_feat[1] //= 2
114 | shape_feat[2] //= 2
115 |
116 | return nn.Sequential(*layers), shape_feat
117 |
118 |
119 |
120 | ''' LeNet '''
121 | class LeNet(nn.Module):
122 | def __init__(self, channel, num_classes):
123 | super(LeNet, self).__init__()
124 | self.features = nn.Sequential(
125 | nn.Conv2d(channel, 6, kernel_size=5, padding=2 if channel==1 else 0),
126 | nn.ReLU(inplace=True),
127 | nn.MaxPool2d(kernel_size=2, stride=2),
128 | nn.Conv2d(6, 16, kernel_size=5),
129 | nn.ReLU(inplace=True),
130 | nn.MaxPool2d(kernel_size=2, stride=2),
131 | )
132 | self.fc_1 = nn.Linear(16 * 5 * 5, 120)
133 | self.fc_2 = nn.Linear(120, 84)
134 | self.fc_3 = nn.Linear(84, num_classes)
135 |
136 | def forward(self, x):
137 | x = self.features(x)
138 | x = x.view(x.size(0), -1)
139 | x = F.relu(self.fc_1(x))
140 | x = F.relu(self.fc_2(x))
141 | x = self.fc_3(x)
142 | return x
143 |
144 |
145 |
146 | ''' AlexNet '''
147 | class AlexNet(nn.Module):
148 | def __init__(self, channel, num_classes):
149 | super(AlexNet, self).__init__()
150 | self.features = nn.Sequential(
151 | nn.Conv2d(channel, 128, kernel_size=5, stride=1, padding=4 if channel==1 else 2),
152 | nn.ReLU(inplace=True),
153 | nn.MaxPool2d(kernel_size=2, stride=2),
154 | nn.Conv2d(128, 192, kernel_size=5, padding=2),
155 | nn.ReLU(inplace=True),
156 | nn.MaxPool2d(kernel_size=2, stride=2),
157 | nn.Conv2d(192, 256, kernel_size=3, padding=1),
158 | nn.ReLU(inplace=True),
159 | nn.Conv2d(256, 192, kernel_size=3, padding=1),
160 | nn.ReLU(inplace=True),
161 | nn.Conv2d(192, 192, kernel_size=3, padding=1),
162 | nn.ReLU(inplace=True),
163 | nn.MaxPool2d(kernel_size=2, stride=2),
164 | )
165 | self.fc = nn.Linear(192 * 4 * 4, num_classes)
166 |
167 | def forward(self, x):
168 | x = self.features(x)
169 | emb = x.view(x.size(0), -1)
170 | out = self.fc(emb)
171 | return emb, out
172 |
173 | def embed(self, x):
174 | x = self.features(x)
175 | x = x.view(x.size(0), -1)
176 | return x
177 |
178 |
179 | ''' AlexNetBN '''
180 | class AlexNetBN(nn.Module):
181 | def __init__(self, channel, num_classes):
182 | super(AlexNetBN, self).__init__()
183 | self.features = nn.Sequential(
184 | nn.Conv2d(channel, 128, kernel_size=5, stride=1, padding=4 if channel==1 else 2),
185 | nn.BatchNorm2d(128),
186 | nn.ReLU(inplace=True),
187 | nn.MaxPool2d(kernel_size=2, stride=2),
188 | nn.Conv2d(128, 192, kernel_size=5, padding=2),
189 | nn.BatchNorm2d(192),
190 | nn.ReLU(inplace=True),
191 | nn.MaxPool2d(kernel_size=2, stride=2),
192 | nn.Conv2d(192, 256, kernel_size=3, padding=1),
193 | nn.BatchNorm2d(256),
194 | nn.ReLU(inplace=True),
195 | nn.Conv2d(256, 192, kernel_size=3, padding=1),
196 | nn.BatchNorm2d(192),
197 | nn.ReLU(inplace=True),
198 | nn.Conv2d(192, 192, kernel_size=3, padding=1),
199 | nn.BatchNorm2d(192),
200 | nn.ReLU(inplace=True),
201 | nn.MaxPool2d(kernel_size=2, stride=2),
202 | )
203 | self.fc = nn.Linear(192 * 4 * 4, num_classes)
204 |
205 | def forward(self, x):
206 | x = self.features(x)
207 | emb = x.view(x.size(0), -1)
208 | out = self.fc(emb)
209 | return emb, out
210 |
211 | def embed(self, x):
212 | x = self.features(x)
213 | x = x.view(x.size(0), -1)
214 | return x
215 |
216 |
217 | ''' VGG '''
218 | cfg_vgg = {
219 | 'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
220 | 'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
221 | 'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
222 | 'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
223 | }
224 | class VGG(nn.Module):
225 | def __init__(self, vgg_name, channel, num_classes, norm='instancenorm'):
226 | super(VGG, self).__init__()
227 | self.channel = channel
228 | self.features = self._make_layers(cfg_vgg[vgg_name], norm)
229 | self.classifier = nn.Linear(512 if vgg_name != 'VGGS' else 128, num_classes)
230 |
231 | def forward(self, x):
232 | x = self.features(x)
233 | emb = x.view(x.size(0), -1)
234 | out = self.classifier(emb)
235 | return emb, out
236 |
237 | def embed(self, x):
238 | x = self.features(x)
239 | x = x.view(x.size(0), -1)
240 | return x
241 |
242 | def _make_layers(self, cfg, norm):
243 | layers = []
244 | in_channels = self.channel
245 | for ic, x in enumerate(cfg):
246 | if x == 'M':
247 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
248 | else:
249 | layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=3 if self.channel==1 and ic==0 else 1),
250 | nn.GroupNorm(x, x, affine=True) if norm=='instancenorm' else nn.BatchNorm2d(x),
251 | nn.ReLU(inplace=True)]
252 | in_channels = x
253 | layers += [nn.AvgPool2d(kernel_size=1, stride=1)]
254 | return nn.Sequential(*layers)
255 |
256 |
257 | def VGG11(channel, num_classes):
258 | return VGG('VGG11', channel, num_classes)
259 | def VGG11BN(channel, num_classes):
260 | return VGG('VGG11', channel, num_classes, norm='batchnorm')
261 | def VGG13(channel, num_classes):
262 | return VGG('VGG13', channel, num_classes)
263 | def VGG16(channel, num_classes):
264 | return VGG('VGG16', channel, num_classes)
265 | def VGG19(channel, num_classes):
266 | return VGG('VGG19', channel, num_classes)
267 |
268 |
269 | ''' ResNet_AP '''
270 | # The conv(stride=2) is replaced by conv(stride=1) + avgpool(kernel_size=2, stride=2)
271 |
272 | class BasicBlock_AP(nn.Module):
273 | expansion = 1
274 |
275 | def __init__(self, in_planes, planes, stride=1, norm='instancenorm'):
276 | super(BasicBlock_AP, self).__init__()
277 | self.norm = norm
278 | self.stride = stride
279 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=1, padding=1, bias=False) # modification
280 | self.bn1 = nn.GroupNorm(planes, planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(planes)
281 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
282 | self.bn2 = nn.GroupNorm(planes, planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(planes)
283 |
284 | self.shortcut = nn.Sequential()
285 | if stride != 1 or in_planes != self.expansion * planes:
286 | self.shortcut = nn.Sequential(
287 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=1, bias=False),
288 | nn.AvgPool2d(kernel_size=2, stride=2), # modification
289 | nn.GroupNorm(self.expansion * planes, self.expansion * planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(self.expansion * planes)
290 | )
291 |
292 | def forward(self, x):
293 | out = F.relu(self.bn1(self.conv1(x)))
294 | if self.stride != 1: # modification
295 | out = F.avg_pool2d(out, kernel_size=2, stride=2)
296 | out = self.bn2(self.conv2(out))
297 | out += self.shortcut(x)
298 | out = F.relu(out)
299 | return out
300 |
301 |
302 | class Bottleneck_AP(nn.Module):
303 | expansion = 4
304 |
305 | def __init__(self, in_planes, planes, stride=1, norm='instancenorm'):
306 | super(Bottleneck_AP, self).__init__()
307 | self.norm = norm
308 | self.stride = stride
309 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
310 | self.bn1 = nn.GroupNorm(planes, planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(planes)
311 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) # modification
312 | self.bn2 = nn.GroupNorm(planes, planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(planes)
313 | self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False)
314 | self.bn3 = nn.GroupNorm(self.expansion * planes, self.expansion * planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(self.expansion * planes)
315 |
316 | self.shortcut = nn.Sequential()
317 | if stride != 1 or in_planes != self.expansion * planes:
318 | self.shortcut = nn.Sequential(
319 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=1, bias=False),
320 | nn.AvgPool2d(kernel_size=2, stride=2), # modification
321 | nn.GroupNorm(self.expansion * planes, self.expansion * planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(self.expansion * planes)
322 | )
323 |
324 | def forward(self, x):
325 | out = F.relu(self.bn1(self.conv1(x)))
326 | out = F.relu(self.bn2(self.conv2(out)))
327 | if self.stride != 1: # modification
328 | out = F.avg_pool2d(out, kernel_size=2, stride=2)
329 | out = self.bn3(self.conv3(out))
330 | out += self.shortcut(x)
331 | out = F.relu(out)
332 | return out
333 |
334 |
335 | class ResNet_AP(nn.Module):
336 | def __init__(self, block, num_blocks, channel=3, num_classes=10, norm='instancenorm'):
337 | super(ResNet_AP, self).__init__()
338 | self.in_planes = 64
339 | self.norm = norm
340 |
341 | self.conv1 = nn.Conv2d(channel, 64, kernel_size=3, stride=1, padding=1, bias=False)
342 | self.bn1 = nn.GroupNorm(64, 64, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(64)
343 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
344 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
345 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
346 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
347 | self.classifier = nn.Linear(512 * block.expansion * 3 * 3 if channel==1 else 512 * block.expansion * 4 * 4, num_classes) # modification
348 |
349 | def _make_layer(self, block, planes, num_blocks, stride):
350 | strides = [stride] + [1] * (num_blocks - 1)
351 | layers = []
352 | for stride in strides:
353 | layers.append(block(self.in_planes, planes, stride, self.norm))
354 | self.in_planes = planes * block.expansion
355 | return nn.Sequential(*layers)
356 |
357 | def forward(self, x):
358 | out = F.relu(self.bn1(self.conv1(x)))
359 | out = self.layer1(out)
360 | out = self.layer2(out)
361 | out = self.layer3(out)
362 | out = self.layer4(out)
363 | out = F.avg_pool2d(out, kernel_size=1, stride=1) # modification
364 | out = out.view(out.size(0), -1)
365 | out = self.classifier(out)
366 | return out
367 |
368 | def embed(self, x):
369 | out = F.relu(self.bn1(self.conv1(x)))
370 | out = self.layer1(out)
371 | out = self.layer2(out)
372 | out = self.layer3(out)
373 | out = self.layer4(out)
374 | out = F.avg_pool2d(out, kernel_size=1, stride=1) # modification
375 | out = out.view(out.size(0), -1)
376 | return out
377 |
378 | def ResNet18BN_AP(channel, num_classes):
379 | return ResNet_AP(BasicBlock_AP, [2,2,2,2], channel=channel, num_classes=num_classes, norm='batchnorm')
380 |
381 | def ResNet18_AP(channel, num_classes):
382 | return ResNet_AP(BasicBlock_AP, [2,2,2,2], channel=channel, num_classes=num_classes)
383 |
384 |
385 | ''' ResNet '''
386 |
387 | class BasicBlock(nn.Module):
388 | expansion = 1
389 |
390 | def __init__(self, in_planes, planes, stride=1, norm='instancenorm'):
391 | super(BasicBlock, self).__init__()
392 | self.norm = norm
393 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
394 | self.bn1 = nn.GroupNorm(planes, planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(planes)
395 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
396 | self.bn2 = nn.GroupNorm(planes, planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(planes)
397 |
398 | self.shortcut = nn.Sequential()
399 | if stride != 1 or in_planes != self.expansion*planes:
400 | self.shortcut = nn.Sequential(
401 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
402 | nn.GroupNorm(self.expansion*planes, self.expansion*planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(self.expansion*planes)
403 | )
404 |
405 | def forward(self, x):
406 | out = F.relu(self.bn1(self.conv1(x)))
407 | out = self.bn2(self.conv2(out))
408 | out += self.shortcut(x)
409 | out = F.relu(out)
410 | return out
411 |
412 |
413 | class Bottleneck(nn.Module):
414 | expansion = 4
415 |
416 | def __init__(self, in_planes, planes, stride=1, norm='instancenorm'):
417 | super(Bottleneck, self).__init__()
418 | self.norm = norm
419 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
420 | self.bn1 = nn.GroupNorm(planes, planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(planes)
421 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
422 | self.bn2 = nn.GroupNorm(planes, planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(planes)
423 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)
424 | self.bn3 = nn.GroupNorm(self.expansion*planes, self.expansion*planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(self.expansion*planes)
425 |
426 | self.shortcut = nn.Sequential()
427 | if stride != 1 or in_planes != self.expansion*planes:
428 | self.shortcut = nn.Sequential(
429 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
430 | nn.GroupNorm(self.expansion*planes, self.expansion*planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(self.expansion*planes)
431 | )
432 |
433 | def forward(self, x):
434 | out = F.relu(self.bn1(self.conv1(x)))
435 | out = F.relu(self.bn2(self.conv2(out)))
436 | out = self.bn3(self.conv3(out))
437 | out += self.shortcut(x)
438 | out = F.relu(out)
439 | return out
440 |
441 |
442 | class ResNet(nn.Module):
443 | def __init__(self, block, num_blocks, channel=3, num_classes=10, norm='instancenorm'):
444 | super(ResNet, self).__init__()
445 | self.in_planes = 64
446 | self.norm = norm
447 |
448 | self.conv1 = nn.Conv2d(channel, 64, kernel_size=3, stride=1, padding=1, bias=False)
449 | self.bn1 = nn.GroupNorm(64, 64, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(64)
450 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
451 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
452 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
453 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
454 | self.classifier = nn.Linear(512*block.expansion, num_classes)
455 |
456 | def _make_layer(self, block, planes, num_blocks, stride):
457 | strides = [stride] + [1]*(num_blocks-1)
458 | layers = []
459 | for stride in strides:
460 | layers.append(block(self.in_planes, planes, stride, self.norm))
461 | self.in_planes = planes * block.expansion
462 | return nn.Sequential(*layers)
463 |
464 | def forward(self, x):
465 | out = F.relu(self.bn1(self.conv1(x)))
466 | out = self.layer1(out)
467 | out = self.layer2(out)
468 | out = self.layer3(out)
469 | out = self.layer4(out)
470 | out = F.avg_pool2d(out, 4)
471 | emb = out.view(out.size(0), -1)
472 | out = self.classifier(emb)
473 | return emb, out
474 |
475 | def embed(self, x):
476 | out = F.relu(self.bn1(self.conv1(x)))
477 | out = self.layer1(out)
478 | out = self.layer2(out)
479 | out = self.layer3(out)
480 | out = self.layer4(out)
481 | out = F.avg_pool2d(out, 4)
482 | out = out.view(out.size(0), -1)
483 | return out
484 |
485 |
486 |
487 |
488 |
489 |
490 |
491 | def ResNet18BN(channel, num_classes):
492 | return ResNet(BasicBlock, [2,2,2,2], channel=channel, num_classes=num_classes, norm='batchnorm')
493 |
494 | def ResNet18(channel, num_classes):
495 | return ResNet(BasicBlock, [2,2,2,2], channel=channel, num_classes=num_classes)
496 |
497 | def ResNet34(channel, num_classes):
498 | return ResNet(BasicBlock, [3,4,6,3], channel=channel, num_classes=num_classes)
499 |
500 | def ResNet50(channel, num_classes):
501 | return ResNet(Bottleneck, [3,4,6,3], channel=channel, num_classes=num_classes)
502 |
503 | def ResNet101(channel, num_classes):
504 | return ResNet(Bottleneck, [3,4,23,3], channel=channel, num_classes=num_classes)
505 |
506 | def ResNet152(channel, num_classes):
507 | return ResNet(Bottleneck, [3,8,36,3], channel=channel, num_classes=num_classes)
508 |
509 |
510 |
511 |
512 | '''ViT Model '''
513 | def pair(t):
514 | return t if isinstance(t, tuple) else (t, t)
515 |
516 | # classes
517 | class PreNorm(nn.Module):
518 | def __init__(self, dim, fn):
519 | super().__init__()
520 | self.norm = nn.LayerNorm(dim)
521 | self.fn = fn
522 | def forward(self, x, **kwargs):
523 | return self.fn(self.norm(x), **kwargs)
524 |
525 | class FeedForward(nn.Module):
526 | def __init__(self, dim, hidden_dim, dropout = 0.):
527 | super().__init__()
528 | self.net = nn.Sequential(
529 | nn.Linear(dim, hidden_dim),
530 | nn.GELU(),
531 | nn.Dropout(dropout),
532 | nn.Linear(hidden_dim, dim),
533 | nn.Dropout(dropout)
534 | )
535 | def forward(self, x):
536 | return self.net(x)
537 |
538 | class Attention(nn.Module):
539 | def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
540 | super().__init__()
541 | inner_dim = dim_head * heads
542 | project_out = not (heads == 1 and dim_head == dim)
543 |
544 | self.heads = heads
545 | self.scale = dim_head ** -0.5
546 |
547 | self.attend = nn.Softmax(dim = -1)
548 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
549 |
550 | self.to_out = nn.Sequential(
551 | nn.Linear(inner_dim, dim),
552 | nn.Dropout(dropout)
553 | ) if project_out else nn.Identity()
554 |
555 | def forward(self, x):
556 | qkv = self.to_qkv(x).chunk(3, dim = -1)
557 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
558 |
559 | dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
560 |
561 | attn = self.attend(dots)
562 |
563 | out = torch.matmul(attn, v)
564 | out = rearrange(out, 'b h n d -> b n (h d)')
565 | return self.to_out(out)
566 |
567 | class Transformer(nn.Module):
568 | def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
569 | super().__init__()
570 | self.layers = nn.ModuleList([])
571 | for _ in range(depth):
572 | self.layers.append(nn.ModuleList([
573 | PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
574 | PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
575 | ]))
576 | def forward(self, x):
577 | for attn, ff in self.layers:
578 | x = attn(x) + x
579 | x = ff(x) + x
580 | return x
581 |
582 | class ViT(nn.Module):
583 | def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
584 | super().__init__()
585 | image_height, image_width = pair(image_size)
586 | patch_height, patch_width = pair(patch_size)
587 |
588 | assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
589 |
590 | num_patches = (image_height // patch_height) * (image_width // patch_width)
591 | patch_dim = channels * patch_height * patch_width
592 | assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
593 |
594 | self.to_patch_embedding = nn.Sequential(
595 | Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
596 | nn.Linear(patch_dim, dim),
597 | )
598 |
599 | self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
600 | self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
601 | self.dropout = nn.Dropout(emb_dropout)
602 |
603 | self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
604 |
605 | self.pool = pool
606 | self.to_latent = nn.Identity()
607 |
608 | self.mlp_head = nn.Sequential(
609 | nn.LayerNorm(dim),
610 | nn.Linear(dim, num_classes)
611 | )
612 |
613 | def forward(self, img):
614 | x = self.to_patch_embedding(img)
615 | b, n, _ = x.shape
616 |
617 | cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
618 | x = torch.cat((cls_tokens, x), dim=1)
619 | x += self.pos_embedding[:, :(n + 1)]
620 | x = self.dropout(x)
621 |
622 | x = self.transformer(x)
623 |
624 | x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
625 |
626 | x = self.to_latent(x)
627 | return x, self.mlp_head(x)
628 |
629 | def ViTModel(im_size, num_classes):
630 | return ViT(
631 | image_size = im_size,
632 | patch_size = 4,
633 | num_classes = num_classes,
634 | dim = 512,
635 | depth = 6,
636 | heads = 8,
637 | mlp_dim = 512,
638 | dropout = 0.1,
639 | emb_dropout = 0.1)
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | absl-py==1.4.0
2 | aiohttp==3.8.3
3 | aiosignal==1.3.1
4 | astunparse==1.6.3
5 | async-timeout==4.0.2
6 | asynctest==0.13.0
7 | attrs==22.2.0
8 | blessed==1.20.0
9 | cachetools==5.3.0
10 | certifi==2022.12.7
11 | charset-normalizer==2.1.1
12 | cycler==0.11.0
13 | einops==0.6.0
14 | flatbuffers==23.1.21
15 | fonttools==4.38.0
16 | frozenlist==1.3.3
17 | fsspec==2023.1.0
18 | gast==0.4.0
19 | google-auth==2.16.1
20 | google-auth-oauthlib==0.4.6
21 | google-pasta==0.2.0
22 | gpustat==1.0.0
23 | grpcio==1.51.3
24 | h5py==3.8.0
25 | idna==3.4
26 | importlib-metadata==6.0.0
27 | joblib==1.2.0
28 | keras==2.11.0
29 | kiwisolver==1.4.4
30 | kornia==0.6.9
31 | libclang==15.0.6.1
32 | lightning-utilities==0.6.0.post0
33 | Markdown==3.4.1
34 | MarkupSafe==2.1.2
35 | matplotlib==3.5.3
36 | multidict==6.0.4
37 | nas-bench-201==2.1
38 | numpy==1.21.6
39 | nvidia-cublas-cu11==11.10.3.66
40 | nvidia-cuda-nvrtc-cu11==11.7.99
41 | nvidia-cuda-runtime-cu11==11.7.99
42 | nvidia-cudnn-cu11==8.5.0.96
43 | nvidia-ml-py==11.495.46
44 | oauthlib==3.2.2
45 | opt-einsum==3.3.0
46 | packaging==23.0
47 | pandas==1.3.5
48 | Pillow==9.4.0
49 | pkg_resources==0.0.0
50 | protobuf==3.19.6
51 | psutil==5.9.4
52 | pyasn1==0.4.8
53 | pyasn1-modules==0.2.8
54 | pyparsing==3.0.9
55 | python-dateutil==2.8.2
56 | pytorch-lightning==1.9.0
57 | pytz==2022.7.1
58 | PyYAML==6.0
59 | requests==2.28.2
60 | requests-oauthlib==1.3.1
61 | rsa==4.9
62 | scikit-learn==1.0.2
63 | scipy==1.1.0
64 | seaborn==0.12.2
65 | six==1.16.0
66 | tensorboard==2.11.2
67 | tensorboard-data-server==0.6.1
68 | tensorboard-plugin-wit==1.8.1
69 | tensorflow==2.11.0
70 | tensorflow-estimator==2.11.0
71 | tensorflow-io-gcs-filesystem==0.30.0
72 | termcolor==2.2.0
73 | threadpoolctl==3.1.0
74 | torch==1.13.1
75 | torchmetrics==0.11.1
76 | torchvision==0.14.1
77 | tqdm==4.64.1
78 | typing_extensions==4.4.0
79 | urllib3==1.26.14
80 | wcwidth==0.2.6
81 | Werkzeug==2.2.3
82 | wrapt==1.14.1
83 | yarl==1.8.2
84 | zipp==3.11.0
85 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import time
2 | import os
3 | import numpy as np
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | from torch.utils.data import Dataset
8 | from torchvision import datasets, transforms
9 | from scipy.ndimage.interpolation import rotate as scipyrotate
10 | from networks import MLP, ConvNet, LeNet, AlexNet, AlexNetBN, VGG11, VGG11BN, ResNet18, ResNet18BN_AP, ResNet18BN, ViTModel
11 | import tqdm
12 | import kornia as K
13 | from copy import deepcopy
14 | from torchvision.utils import save_image
15 |
16 | # Attention Module
17 | def get_attention(feature_set, param=0, exp=4, norm='l2'):
18 | if param==0:
19 | attention_map = torch.sum(torch.abs(feature_set), dim=1)
20 |
21 | elif param ==1:
22 | attention_map = torch.sum(torch.abs(feature_set)**exp, dim=1)
23 |
24 | elif param == 2:
25 | attention_map = torch.max(torch.abs(feature_set)**exp, dim=1)
26 |
27 | if norm == 'l2':
28 | # Dimension: [B x (H*W)] -- Vectorized
29 | vectorized_attention_map = attention_map.view(feature_set.size(0), -1)
30 | normalized_attention_maps = F.normalize(vectorized_attention_map, p=2.0)
31 |
32 | elif norm == 'fro':
33 | # Dimension: [B x H x W] -- Un-Vectorized
34 | un_vectorized_attention_map = attention_map
35 | # Dimension: [B]
36 | fro_norm = torch.sum(torch.sum(torch.abs(attention_map)**2, dim=1), dim=1)
37 | # Dimension: [B x H x W] -- Un-Vectorized)
38 | normalized_attention_maps = un_vectorized_attention_map / fro_norm.unsqueeze(dim=-1).unsqueeze(dim=-1)
39 | elif norm == 'l1':
40 | # Dimension: [B x (H*W)] -- Vectorized
41 | vectorized_attention_map = attention_map.view(feature_set.size(0), -1)
42 | normalized_attention_maps = F.normalize(vectorized_attention_map, p=1.0)
43 |
44 | elif norm =='none':
45 | normalized_attention_maps = attention_map
46 |
47 | elif norm == 'none-vectorized':
48 | normalized_attention_maps = attention_map.view(feature_set.size(0), -1)
49 |
50 | return normalized_attention_maps
51 |
52 |
53 |
54 |
55 | def get_dataset(dataset, data_path, args):
56 | if dataset == 'MNIST':
57 | channel = 1
58 | im_size = (28, 28)
59 | num_classes = 10
60 | mean = [0.1307]
61 | std = [0.3081]
62 | transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=mean, std=std)])
63 | dst_train = datasets.MNIST(data_path, train=True, download=True, transform=transform) # no augmentation
64 | dst_test = datasets.MNIST(data_path, train=False, download=True, transform=transform)
65 | class_names = [str(c) for c in range(num_classes)]
66 |
67 | elif dataset == 'CIFAR10':
68 | channel = 3
69 | im_size = (32, 32)
70 | num_classes = 10
71 | mean = [0.4914, 0.4822, 0.4465]
72 | std = [0.2023, 0.1994, 0.2010]
73 | transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=mean, std=std)])
74 | dst_train = datasets.CIFAR10(data_path, train=True, download=True, transform=transform) # no augmentation
75 | dst_test = datasets.CIFAR10(data_path, train=False, download=True, transform=transform)
76 | class_names = dst_train.classes
77 |
78 | elif dataset == 'CIFAR100':
79 | channel = 3
80 | im_size = (32, 32)
81 | num_classes = 100
82 | mean = [0.5071, 0.4866, 0.4409]
83 | std = [0.2673, 0.2564, 0.2762]
84 | transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=mean, std=std)])
85 | dst_train = datasets.CIFAR100(data_path, train=True, download=True, transform=transform) # no augmentation
86 | dst_test = datasets.CIFAR100(data_path, train=False, download=True, transform=transform)
87 | class_names = dst_train.classes
88 |
89 | elif dataset == 'TinyImageNet':
90 | channel = 3
91 | im_size = (64, 64)
92 | num_classes = 200
93 | mean = [0.485, 0.456, 0.406]
94 | std = [0.229, 0.224, 0.225]
95 | data = torch.load(os.path.join(data_path, 'tinyimagenet.pt'), map_location='cpu')
96 |
97 | class_names = data['classes']
98 |
99 | images_train = data['images_train']
100 | labels_train = data['labels_train']
101 | images_train = images_train.detach().float() / 255.0
102 | labels_train = labels_train.detach()
103 | for c in range(channel):
104 | images_train[:,c] = (images_train[:,c].clone() - mean[c])/std[c]
105 | dst_train = TensorDataset(images_train, labels_train) # no augmentation
106 |
107 | images_val = data['images_val']
108 | labels_val = data['labels_val']
109 | images_val = images_val.detach().float() / 255.0
110 | labels_val = labels_val.detach()
111 |
112 | for c in range(channel):
113 | images_val[:, c] = (images_val[:, c].clone() - mean[c]) / std[c]
114 |
115 | dst_test = TensorDataset(images_val, labels_val) # no augmentation
116 |
117 | elif dataset == 'ImageNette':
118 | channel = 3
119 | im_size = (128, 128)
120 | num_classes = 10
121 |
122 | class_names = ["Tench", "English Springer", "Cassette Player", "Chainsaw", "Church", "French Horn", "Garbage Truck", "Gas Pump","Golf Ball", "Parachute"]
123 |
124 | mean = [0.485, 0.456, 0.406]
125 | std = [0.229, 0.224, 0.225]
126 |
127 | transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=mean, std=std),
128 | transforms.Resize(im_size),
129 | transforms.CenterCrop(im_size)])
130 | dst_train = datasets.ImageFolder(os.path.join(data_path, "train"), transform=transform) # no augmentation
131 | dst_test = datasets.ImageFolder(os.path.join(data_path, "val"), transform=transform)
132 |
133 | elif dataset == 'ImageWoof':
134 | channel = 3
135 | im_size = (128, 128)
136 | num_classes = 10
137 |
138 | class_names = ["Australian Terrier", "Border Terrier", "Samoyed", "Beagle", "Shih-Tzu" ,"English Foxhound", "Rhodesian Ridgeback", "Dingo", "Golden Retriever", "English Sheepdog"]
139 |
140 | mean = [0.485, 0.456, 0.406]
141 | std = [0.229, 0.224, 0.225]
142 |
143 | transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=mean, std=std),
144 | transforms.Resize(im_size),
145 | transforms.CenterCrop(im_size)])
146 | dst_train = datasets.ImageFolder(os.path.join(data_path, "train"), transform=transform) # no augmentation
147 | dst_test = datasets.ImageFolder(os.path.join(data_path, "val"), transform=transform)
148 |
149 |
150 | elif dataset == 'ImageSquack':
151 | channel = 3
152 | im_size = (128, 128)
153 | num_classes = 10
154 |
155 | class_names = ["peacock", "flamingo", "macaw", "pelican", "king_penguin", "bald_eagle", "toucan", "ostrich", "black_swan", "cockatoo"]
156 |
157 | mean = [0.485, 0.456, 0.406]
158 | std = [0.229, 0.224, 0.225]
159 |
160 | transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=mean, std=std),
161 | transforms.Resize(im_size),
162 | transforms.CenterCrop(im_size)])
163 | dst_train = datasets.ImageFolder(os.path.join(data_path, "train"), transform=transform) # no augmentation
164 | dst_test = datasets.ImageFolder(os.path.join(data_path, "val"), transform=transform)
165 |
166 | elif dataset == 'ImageFruit':
167 | imagefruit = [953, 954, 949, 950, 951, 957, 952, 945, 943, 948]
168 | imagefruitLabels = {j:i for i,j in enumerate(imagefruit)}
169 | class_names = ["pineapple", "banana", "strawberry", "orange", "lemon", "pomegranate", "fig", "bell_pepper", "cucumber", "green_apple"]
170 | channel = 3
171 | im_size = (128, 128)
172 | num_classes = 10
173 |
174 | mean = [0.485, 0.456, 0.406]
175 | std = [0.229, 0.224, 0.225]
176 | if args.zca:
177 | transform = transforms.Compose([transforms.ToTensor(),
178 | transforms.Resize(im_size),
179 | transforms.CenterCrop(im_size)])
180 | else:
181 | transform = transforms.Compose([transforms.ToTensor(),
182 | transforms.Normalize(mean=mean, std=std),
183 | transforms.Resize(im_size),
184 | transforms.CenterCrop(im_size)])
185 |
186 | dst_train = datasets.ImageFolder(os.path.join(data_path, "train"), transform=transform)
187 |
188 | for idx, (image, label) in enumerate(dst_train):
189 | if label in imagefruit:
190 | selected_dataset.append((image, imagefruitLabels[label]))
191 | # Create a new dataset using the selected classes
192 | dst_train = torch.utils.data.Subset(selected_dataset, torch.arange(len(selected_dataset)))
193 |
194 | dst_test = datasets.ImageFolder(os.path.join(data_path, "val"), transform=transform)
195 | dst_test = torch.utils.data.Subset(dst_test, np.squeeze(np.argwhere(np.isin(dst_test.targets, imagefruit))))
196 | for c in range(len(imagefruit)):
197 | dst_test.dataset.targets[dst_test.dataset.targets == imagefruit[c]] = c
198 | dst_train.dataset.targets[dst_train.dataset.targets == imagefruit[c]] = c
199 | print(dst_test.dataset)
200 | print(len(dst_train ))
201 | print(dst_train.dataset)
202 | print(min(dst_train.dataset.targets), max(dst_train.dataset.targets))
203 | class_map = {x: i for i, x in enumerate(imagefruit)}
204 | class_map_inv = {i: x for i, x in enumerate(imagefruit)}
205 | class_names = None
206 |
207 |
208 | else:
209 | exit('unknown dataset: %s'%dataset)
210 | zca=None
211 | if args.zca:
212 | images = []
213 | labels = []
214 | print("Train ZCA")
215 | for i in tqdm.tqdm(range(len(dst_train))):
216 | im, lab = dst_train[i]
217 | images.append(im)
218 | labels.append(lab)
219 | images = torch.stack(images, dim=0).to(args.device)
220 | labels = torch.tensor(labels, dtype=torch.long, device="cpu")
221 | zca = K.enhance.ZCAWhitening(eps=0.1, compute_inv=True)
222 | zca.fit(images)
223 | zca_images = zca(images).to("cpu")
224 | dst_train = TensorDataset(zca_images, labels)
225 |
226 | images = []
227 | labels = []
228 | print("Test ZCA")
229 | for i in tqdm.tqdm(range(len(dst_test))):
230 | im, lab = dst_test[i]
231 | images.append(im)
232 | labels.append(lab)
233 | images = torch.stack(images, dim=0).to(args.device)
234 | labels = torch.tensor(labels, dtype=torch.long, device="cpu")
235 |
236 | zca_images = zca(images).to("cpu")
237 | dst_test = TensorDataset(zca_images, labels)
238 |
239 | args.zca_trans = zca
240 |
241 | testloader = torch.utils.data.DataLoader(dst_test, batch_size=256, shuffle=False, num_workers=0)
242 | return channel, im_size, num_classes, class_names, mean, std, dst_train, dst_test, testloader, zca
243 |
244 |
245 |
246 | class TensorDataset(Dataset):
247 | def __init__(self, images, labels): # images: n x c x h x w tensor
248 | self.images = images.detach().float()
249 | self.labels = labels.detach()
250 |
251 | def __getitem__(self, index):
252 | return self.images[index], self.labels[index]
253 |
254 | def __len__(self):
255 | return self.images.shape[0]
256 |
257 |
258 |
259 | def get_default_convnet_setting():
260 | net_width, net_depth, net_act, net_norm, net_pooling = 128, 3, 'relu', 'instancenorm', 'avgpooling'
261 | return net_width, net_depth, net_act, net_norm, net_pooling
262 |
263 |
264 |
265 | def get_network(model, channel, num_classes, im_size=(32, 32)):
266 | torch.random.manual_seed(int(time.time() * 1000) % 100000)
267 | net_width, net_depth, net_act, net_norm, net_pooling = get_default_convnet_setting()
268 |
269 | if model == 'MLP':
270 | net = MLP(channel=channel, num_classes=num_classes)
271 | elif model == 'ConvNet':
272 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=net_depth, net_act=net_act, net_norm='batchnorm', net_pooling=net_pooling, im_size=im_size)
273 |
274 | elif model == 'ConvNet128IN': # Higher Resolution
275 | net_depth=6
276 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=net_depth, net_act=net_act, net_norm='batchnorm', net_pooling=net_pooling, im_size=im_size)
277 |
278 | elif model == 'LeNet':
279 | net = LeNet(channel=channel, num_classes=num_classes)
280 | elif model == 'AlexNet':
281 | net = AlexNet(channel=channel, num_classes=num_classes)
282 | elif model == 'AlexNetBN':
283 | net = AlexNetBN(channel=channel, num_classes=num_classes)
284 | elif model == 'VGG11':
285 | net = VGG11( channel=channel, num_classes=num_classes)
286 | elif model == 'VGG11BN':
287 | net = VGG11BN(channel=channel, num_classes=num_classes)
288 | elif model == 'ResNet18':
289 | net = ResNet18(channel=channel, num_classes=num_classes)
290 | elif model == 'ResNet18BN_AP':
291 | net = ResNet18BN_AP(channel=channel, num_classes=num_classes)
292 | elif model == 'ResNet18BN':
293 | net = ResNet18BN(channel=channel, num_classes=num_classes)
294 | elif model == 'ViT':
295 | net = ViTModel(im_size, num_classes)
296 | print("ViT Model")
297 | elif model == 'ConvNetD1':
298 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=1, net_act=net_act, net_norm=net_norm, net_pooling=net_pooling, im_size=im_size)
299 | elif model == 'ConvNetD2':
300 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=2, net_act=net_act, net_norm=net_norm, net_pooling=net_pooling, im_size=im_size)
301 | elif model == 'ConvNetD3':
302 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=3, net_act=net_act, net_norm=net_norm, net_pooling=net_pooling, im_size=im_size)
303 | elif model == 'ConvNetD4':
304 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=4, net_act=net_act, net_norm=net_norm, net_pooling=net_pooling, im_size=im_size)
305 |
306 | elif model == 'ConvNetW32':
307 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=32, net_depth=net_depth, net_act=net_act, net_norm=net_norm, net_pooling=net_pooling, im_size=im_size)
308 | elif model == 'ConvNetW64':
309 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=64, net_depth=net_depth, net_act=net_act, net_norm=net_norm, net_pooling=net_pooling, im_size=im_size)
310 | elif model == 'ConvNetW128':
311 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=128, net_depth=net_depth, net_act=net_act, net_norm=net_norm, net_pooling=net_pooling, im_size=im_size)
312 | elif model == 'ConvNetW256':
313 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=256, net_depth=net_depth, net_act=net_act, net_norm=net_norm, net_pooling=net_pooling, im_size=im_size)
314 |
315 | elif model == 'ConvNetAS':
316 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=net_depth, net_act='sigmoid', net_norm=net_norm, net_pooling=net_pooling, im_size=im_size)
317 | elif model == 'ConvNetAR':
318 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=net_depth, net_act='relu', net_norm=net_norm, net_pooling=net_pooling, im_size=im_size)
319 | elif model == 'ConvNetAL':
320 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=net_depth, net_act='leakyrelu', net_norm=net_norm, net_pooling=net_pooling, im_size=im_size)
321 | elif model == 'ConvNetASwish':
322 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=net_depth, net_act='swish', net_norm=net_norm, net_pooling=net_pooling, im_size=im_size)
323 | elif model == 'ConvNetASwishBN':
324 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=net_depth, net_act='swish', net_norm='batchnorm', net_pooling=net_pooling, im_size=im_size)
325 |
326 | elif model == 'ConvNetNN':
327 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=net_depth, net_act=net_act, net_norm='none', net_pooling=net_pooling, im_size=im_size)
328 | elif model == 'ConvNetBN':
329 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=net_depth, net_act=net_act, net_norm='batchnorm', net_pooling=net_pooling, im_size=im_size)
330 | elif model == 'ConvNetBNImageNet':
331 | net_depth=4
332 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=net_depth, net_act=net_act, net_norm='batchnorm', net_pooling=net_pooling, im_size=im_size)
333 | elif model == 'ConvNetLN':
334 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=net_depth, net_act=net_act, net_norm='layernorm', net_pooling=net_pooling, im_size=im_size)
335 | elif model == 'ConvNetIN':
336 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=net_depth, net_act=net_act, net_norm='instancenorm', net_pooling=net_pooling, im_size=im_size)
337 | elif model == 'ConvNetGN':
338 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=net_depth, net_act=net_act, net_norm='groupnorm', net_pooling=net_pooling, im_size=im_size)
339 |
340 | elif model == 'ConvNetNP':
341 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=net_depth, net_act=net_act, net_norm=net_norm, net_pooling='none', im_size=im_size)
342 | elif model == 'ConvNetMP':
343 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=net_depth, net_act=net_act, net_norm=net_norm, net_pooling='maxpooling', im_size=im_size)
344 | elif model == 'ConvNetAP':
345 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=net_depth, net_act=net_act, net_norm=net_norm, net_pooling='avgpooling', im_size=im_size)
346 |
347 | else:
348 | net = None
349 | exit('unknown model: %s'%model)
350 |
351 | gpu_num = torch.cuda.device_count()
352 |
353 | if gpu_num > 1:
354 | net = nn.DataParallel(net)
355 | net = net.cuda()
356 |
357 | return net
358 |
359 |
360 |
361 | def get_time():
362 | return str(time.strftime("[%Y-%m-%d %H:%M:%S]", time.localtime()))
363 |
364 | def get_loops(ipc):
365 | # Get the two hyper-parameters of outer-loop and inner-loop.
366 | # The following values are empirically good.
367 | if ipc == 1:
368 | outer_loop, inner_loop = 1, 1
369 | elif ipc == 10:
370 | outer_loop, inner_loop = 10, 50
371 | elif ipc == 20:
372 | outer_loop, inner_loop = 20, 25
373 | elif ipc == 30:
374 | outer_loop, inner_loop = 30, 20
375 | elif ipc == 40:
376 | outer_loop, inner_loop = 40, 15
377 | elif ipc == 50:
378 | outer_loop, inner_loop = 50, 10
379 | else:
380 | outer_loop, inner_loop = 0, 0
381 | exit('loop hyper-parameters are not defined for %d ipc'%ipc)
382 | return outer_loop, inner_loop
383 |
384 |
385 |
386 | def epoch(mode, dataloader, net, optimizer, criterion, args, aug):
387 | loss_avg, acc_avg, num_exp = 0, 0, 0
388 | net = net.to(args.device)
389 | criterion = criterion.to(args.device)
390 |
391 | if mode == 'train':
392 | net.train()
393 | else:
394 | net.eval()
395 |
396 | for i_batch, datum in enumerate(dataloader):
397 | img = datum[0].float().to(args.device)
398 | if aug:
399 | if args.dsa:
400 | img = DiffAugment(img, args.dsa_strategy, param=args.dsa_param)
401 | lab = datum[1].long().to(args.device)
402 | n_b = lab.shape[0]
403 |
404 | output = net(img)[1]
405 | loss = criterion(output, lab)
406 | acc = np.sum(np.equal(np.argmax(output.cpu().data.numpy(), axis=-1), lab.cpu().data.numpy()))
407 |
408 | loss_avg += loss.item()*n_b
409 | acc_avg += acc
410 | num_exp += n_b
411 |
412 | if mode == 'train':
413 | optimizer.zero_grad()
414 | loss.backward()
415 | optimizer.step()
416 | loss_avg /= num_exp
417 | acc_avg /= num_exp
418 |
419 | return loss_avg, acc_avg
420 |
421 |
422 |
423 | def evaluate_synset(it_eval, net, images_train, labels_train, testloader, args, skip=False):
424 | net = net.to(args.device)
425 |
426 | images_train = images_train.to(args.device)
427 | labels_train = labels_train.to(args.device)
428 | lr = float(args.lr_net)
429 | Epoch = int(args.epoch_eval_train)
430 | lr_schedule = [Epoch//2+1]
431 | optimizer = torch.optim.SGD(net.parameters(), lr=lr, momentum=0.9, weight_decay=0.0005)
432 | criterion = nn.CrossEntropyLoss().to(args.device)
433 |
434 | dst_train = TensorDataset(images_train, labels_train)
435 | trainloader = torch.utils.data.DataLoader(dst_train, batch_size=args.batch_train, shuffle=True, num_workers=0)
436 |
437 | start = time.time()
438 | acc_test = 0
439 | loss_train = 0
440 | time_train = 0
441 | acc_train = 0
442 | if not skip:
443 | for ep in tqdm.tqdm(range(Epoch+1)):
444 | loss_train, acc_train = epoch('train', trainloader, net, optimizer, criterion, args, aug = True)
445 | if ep in lr_schedule:
446 | lr *= 0.1
447 | optimizer = torch.optim.SGD(net.parameters(), lr=lr, momentum=0.9, weight_decay=0.0005)
448 | time_train = time.time() - start
449 |
450 | loss_test, acc_test = epoch('test', testloader, net, optimizer, criterion, args, aug = False)
451 | print('%s Evaluate_%02d: epoch = %04d train time = %d s train loss = %.6f train acc = %.4f, test acc = %.4f' % (get_time(), it_eval, Epoch, int(time_train), loss_train, acc_train, acc_test))
452 |
453 | return net, acc_train, acc_test
454 |
455 |
456 |
457 | def augment(images, dc_aug_param, device):
458 | # This can be sped up in the future.
459 | print("In here, no dsa lol", dc_aug_param)
460 |
461 |
462 |
463 | if dc_aug_param != None and dc_aug_param['strategy'] != 'none':
464 | scale = dc_aug_param['scale']
465 | crop = dc_aug_param['crop']
466 | rotate = dc_aug_param['rotate']
467 | noise = dc_aug_param['noise']
468 | strategy = dc_aug_param['strategy']
469 |
470 | shape = images.shape
471 | mean = []
472 | for c in range(shape[1]):
473 | mean.append(float(torch.mean(images[:,c])))
474 |
475 | def cropfun(i):
476 | im_ = torch.zeros(shape[1],shape[2]+crop*2,shape[3]+crop*2, dtype=torch.float, device=device)
477 | for c in range(shape[1]):
478 | im_[c] = mean[c]
479 | im_[:, crop:crop+shape[2], crop:crop+shape[3]] = images[i]
480 | r, c = np.random.permutation(crop*2)[0], np.random.permutation(crop*2)[0]
481 | images[i] = im_[:, r:r+shape[2], c:c+shape[3]]
482 |
483 | def scalefun(i):
484 | h = int((np.random.uniform(1 - scale, 1 + scale)) * shape[2])
485 | w = int((np.random.uniform(1 - scale, 1 + scale)) * shape[2])
486 | tmp = F.interpolate(images[i:i + 1], [h, w], )[0]
487 | mhw = max(h, w, shape[2], shape[3])
488 | im_ = torch.zeros(shape[1], mhw, mhw, dtype=torch.float, device=device)
489 | r = int((mhw - h) / 2)
490 | c = int((mhw - w) / 2)
491 | im_[:, r:r + h, c:c + w] = tmp
492 | r = int((mhw - shape[2]) / 2)
493 | c = int((mhw - shape[3]) / 2)
494 | images[i] = im_[:, r:r + shape[2], c:c + shape[3]]
495 |
496 | def rotatefun(i):
497 | im_ = scipyrotate(images[i].cpu().data.numpy(), angle=np.random.randint(-rotate, rotate), axes=(-2, -1), cval=np.mean(mean))
498 | r = int((im_.shape[-2] - shape[-2]) / 2)
499 | c = int((im_.shape[-1] - shape[-1]) / 2)
500 | images[i] = torch.tensor(im_[:, r:r + shape[-2], c:c + shape[-1]], dtype=torch.float, device=device)
501 |
502 | def noisefun(i):
503 | images[i] = images[i] + noise * torch.randn(shape[1:], dtype=torch.float, device=device)
504 |
505 |
506 | augs = strategy.split('_')
507 |
508 | for i in range(shape[0]):
509 | choice = np.random.permutation(augs)[0] # randomly implement one augmentation
510 | if choice == 'crop':
511 | cropfun(i)
512 | elif choice == 'scale':
513 | scalefun(i)
514 | elif choice == 'rotate':
515 | rotatefun(i)
516 | elif choice == 'noise':
517 | noisefun(i)
518 |
519 | return images
520 |
521 |
522 |
523 | def get_daparam(dataset, model, model_eval, ipc):
524 | # We find that augmentation doesn't always benefit the performance.
525 | # So we do augmentation for some of the settings.
526 |
527 | dc_aug_param = dict()
528 | dc_aug_param['crop'] = 4
529 | dc_aug_param['scale'] = 0.2
530 | dc_aug_param['rotate'] = 45
531 | dc_aug_param['noise'] = 0.001
532 | dc_aug_param['strategy'] = 'none'
533 |
534 | if dataset == 'MNIST':
535 | dc_aug_param['strategy'] = 'crop_scale_rotate'
536 |
537 | if model_eval in ['ConvNetBN']: # Data augmentation makes model training with Batch Norm layer easier.
538 | dc_aug_param['strategy'] = 'crop_noise'
539 |
540 | return dc_aug_param
541 |
542 |
543 | def get_eval_pool(eval_mode, model, model_eval):
544 | if eval_mode == 'M': # multiple architectures
545 | model_eval_pool = ['MLP', 'ConvNet', 'LeNet', 'AlexNet', 'VGG11', 'ResNet18']
546 | elif eval_mode == 'B': # multiple architectures with BatchNorm for DM experiments
547 | model_eval_pool = ['ConvNetBN', 'ConvNetASwishBN', 'AlexNetBN', 'VGG11BN', 'ResNet18BN']
548 | elif eval_mode == 'W': # ablation study on network width
549 | model_eval_pool = ['ConvNetW32', 'ConvNetW64', 'ConvNetW128', 'ConvNetW256']
550 | elif eval_mode == 'D': # ablation study on network depth
551 | model_eval_pool = ['ConvNetD1', 'ConvNetD2', 'ConvNetD3', 'ConvNetD4']
552 | elif eval_mode == 'A': # ablation study on network activation function
553 | model_eval_pool = ['ConvNetAS', 'ConvNetAR', 'ConvNetAL', 'ConvNetASwish']
554 | elif eval_mode == 'P': # ablation study on network pooling layer
555 | model_eval_pool = ['ConvNetNP', 'ConvNetMP', 'ConvNetAP']
556 | elif eval_mode == 'N': # ablation study on network normalization layer
557 | model_eval_pool = ['ConvNetNN', 'ConvNetBN', 'ConvNetLN', 'ConvNetIN', 'ConvNetGN']
558 | elif eval_mode == 'S': # itself
559 | if 'BN' in model:
560 | print('Attention: Here I will replace BN with IN in evaluation, as the synthetic set is too small to measure BN hyper-parameters.')
561 | model_eval_pool = [model[:model.index('BN')]] if 'BN' in model else [model]
562 | elif eval_mode == 'SS': # itself
563 | model_eval_pool = [model]
564 | else:
565 | model_eval_pool = [model_eval]
566 | return model_eval_pool
567 |
568 |
569 | class ParamDiffAug():
570 | def __init__(self):
571 | self.aug_mode = 'S' #'multiple or single'
572 | self.prob_flip = 0.5
573 | self.ratio_scale = 1.2
574 | self.ratio_rotate = 15.0
575 | self.ratio_crop_pad = 0.125
576 | self.ratio_cutout = 0.5 # the size would be 0.5x0.5
577 | self.brightness = 1.0
578 | self.saturation = 2.0
579 | self.contrast = 0.5
580 |
581 |
582 | def set_seed_DiffAug(param):
583 | if param.latestseed == -1:
584 | return
585 | else:
586 | torch.random.manual_seed(param.latestseed)
587 | param.latestseed += 1
588 |
589 |
590 | def DiffAugment(x, strategy='', seed = -1, param = None):
591 | if strategy == 'None' or strategy == 'none' or strategy == '':
592 | return x
593 |
594 | if seed == -1:
595 | param.Siamese = False
596 | else:
597 | param.Siamese = True
598 |
599 | param.latestseed = seed
600 |
601 | if strategy:
602 | if param.aug_mode == 'M': # original
603 | for p in strategy.split('_'):
604 | for f in AUGMENT_FNS[p]:
605 | x = f(x, param)
606 | elif param.aug_mode == 'S':
607 | pbties = strategy.split('_')
608 | set_seed_DiffAug(param)
609 | p = pbties[torch.randint(0, len(pbties), size=(1,)).item()]
610 | for f in AUGMENT_FNS[p]:
611 | x = f(x, param)
612 | else:
613 | exit('unknown augmentation mode: %s'%param.aug_mode)
614 | x = x.contiguous()
615 | return x
616 |
617 |
618 | # We implement the following differentiable augmentation strategies based on the code provided in https://github.com/mit-han-lab/data-efficient-gans.
619 | def rand_scale(x, param):
620 | # x>1, max scale
621 | # sx, sy: (0, +oo), 1: orignial size, 0.5: enlarge 2 times
622 | ratio = param.ratio_scale
623 | set_seed_DiffAug(param)
624 | sx = torch.rand(x.shape[0]) * (ratio - 1.0/ratio) + 1.0/ratio
625 | set_seed_DiffAug(param)
626 | sy = torch.rand(x.shape[0]) * (ratio - 1.0/ratio) + 1.0/ratio
627 | theta = [[[sx[i], 0, 0],
628 | [0, sy[i], 0],] for i in range(x.shape[0])]
629 | theta = torch.tensor(theta, dtype=torch.float)
630 | if param.Siamese: # Siamese augmentation:
631 | theta[:] = theta[0].clone()
632 | grid = F.affine_grid(theta, x.shape).to(x.device)
633 | x = F.grid_sample(x, grid)
634 | return x
635 |
636 |
637 | def rand_rotate(x, param): # [-180, 180], 90: anticlockwise 90 degree
638 | ratio = param.ratio_rotate
639 | set_seed_DiffAug(param)
640 | theta = (torch.rand(x.shape[0]) - 0.5) * 2 * ratio / 180 * float(np.pi)
641 | theta = [[[torch.cos(theta[i]), torch.sin(-theta[i]), 0],
642 | [torch.sin(theta[i]), torch.cos(theta[i]), 0],] for i in range(x.shape[0])]
643 | theta = torch.tensor(theta, dtype=torch.float)
644 | if param.Siamese: # Siamese augmentation:
645 | theta[:] = theta[0].clone()
646 | grid = F.affine_grid(theta, x.shape).to(x.device)
647 | x = F.grid_sample(x, grid)
648 | return x
649 |
650 |
651 | def rand_flip(x, param):
652 | prob = param.prob_flip
653 | set_seed_DiffAug(param)
654 | randf = torch.rand(x.size(0), 1, 1, 1, device=x.device)
655 | if param.Siamese: # Siamese augmentation:
656 | randf[:] = randf[0].clone()
657 | return torch.where(randf < prob, x.flip(3), x)
658 |
659 |
660 | def rand_brightness(x, param):
661 | ratio = param.brightness
662 | set_seed_DiffAug(param)
663 | randb = torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device)
664 | if param.Siamese: # Siamese augmentation:
665 | randb[:] = randb[0].clone()
666 | x = x + (randb - 0.5)*ratio
667 | return x
668 |
669 |
670 | def rand_saturation(x, param):
671 | ratio = param.saturation
672 | x_mean = x.mean(dim=1, keepdim=True)
673 | set_seed_DiffAug(param)
674 | rands = torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device)
675 | if param.Siamese: # Siamese augmentation:
676 | rands[:] = rands[0].clone()
677 | x = (x - x_mean) * (rands * ratio) + x_mean
678 | return x
679 |
680 |
681 | def rand_contrast(x, param):
682 | ratio = param.contrast
683 | x_mean = x.mean(dim=[1, 2, 3], keepdim=True)
684 | set_seed_DiffAug(param)
685 | randc = torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device)
686 | if param.Siamese: # Siamese augmentation:
687 | randc[:] = randc[0].clone()
688 | x = (x - x_mean) * (randc + ratio) + x_mean
689 | return x
690 |
691 |
692 | def rand_crop(x, param):
693 | # The image is padded on its surrounding and then cropped.
694 | ratio = param.ratio_crop_pad
695 | shift_x, shift_y = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
696 | set_seed_DiffAug(param)
697 | translation_x = torch.randint(-shift_x, shift_x + 1, size=[x.size(0), 1, 1], device=x.device)
698 | set_seed_DiffAug(param)
699 | translation_y = torch.randint(-shift_y, shift_y + 1, size=[x.size(0), 1, 1], device=x.device)
700 | if param.Siamese: # Siamese augmentation:
701 | translation_x[:] = translation_x[0].clone()
702 | translation_y[:] = translation_y[0].clone()
703 | grid_batch, grid_x, grid_y = torch.meshgrid(
704 | torch.arange(x.size(0), dtype=torch.long, device=x.device),
705 | torch.arange(x.size(2), dtype=torch.long, device=x.device),
706 | torch.arange(x.size(3), dtype=torch.long, device=x.device),
707 | )
708 | grid_x = torch.clamp(grid_x + translation_x + 1, 0, x.size(2) + 1)
709 | grid_y = torch.clamp(grid_y + translation_y + 1, 0, x.size(3) + 1)
710 | x_pad = F.pad(x, [1, 1, 1, 1, 0, 0, 0, 0])
711 | x = x_pad.permute(0, 2, 3, 1).contiguous()[grid_batch, grid_x, grid_y].permute(0, 3, 1, 2)
712 | return x
713 |
714 |
715 | def rand_cutout(x, param):
716 | ratio = param.ratio_cutout
717 | cutout_size = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
718 | set_seed_DiffAug(param)
719 | offset_x = torch.randint(0, x.size(2) + (1 - cutout_size[0] % 2), size=[x.size(0), 1, 1], device=x.device)
720 | set_seed_DiffAug(param)
721 | offset_y = torch.randint(0, x.size(3) + (1 - cutout_size[1] % 2), size=[x.size(0), 1, 1], device=x.device)
722 | if param.Siamese: # Siamese augmentation:
723 | offset_x[:] = offset_x[0].clone()
724 | offset_y[:] = offset_y[0].clone()
725 | grid_batch, grid_x, grid_y = torch.meshgrid(
726 | torch.arange(x.size(0), dtype=torch.long, device=x.device),
727 | torch.arange(cutout_size[0], dtype=torch.long, device=x.device),
728 | torch.arange(cutout_size[1], dtype=torch.long, device=x.device),
729 | )
730 | grid_x = torch.clamp(grid_x + offset_x - cutout_size[0] // 2, min=0, max=x.size(2) - 1)
731 | grid_y = torch.clamp(grid_y + offset_y - cutout_size[1] // 2, min=0, max=x.size(3) - 1)
732 | mask = torch.ones(x.size(0), x.size(2), x.size(3), dtype=x.dtype, device=x.device)
733 | mask[grid_batch, grid_x, grid_y] = 0
734 | x = x * mask.unsqueeze(1)
735 | return x
736 |
737 |
738 | AUGMENT_FNS = {
739 | 'color': [rand_brightness, rand_saturation, rand_contrast],
740 | 'crop': [rand_crop],
741 | 'cutout': [rand_cutout],
742 | 'flip': [rand_flip],
743 | 'scale': [rand_scale],
744 | 'rotate': [rand_rotate],
745 | }
746 |
747 |
--------------------------------------------------------------------------------