├── Initial_Synthetic_Dataset ├── CIFAR10_IPC10_images.pt ├── CIFAR10_IPC10_labels.pt ├── CIFAR10_IPC1_images.pt ├── CIFAR10_IPC1_labels.pt ├── CIFAR10_IPC50_images.pt └── CIFAR10_IPC50_labels.pt ├── README.md ├── distill_test_model.py ├── google8905e38a0c973ed3.html ├── img ├── DataDAM_pipeline.png └── HPTable.png ├── main_DataDAM.py ├── networks.py ├── requirements.txt └── utils.py /Initial_Synthetic_Dataset/CIFAR10_IPC10_images.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DataDistillation/DataDAM/0f6ea09ed019073933c9bd46382cd4aa3bec7fb8/Initial_Synthetic_Dataset/CIFAR10_IPC10_images.pt -------------------------------------------------------------------------------- /Initial_Synthetic_Dataset/CIFAR10_IPC10_labels.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DataDistillation/DataDAM/0f6ea09ed019073933c9bd46382cd4aa3bec7fb8/Initial_Synthetic_Dataset/CIFAR10_IPC10_labels.pt -------------------------------------------------------------------------------- /Initial_Synthetic_Dataset/CIFAR10_IPC1_images.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DataDistillation/DataDAM/0f6ea09ed019073933c9bd46382cd4aa3bec7fb8/Initial_Synthetic_Dataset/CIFAR10_IPC1_images.pt -------------------------------------------------------------------------------- /Initial_Synthetic_Dataset/CIFAR10_IPC1_labels.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DataDistillation/DataDAM/0f6ea09ed019073933c9bd46382cd4aa3bec7fb8/Initial_Synthetic_Dataset/CIFAR10_IPC1_labels.pt -------------------------------------------------------------------------------- /Initial_Synthetic_Dataset/CIFAR10_IPC50_images.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DataDistillation/DataDAM/0f6ea09ed019073933c9bd46382cd4aa3bec7fb8/Initial_Synthetic_Dataset/CIFAR10_IPC50_images.pt -------------------------------------------------------------------------------- /Initial_Synthetic_Dataset/CIFAR10_IPC50_labels.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DataDistillation/DataDAM/0f6ea09ed019073933c9bd46382cd4aa3bec7fb8/Initial_Synthetic_Dataset/CIFAR10_IPC50_labels.pt -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DataDAM: Efficient Dataset Distillation with Attention Matching 2 | Official implementation of "DataDAM: Efficient Dataset Distillation with Attention Matching", published as a conference paper at ICCV 2023. 3 | - Project Page: https://datadistillation.github.io/DataDAM/ 4 | ## Abstract 5 | Researchers have long tried to minimize training costs in deep learning while maintaining strong generalization across diverse datasets. Emerging research on dataset distillation aims to reduce training costs by creating a small synthetic set that contains the information of a larger real dataset and ultimately achieves test accuracy equivalent to a model trained on the whole dataset. Unfortunately, the synthetic data generated by previous methods are not guaranteed to distribute and discriminate as well as the original training data, and they incur significant computational costs. Despite promising results, there still exists a significant performance gap between models trained on condensed synthetic sets and those trained on the whole dataset. In this paper, we address these challenges using efficient Dataset Distillation with Attention Matching (DataDAM), achieving state-of-the-art performance while reducing training costs. Specifically, we learn synthetic images by matching the spatial attention maps of real and synthetic data generated by different layers within a family of randomly initialized neural networks. Our method outperforms the prior methods on several datasets, including MNIST, CIFAR10/100, TinyImageNet, and ImageNet-1K, across most of the settings, and achieves improvements of up to 6.5\% and 4.1\% on CIFAR100 and ImageNet-1K, respectively. We also show that our high-quality distilled images have practical benefits for downstream applications, such as continual learning and neural architecture search. 6 |

7 | 8 |

9 | 10 | ## File Tree 11 | This folder contains all neccesary code files and supplemental material for the main paper. 12 | ``` 13 | . 14 | ├── main_DataDAM.py # Source Code for reproducing DataDAM results on behncmark datasets and IPCs 15 | ├── networks.py # Defines all relevant network architectures, including cross-arch models 16 | ├── utils.py # Defines all utility functions required for any task or ablation in main paper, inlcuding our attention module 17 | ├── distill_test_model.py # Script to test the frozen models 18 | ├── requirements.txt # Lists all related Python packages neccessary for reproducing our model results 19 | ├── Supplementary.pdf # Supplementary pdf for our main paper -- DataDAM 20 | └── README.md 21 | ``` 22 | 23 | 24 | 25 | ## HyperParameter Table 26 | For reproducibility, we outline our associated hyperparameters below: 27 |

28 | 29 |

30 | 31 | ## Distilled Datasets & Frozen Evaluation Models 32 | 33 | We provide saved tensors of the dataset and frozen evaluation models trained on the respective distilled dataset on our HuggingFace Page: https://huggingface.co/datasets/uoft-dsp-lab/DataDAM 34 | 35 | Additionally these frozen models can be tested with "distill_test_model.py" 36 | -------------------------------------------------------------------------------- /distill_test_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import copy 4 | import argparse 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | 9 | from utils import get_loops, get_dataset, get_network, get_eval_pool, evaluate_synset, get_daparam, get_time, TensorDataset, epoch, DiffAugment, ParamDiffAug, get_attention 10 | import matplotlib.pyplot as plt 11 | from torchvision import transforms 12 | from torch.utils.data.distributed import DistributedSampler 13 | import kornia as K 14 | import torch.distributed as dist 15 | import torch.cuda.comm 16 | from torchvision.utils import save_image 17 | 18 | def main(): 19 | 20 | parser = argparse.ArgumentParser(description='Parameter Processing') 21 | parser.add_argument('--dataset', type=str, default='CIFAR10', help='dataset') 22 | parser.add_argument('--model', type=str, default='ConvNet', help='model') 23 | parser.add_argument('--ipc', type=int, default=50, help='image(s) per class') 24 | parser.add_argument('--eval_mode', type=str, default='SS', help='eval_mode') 25 | parser.add_argument('--num_exp', type=int, default=1, help='the number of experiments') 26 | parser.add_argument('--num_eval', type=int, default=10, help='the number of evaluating randomly initialized models') 27 | parser.add_argument('--epoch_eval_train', type=int, default=1800, help='epochs to train a model with synthetic data') 28 | parser.add_argument('--Iteration', type=int, default=20000, help='training iterations') 29 | parser.add_argument('--lr_img', type=float, default=1, help='learning rate for updating synthetic images, 1 for low IPCs 10 for >= 100') 30 | parser.add_argument('--lr_net', type=float, default=0.01, help='learning rate for updating network parameters') 31 | parser.add_argument('--batch_real', type=int, default=64, help='batch size for real data') 32 | parser.add_argument('--batch_train', type=int, default=64, help='batch size for training networks') 33 | parser.add_argument('--init', type=str, default='real', help='noise/real/smart: initialize synthetic images from random noise or randomly sampled real images.') 34 | parser.add_argument('--dsa_strategy', type=str, default='color_crop_cutout_flip_scale_rotate', help='differentiable Siamese augmentation strategy') 35 | parser.add_argument('--data_path', type=str, default='', help='dataset path') 36 | parser.add_argument('--zca', type=bool, default=False, help='Zca Whitening') 37 | parser.add_argument('--save_path', type=str, default='', help='path to save results') 38 | parser.add_argument('--task_balance', type=float, default=0.01, help='balance attention with output') 39 | 40 | args = parser.parse_args() 41 | args.method = 'DataDAM' 42 | args.device = 'cuda' if torch.cuda.is_available() else 'cpu' 43 | args.dsa_param = ParamDiffAug() 44 | args.dsa = False if args.dsa_strategy in ['none', 'None'] else True 45 | 46 | if not os.path.exists(args.data_path): 47 | os.mkdir(args.data_path) 48 | 49 | if not os.path.exists(args.save_path): 50 | os.mkdir(args.save_path) 51 | 52 | args.save_path += "/{}".format(args.dataset.lower()) 53 | channel, im_size, num_classes, class_names, mean, std, dst_train, dst_test, testloader, zca = get_dataset(args.dataset, args.data_path, args) 54 | model_eval_pool = get_eval_pool(args.eval_mode, args.model, args.model) 55 | 56 | 57 | model_eval = model_eval_pool[0] 58 | 59 | data_save = torch.load(os.path.join(args.save_path, 'syn_data_%s_ipc_%d.pt'%(args.dataset.lower(), args.ipc)))["data"] 60 | 61 | image_syn_eval = torch.tensor(data_save[0]) 62 | label_syn_eval = torch.tensor(data_save[1]) 63 | net_model_dict = torch.load(os.path.join(args.save_path, 'model_params_%s_ipc_%d.pt'%(args.dataset.lower(), args.ipc)))["net_parameters"] 64 | 65 | net_eval = get_network(model_eval, channel, num_classes, im_size).to(args.device) # get a random model 66 | 67 | net_eval.load_state_dict(net_model_dict) # load the state dict 68 | _, _, acc_test = evaluate_synset(-1, net_eval, image_syn_eval, label_syn_eval, testloader, args, skip=True) # evaluate the model 69 | print("Trained Model Best", acc_test) 70 | 71 | main() 72 | 73 | 74 | -------------------------------------------------------------------------------- /google8905e38a0c973ed3.html: -------------------------------------------------------------------------------- 1 | google-site-verification: google8905e38a0c973ed3.html 2 | -------------------------------------------------------------------------------- /img/DataDAM_pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DataDistillation/DataDAM/0f6ea09ed019073933c9bd46382cd4aa3bec7fb8/img/DataDAM_pipeline.png -------------------------------------------------------------------------------- /img/HPTable.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DataDistillation/DataDAM/0f6ea09ed019073933c9bd46382cd4aa3bec7fb8/img/HPTable.png -------------------------------------------------------------------------------- /main_DataDAM.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import copy 4 | import argparse 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | 9 | from utils import get_loops, get_dataset, get_network, get_eval_pool, evaluate_synset, get_daparam, get_time, TensorDataset, epoch, DiffAugment, ParamDiffAug, get_attention 10 | import matplotlib.pyplot as plt 11 | from torchvision import transforms 12 | from torch.utils.data.distributed import DistributedSampler 13 | import kornia as K 14 | import torch.distributed as dist 15 | import torch.cuda.comm 16 | 17 | def main(): 18 | 19 | parser = argparse.ArgumentParser(description='Parameter Processing') 20 | parser.add_argument('--dataset', type=str, default='CIFAR10', help='dataset') 21 | parser.add_argument('--model', type=str, default='ConvNet', help='model') 22 | parser.add_argument('--ipc', type=int, default=50, help='image(s) per class') 23 | parser.add_argument('--eval_mode', type=str, default='SS', help='eval_mode') 24 | parser.add_argument('--num_exp', type=int, default=1, help='the number of experiments') 25 | parser.add_argument('--num_eval', type=int, default=10, help='the number of evaluating randomly initialized models') 26 | parser.add_argument('--epoch_eval_train', type=int, default=1800, help='epochs to train a model with synthetic data') 27 | parser.add_argument('--Iteration', type=int, default=20000, help='training iterations') 28 | parser.add_argument('--lr_img', type=float, default=1, help='learning rate for updating synthetic images, 1 for low IPCs 10 for >= 100') 29 | parser.add_argument('--lr_net', type=float, default=0.01, help='learning rate for updating network parameters') 30 | parser.add_argument('--batch_real', type=int, default=64, help='batch size for real data') 31 | parser.add_argument('--batch_train', type=int, default=64, help='batch size for training networks') 32 | parser.add_argument('--init', type=str, default='real', help='noise/real/smart: initialize synthetic images from random noise or randomly sampled real images.') 33 | parser.add_argument('--dsa_strategy', type=str, default='color_crop_cutout_flip_scale_rotate', help='differentiable Siamese augmentation strategy') 34 | parser.add_argument('--data_path', type=str, default='', help='dataset path') 35 | parser.add_argument('--zca', type=bool, default=False, help='Zca Whitening') 36 | parser.add_argument('--save_path', type=str, default='', help='path to save results') 37 | parser.add_argument('--task_balance', type=float, default=0.01, help='balance attention with output') 38 | 39 | args = parser.parse_args() 40 | args.method = 'DataDAM' 41 | args.device = 'cuda' if torch.cuda.is_available() else 'cpu' 42 | args.dsa_param = ParamDiffAug() 43 | args.dsa = False if args.dsa_strategy in ['none', 'None'] else True 44 | 45 | if not os.path.exists(args.data_path): 46 | os.mkdir(args.data_path) 47 | 48 | if not os.path.exists(args.save_path): 49 | os.mkdir(args.save_path) 50 | 51 | eval_it_pool = np.arange(0, args.Iteration+1, 2000).tolist()[:] if args.eval_mode == 'S' or args.eval_mode == 'SS' else [args.Iteration] # The list of iterations when we evaluate models and record results. 52 | print('eval_it_pool: ', eval_it_pool) 53 | 54 | channel, im_size, num_classes, class_names, mean, std, dst_train, dst_test, testloader, zca = get_dataset(args.dataset, args.data_path, args) 55 | model_eval_pool = get_eval_pool(args.eval_mode, args.model, args.model) 56 | 57 | 58 | accs_all_exps = dict() # record performances of all experiments 59 | for key in model_eval_pool: 60 | accs_all_exps[key] = [] 61 | 62 | data_save = [] 63 | 64 | total_mean = {} 65 | best_5 = [] 66 | accuracy_logging = {"mean":[], "std":[], "max_mean":[]} 67 | for exp in range(args.num_exp): 68 | total_mean[exp] = {'mean':[], 'std':[]} 69 | best_5.append(0) 70 | print('\n================== Exp %d ==================\n '%exp) 71 | print('Hyper-parameters: \n', args.__dict__) 72 | print('Evaluation model pool: ', model_eval_pool) 73 | 74 | ''' organize the real dataset ''' 75 | images_all = [] 76 | labels_all = [] 77 | indices_class = [[] for c in range(num_classes)] 78 | 79 | images_all = [torch.unsqueeze(dst_train[i][0], dim=0) for i in range(len(dst_train))] 80 | labels_all = [dst_train[i][1] for i in range(len(dst_train))] 81 | for i, lab in enumerate(labels_all): 82 | indices_class[lab].append(i) 83 | images_all = torch.cat(images_all, dim=0).to(args.device) 84 | labels_all = torch.tensor(labels_all, dtype=torch.long, device=args.device) 85 | 86 | 87 | 88 | for c in range(num_classes): 89 | print('class c = %d: %d real images'%(c, len(indices_class[c]))) 90 | 91 | def get_images(c, n): # get random n images from class c 92 | idx_shuffle = np.random.permutation(indices_class[c])[:n] 93 | return images_all[idx_shuffle] 94 | 95 | for ch in range(channel): 96 | print('real images channel %d, mean = %.4f, std = %.4f'%(ch, torch.mean(images_all[:, ch]), torch.std(images_all[:, ch]))) 97 | 98 | 99 | ''' initialize the synthetic data ''' 100 | image_syn = torch.randn(size=(num_classes*args.ipc, channel, im_size[0], im_size[1]), dtype=torch.float, requires_grad=True, device=args.device) 101 | label_syn = torch.tensor([np.ones(args.ipc)*i for i in range(num_classes)], dtype=torch.long, requires_grad=False, device=args.device).view(-1) # [0,0,0, 1,1,1, ..., 9,9,9] 102 | if args.init == 'real': 103 | print('initialize synthetic data from random real images') 104 | for c in range(num_classes): 105 | image_syn.data[c*args.ipc:(c+1)*args.ipc] = get_images(c, args.ipc).detach().data 106 | elif args.init =='noise' : 107 | print('initialize synthetic data from random noise') 108 | 109 | elif args.init =='smart' : 110 | print('initialize synthetic data from SMART selection') 111 | Path = './' 112 | if args.dataset == "CIFAR10": 113 | Path+='CIFAR10_' 114 | 115 | elif args.dataset == "CIFAR100": 116 | Path+='CIFAR100_' 117 | 118 | if args.ipc == 1: 119 | Path += 'IPC1_' 120 | 121 | elif args.ipc == 10: 122 | Path += 'IPC10_' 123 | 124 | elif args.ipc == 50: 125 | Path += 'IPC50_' 126 | 127 | elif args.ipc == 100: 128 | Path += 'IPC100_' 129 | 130 | elif args.ipc == 200: 131 | Path += 'IPC200_' 132 | image_syn.data[:][:][:][:] = torch.load(Path+'images.pt') 133 | label_syn.data[:] = torch.load(Path+'labels.pt') 134 | 135 | if(args.zca): 136 | print("ZCA Whitened Complete") 137 | image_syn.data[:][:][:][:] = zca(image_syn.data[:][:][:][:], include_fit=True) 138 | else: 139 | print("No ZCA Whiteinign") 140 | 141 | 142 | 143 | 144 | 145 | ''' training ''' 146 | optimizer_img = torch.optim.SGD([image_syn, ], lr=args.lr_img, momentum=0.5) # optimizer_img for synthetic data 147 | optimizer_img.zero_grad() 148 | print('%s training begins'%get_time()) 149 | ''' Defining the Hook Function to collect Activations ''' 150 | activations = {} 151 | def getActivation(name): 152 | def hook_func(m, inp, op): 153 | activations[name] = op.clone() 154 | return hook_func 155 | 156 | ''' Defining the Refresh Function to store Activations and reset Collection ''' 157 | def refreshActivations(activations): 158 | model_set_activations = [] # Jagged Tensor Creation 159 | for i in activations.keys(): 160 | model_set_activations.append(activations[i]) 161 | activations = {} 162 | return activations, model_set_activations 163 | 164 | ''' Defining the Delete Hook Function to collect Remove Hooks ''' 165 | def delete_hooks(hooks): 166 | for i in hooks: 167 | i.remove() 168 | return 169 | 170 | def attach_hooks(net): 171 | hooks = [] 172 | base = net.module if torch.cuda.device_count() > 1 else net 173 | for module in (base.features.named_modules()): 174 | if isinstance(module[1], nn.ReLU): 175 | # Hook the Ouptus of a ReLU Layer 176 | hooks.append(base.features[int(module[0])].register_forward_hook(getActivation('ReLU_'+str(len(hooks))))) 177 | return hooks 178 | 179 | max_mean = 0 180 | for it in range(args.Iteration+1): 181 | 182 | ''' Evaluate synthetic data ''' 183 | if it in eval_it_pool: 184 | for model_eval in model_eval_pool: 185 | print('-------------------------\nEvaluation\nmodel_train = %s, model_eval = %s, iteration = %d'%(args.model, model_eval, it)) 186 | 187 | print('DSA augmentation strategy: \n', args.dsa_strategy) 188 | print('DSA augmentation parameters: \n', args.dsa_param.__dict__) 189 | 190 | accs = [] 191 | Start = time.time() 192 | for it_eval in range(args.num_eval): 193 | net_eval = get_network(model_eval, channel, num_classes, im_size).to(args.device) # get a random model 194 | image_syn_eval, label_syn_eval = copy.deepcopy(image_syn.detach()), copy.deepcopy(label_syn.detach()) # avoid any unaware modification 195 | mini_net, acc_train, acc_test = evaluate_synset(it_eval, net_eval, image_syn_eval, label_syn_eval, testloader, args) 196 | accs.append(acc_test) 197 | if acc_test > best_5[-1]: 198 | best_5[-1] = acc_test 199 | 200 | Finish = (time.time() - Start)/10 201 | 202 | print("TOTAL TIME WAS: ", Finish) 203 | 204 | 205 | print('Evaluate %d random %s, mean = %.4f std = %.4f\n-------------------------'%(len(accs), model_eval, np.mean(accs), np.std(accs))) 206 | if np.mean(accs) > max_mean: 207 | data=[] 208 | data_save.append([copy.deepcopy(image_syn.detach().cpu()), copy.deepcopy(label_syn.detach().cpu())]) 209 | torch.save({'data': data_save, 'accs_all_exps': accs_all_exps, }, os.path.join(args.save_path, 'res_%s_%s_%s_%dipc_.pt'%(args.method, args.dataset, args.model, args.ipc))) 210 | # Track All of them! 211 | total_mean[exp]['mean'].append(np.mean(accs)) 212 | total_mean[exp]['std'].append(np.std(accs)) 213 | 214 | accuracy_logging["mean"].append(np.mean(accs)) 215 | accuracy_logging["std"].append(np.std(accs)) 216 | accuracy_logging["max_mean"].append(np.max(accs)) 217 | 218 | 219 | if it == args.Iteration: # record the final results 220 | accs_all_exps[model_eval] += accs 221 | 222 | ''' visualize and save ''' 223 | # save_name = os.path.join(args.save_path, 'vis_%s_%s_%s_%dipc_exp%d_iter%d.png'%(args.method, args.dataset, args.model, args.ipc, exp, it)) 224 | # image_syn_vis = copy.deepcopy(image_syn.detach().cpu()) 225 | # for ch in range(channel): 226 | # image_syn_vis[:, ch] = image_syn_vis[:, ch] * std[ch] + mean[ch] 227 | # image_syn_vis[image_syn_vis<0] = 0.0 228 | # image_syn_vis[image_syn_vis>1] = 1.0 229 | # save_image(image_syn_vis, save_name, nrow=args.ipc) # Trying normalize = True/False may get better visual effects. 230 | 231 | ''' Train synthetic data ''' 232 | net = get_network(args.model, channel, num_classes, im_size).to(args.device) # get a random model 233 | net.train() 234 | for param in list(net.parameters()): 235 | param.requires_grad = False 236 | 237 | loss_avg = 0 238 | def error(real, syn, err_type="MSE"): 239 | 240 | if(err_type == "MSE"): 241 | err = torch.sum((torch.mean(real, dim=0) - torch.mean(syn, dim=0))**2) 242 | 243 | elif (err_type == "MAE"): 244 | err = torch.sum(torch.abs(torch.mean(real, dim=0) - torch.mean(syn, dim=0))) 245 | 246 | elif (err_type == "ANG"): 247 | rl = torch.mean(real, dim=0) 248 | sy = torch.mean(syn, dim=0) 249 | num = torch.matmul(rl, sy) 250 | denom = (torch.sum(rl**2)**0.5) * (torch.sum(sy**2)**0.5) 251 | err = torch.acos(num/denom) 252 | 253 | elif(err_type == "MSE_B"): 254 | err = torch.sum((torch.mean(real.reshape(num_classes, args.batch_real, -1), dim=1).cpu() - torch.mean(syn.cpu().reshape(num_classes, args.ipc, -1), dim=1))**2) 255 | elif(err_type == "MAE_B"): 256 | err = torch.sum(torch.abs(torch.mean(real.reshape(num_classes, args.batch_real, -1), dim=1).cpu() - torch.mean(syn.reshape(num_classes, args.ipc, -1).cpu(), dim=1))) 257 | elif (err_type == "ANG_B"): 258 | rl = torch.mean(real.reshape(num_classes, args.batch_real, -1), dim=1).cpu() 259 | sy = torch.mean(syn.reshape(num_classes, args.ipc, -1), dim=1) 260 | 261 | denom = (torch.sum(rl**2)**0.5).cpu() * (torch.sum(sy**2)**0.5).cpu() 262 | num = rl.cpu() * sy.cpu() 263 | err = torch.sum(torch.acos(num/denom)) 264 | return err 265 | 266 | ''' update synthetic data ''' 267 | loss = torch.tensor(0.0) 268 | mid_loss = 0 269 | out_loss = 0 270 | 271 | images_real_all = [] 272 | images_syn_all = [] 273 | for c in range(num_classes): 274 | img_real = get_images(c, args.batch_real) 275 | img_syn = image_syn[c*args.ipc:(c+1)*args.ipc].reshape((args.ipc, channel, im_size[0], im_size[1])) 276 | 277 | if args.dsa: 278 | seed = int(time.time() * 1000) % 100000 279 | img_real = DiffAugment(img_real, args.dsa_strategy, seed=seed, param=args.dsa_param) 280 | img_syn = DiffAugment(img_syn, args.dsa_strategy, seed=seed, param=args.dsa_param) 281 | 282 | images_real_all.append(img_real) 283 | images_syn_all.append(img_syn) 284 | 285 | images_real_all = torch.cat(images_real_all, dim=0) 286 | 287 | images_syn_all = torch.cat(images_syn_all, dim=0) 288 | 289 | 290 | hooks = attach_hooks(net) 291 | 292 | output_real = net(images_real_all)[0].detach() 293 | activations, original_model_set_activations = refreshActivations(activations) 294 | 295 | output_syn = net(images_syn_all)[0] 296 | activations, syn_model_set_activations = refreshActivations(activations) 297 | delete_hooks(hooks) 298 | 299 | length_of_network = len(original_model_set_activations)# of Feature Map Sets 300 | 301 | for layer in range(length_of_network-1): 302 | 303 | real_attention = get_attention(original_model_set_activations[layer].detach(), param=1, exp=1, norm='l2') 304 | syn_attention = get_attention(syn_model_set_activations[layer], param=1, exp=1, norm='l2') 305 | 306 | tl = 100*error(real_attention, syn_attention, err_type="MSE_B") 307 | loss+=tl 308 | mid_loss += tl 309 | 310 | output_loss = 100*args.task_balance * error(output_real, output_syn, err_type="MSE_B") 311 | 312 | loss += output_loss 313 | out_loss += output_loss 314 | 315 | optimizer_img.zero_grad() 316 | loss.backward() 317 | optimizer_img.step() 318 | loss_avg += loss.item() 319 | torch.cuda.empty_cache() 320 | 321 | loss_avg /= (num_classes) 322 | out_loss /= (num_classes) 323 | mid_loss /= (num_classes) 324 | if it%10 == 0: 325 | print('%s iter = %05d, loss = %.4f' % (get_time(), it, loss_avg)) 326 | print('\n==================== Final Results ====================\n') 327 | for key in model_eval_pool: 328 | accs = accs_all_exps[key] 329 | print('Run %d experiments, train on %s, evaluate %d random %s, mean = %.2f%% std = %.2f%%'%(args.num_exp, args.model, len(accs), key, np.mean(accs)*100, np.std(accs)*100)) 330 | 331 | print('\n==================== Maximum Results ====================\n') 332 | 333 | best_means = [] 334 | best_std = [] 335 | for exp in total_mean.keys(): 336 | best_idx = np.argmax(total_mean[exp]['mean']) 337 | best_means.append(total_mean[exp]['mean'][best_idx]) 338 | best_std.append(total_mean[exp]['std'][best_idx]) 339 | 340 | mean = np.mean(best_means) 341 | std = np.mean(best_std) 342 | 343 | num_eval = args.num_exp*args.num_eval 344 | print('Run %d experiments, train on %s, evaluate %d random %s, mean = %.2f%% std = %.2f%%'%(args.num_exp, args.model,num_eval, key, mean*100, std*100)) 345 | 346 | 347 | print('\n==================== Top 5 Results ====================\n') 348 | 349 | 350 | mean = np.mean(best_5) 351 | std = np.std(best_5) 352 | 353 | num_eval = args.num_exp*args.num_eval 354 | print('Run %d experiments, train on %s, evaluate %d random %s, mean = %.2f%% std = %.2f%%'%(args.num_exp, args.model,num_eval, key, mean*100, std*100)) 355 | 356 | 357 | if __name__ == '__main__': 358 | main() 359 | 360 | 361 | -------------------------------------------------------------------------------- /networks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | # Acknowledgement to 5 | # https://github.com/kuangliu/pytorch-cifar, 6 | # https://github.com/BIGBALLON/CIFAR-ZOO, 7 | 8 | from einops import rearrange, repeat 9 | from einops.layers.torch import Rearrange 10 | 11 | 12 | 13 | ''' Swish activation ''' 14 | class Swish(nn.Module): # Swish(x) = x∗σ(x) 15 | def __init__(self): 16 | super().__init__() 17 | 18 | def forward(self, input): 19 | return input * torch.sigmoid(input) 20 | 21 | 22 | ''' MLP ''' 23 | class MLP(nn.Module): 24 | def __init__(self, channel, num_classes): 25 | super(MLP, self).__init__() 26 | self.fc_1 = nn.Linear(28*28*1 if channel==1 else 32*32*3, 128) 27 | self.fc_2 = nn.Linear(128, 128) 28 | self.fc_3 = nn.Linear(128, num_classes) 29 | 30 | def forward(self, x): 31 | out = x.view(x.size(0), -1) 32 | out = F.relu(self.fc_1(out)) 33 | out = F.relu(self.fc_2(out)) 34 | out = self.fc_3(out) 35 | return out 36 | 37 | 38 | 39 | ''' ConvNet ''' 40 | class ConvNet(nn.Module): 41 | def __init__(self, channel, num_classes, net_width, net_depth, net_act, net_norm, net_pooling, im_size = (32,32)): 42 | super(ConvNet, self).__init__() 43 | 44 | self.features, shape_feat = self._make_layers(channel, net_width, net_depth, net_norm, net_act, net_pooling, im_size) 45 | num_feat = shape_feat[0]*shape_feat[1]*shape_feat[2] 46 | self.classifier = nn.Linear(num_feat, num_classes) 47 | 48 | def forward(self, x): 49 | # print(x.shape) 50 | out = self.features(x) 51 | emb = out.reshape(out.size(0), -1) 52 | # emb = self.embed(x) 53 | out = self.classifier(emb) 54 | return emb, out 55 | 56 | def embed(self, x): 57 | out = self.features(x) 58 | out = out.view(out.size(0), -1) 59 | return out 60 | 61 | def _get_activation(self, net_act): 62 | if net_act == 'sigmoid': 63 | return nn.Sigmoid() 64 | elif net_act == 'relu': 65 | return nn.ReLU(inplace=True) 66 | elif net_act == 'leakyrelu': 67 | return nn.LeakyReLU(negative_slope=0.01) 68 | elif net_act == 'swish': 69 | return Swish() 70 | else: 71 | exit('unknown activation function: %s'%net_act) 72 | 73 | def _get_pooling(self, net_pooling): 74 | if net_pooling == 'maxpooling': 75 | return nn.MaxPool2d(kernel_size=2, stride=2) 76 | elif net_pooling == 'avgpooling': 77 | return nn.AvgPool2d(kernel_size=2, stride=2) 78 | elif net_pooling == 'none': 79 | return None 80 | else: 81 | exit('unknown net_pooling: %s'%net_pooling) 82 | 83 | def _get_normlayer(self, net_norm, shape_feat): 84 | # shape_feat = (c*h*w) 85 | if net_norm == 'batchnorm': 86 | return nn.BatchNorm2d(shape_feat[0], affine=True) 87 | elif net_norm == 'layernorm': 88 | return nn.LayerNorm(shape_feat, elementwise_affine=True) 89 | elif net_norm == 'instancenorm': 90 | return nn.GroupNorm(shape_feat[0], shape_feat[0], affine=True) 91 | elif net_norm == 'groupnorm': 92 | return nn.GroupNorm(4, shape_feat[0], affine=True) 93 | elif net_norm == 'none': 94 | return None 95 | else: 96 | exit('unknown net_norm: %s'%net_norm) 97 | 98 | def _make_layers(self, channel, net_width, net_depth, net_norm, net_act, net_pooling, im_size): 99 | layers = [] 100 | in_channels = channel 101 | if im_size[0] == 28: 102 | im_size = (32, 32) 103 | shape_feat = [in_channels, im_size[0], im_size[1]] 104 | for d in range(net_depth): 105 | layers += [nn.Conv2d(in_channels, net_width, kernel_size=3, padding=3 if channel == 1 and d == 0 else 1)] 106 | shape_feat[0] = net_width 107 | if net_norm != 'none': 108 | layers += [self._get_normlayer(net_norm, shape_feat)] 109 | layers += [self._get_activation(net_act)] 110 | in_channels = net_width 111 | if net_pooling != 'none': 112 | layers += [self._get_pooling(net_pooling)] 113 | shape_feat[1] //= 2 114 | shape_feat[2] //= 2 115 | 116 | return nn.Sequential(*layers), shape_feat 117 | 118 | 119 | 120 | ''' LeNet ''' 121 | class LeNet(nn.Module): 122 | def __init__(self, channel, num_classes): 123 | super(LeNet, self).__init__() 124 | self.features = nn.Sequential( 125 | nn.Conv2d(channel, 6, kernel_size=5, padding=2 if channel==1 else 0), 126 | nn.ReLU(inplace=True), 127 | nn.MaxPool2d(kernel_size=2, stride=2), 128 | nn.Conv2d(6, 16, kernel_size=5), 129 | nn.ReLU(inplace=True), 130 | nn.MaxPool2d(kernel_size=2, stride=2), 131 | ) 132 | self.fc_1 = nn.Linear(16 * 5 * 5, 120) 133 | self.fc_2 = nn.Linear(120, 84) 134 | self.fc_3 = nn.Linear(84, num_classes) 135 | 136 | def forward(self, x): 137 | x = self.features(x) 138 | x = x.view(x.size(0), -1) 139 | x = F.relu(self.fc_1(x)) 140 | x = F.relu(self.fc_2(x)) 141 | x = self.fc_3(x) 142 | return x 143 | 144 | 145 | 146 | ''' AlexNet ''' 147 | class AlexNet(nn.Module): 148 | def __init__(self, channel, num_classes): 149 | super(AlexNet, self).__init__() 150 | self.features = nn.Sequential( 151 | nn.Conv2d(channel, 128, kernel_size=5, stride=1, padding=4 if channel==1 else 2), 152 | nn.ReLU(inplace=True), 153 | nn.MaxPool2d(kernel_size=2, stride=2), 154 | nn.Conv2d(128, 192, kernel_size=5, padding=2), 155 | nn.ReLU(inplace=True), 156 | nn.MaxPool2d(kernel_size=2, stride=2), 157 | nn.Conv2d(192, 256, kernel_size=3, padding=1), 158 | nn.ReLU(inplace=True), 159 | nn.Conv2d(256, 192, kernel_size=3, padding=1), 160 | nn.ReLU(inplace=True), 161 | nn.Conv2d(192, 192, kernel_size=3, padding=1), 162 | nn.ReLU(inplace=True), 163 | nn.MaxPool2d(kernel_size=2, stride=2), 164 | ) 165 | self.fc = nn.Linear(192 * 4 * 4, num_classes) 166 | 167 | def forward(self, x): 168 | x = self.features(x) 169 | emb = x.view(x.size(0), -1) 170 | out = self.fc(emb) 171 | return emb, out 172 | 173 | def embed(self, x): 174 | x = self.features(x) 175 | x = x.view(x.size(0), -1) 176 | return x 177 | 178 | 179 | ''' AlexNetBN ''' 180 | class AlexNetBN(nn.Module): 181 | def __init__(self, channel, num_classes): 182 | super(AlexNetBN, self).__init__() 183 | self.features = nn.Sequential( 184 | nn.Conv2d(channel, 128, kernel_size=5, stride=1, padding=4 if channel==1 else 2), 185 | nn.BatchNorm2d(128), 186 | nn.ReLU(inplace=True), 187 | nn.MaxPool2d(kernel_size=2, stride=2), 188 | nn.Conv2d(128, 192, kernel_size=5, padding=2), 189 | nn.BatchNorm2d(192), 190 | nn.ReLU(inplace=True), 191 | nn.MaxPool2d(kernel_size=2, stride=2), 192 | nn.Conv2d(192, 256, kernel_size=3, padding=1), 193 | nn.BatchNorm2d(256), 194 | nn.ReLU(inplace=True), 195 | nn.Conv2d(256, 192, kernel_size=3, padding=1), 196 | nn.BatchNorm2d(192), 197 | nn.ReLU(inplace=True), 198 | nn.Conv2d(192, 192, kernel_size=3, padding=1), 199 | nn.BatchNorm2d(192), 200 | nn.ReLU(inplace=True), 201 | nn.MaxPool2d(kernel_size=2, stride=2), 202 | ) 203 | self.fc = nn.Linear(192 * 4 * 4, num_classes) 204 | 205 | def forward(self, x): 206 | x = self.features(x) 207 | emb = x.view(x.size(0), -1) 208 | out = self.fc(emb) 209 | return emb, out 210 | 211 | def embed(self, x): 212 | x = self.features(x) 213 | x = x.view(x.size(0), -1) 214 | return x 215 | 216 | 217 | ''' VGG ''' 218 | cfg_vgg = { 219 | 'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 220 | 'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 221 | 'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 222 | 'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], 223 | } 224 | class VGG(nn.Module): 225 | def __init__(self, vgg_name, channel, num_classes, norm='instancenorm'): 226 | super(VGG, self).__init__() 227 | self.channel = channel 228 | self.features = self._make_layers(cfg_vgg[vgg_name], norm) 229 | self.classifier = nn.Linear(512 if vgg_name != 'VGGS' else 128, num_classes) 230 | 231 | def forward(self, x): 232 | x = self.features(x) 233 | emb = x.view(x.size(0), -1) 234 | out = self.classifier(emb) 235 | return emb, out 236 | 237 | def embed(self, x): 238 | x = self.features(x) 239 | x = x.view(x.size(0), -1) 240 | return x 241 | 242 | def _make_layers(self, cfg, norm): 243 | layers = [] 244 | in_channels = self.channel 245 | for ic, x in enumerate(cfg): 246 | if x == 'M': 247 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 248 | else: 249 | layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=3 if self.channel==1 and ic==0 else 1), 250 | nn.GroupNorm(x, x, affine=True) if norm=='instancenorm' else nn.BatchNorm2d(x), 251 | nn.ReLU(inplace=True)] 252 | in_channels = x 253 | layers += [nn.AvgPool2d(kernel_size=1, stride=1)] 254 | return nn.Sequential(*layers) 255 | 256 | 257 | def VGG11(channel, num_classes): 258 | return VGG('VGG11', channel, num_classes) 259 | def VGG11BN(channel, num_classes): 260 | return VGG('VGG11', channel, num_classes, norm='batchnorm') 261 | def VGG13(channel, num_classes): 262 | return VGG('VGG13', channel, num_classes) 263 | def VGG16(channel, num_classes): 264 | return VGG('VGG16', channel, num_classes) 265 | def VGG19(channel, num_classes): 266 | return VGG('VGG19', channel, num_classes) 267 | 268 | 269 | ''' ResNet_AP ''' 270 | # The conv(stride=2) is replaced by conv(stride=1) + avgpool(kernel_size=2, stride=2) 271 | 272 | class BasicBlock_AP(nn.Module): 273 | expansion = 1 274 | 275 | def __init__(self, in_planes, planes, stride=1, norm='instancenorm'): 276 | super(BasicBlock_AP, self).__init__() 277 | self.norm = norm 278 | self.stride = stride 279 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=1, padding=1, bias=False) # modification 280 | self.bn1 = nn.GroupNorm(planes, planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(planes) 281 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 282 | self.bn2 = nn.GroupNorm(planes, planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(planes) 283 | 284 | self.shortcut = nn.Sequential() 285 | if stride != 1 or in_planes != self.expansion * planes: 286 | self.shortcut = nn.Sequential( 287 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=1, bias=False), 288 | nn.AvgPool2d(kernel_size=2, stride=2), # modification 289 | nn.GroupNorm(self.expansion * planes, self.expansion * planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(self.expansion * planes) 290 | ) 291 | 292 | def forward(self, x): 293 | out = F.relu(self.bn1(self.conv1(x))) 294 | if self.stride != 1: # modification 295 | out = F.avg_pool2d(out, kernel_size=2, stride=2) 296 | out = self.bn2(self.conv2(out)) 297 | out += self.shortcut(x) 298 | out = F.relu(out) 299 | return out 300 | 301 | 302 | class Bottleneck_AP(nn.Module): 303 | expansion = 4 304 | 305 | def __init__(self, in_planes, planes, stride=1, norm='instancenorm'): 306 | super(Bottleneck_AP, self).__init__() 307 | self.norm = norm 308 | self.stride = stride 309 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 310 | self.bn1 = nn.GroupNorm(planes, planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(planes) 311 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) # modification 312 | self.bn2 = nn.GroupNorm(planes, planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(planes) 313 | self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False) 314 | self.bn3 = nn.GroupNorm(self.expansion * planes, self.expansion * planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(self.expansion * planes) 315 | 316 | self.shortcut = nn.Sequential() 317 | if stride != 1 or in_planes != self.expansion * planes: 318 | self.shortcut = nn.Sequential( 319 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=1, bias=False), 320 | nn.AvgPool2d(kernel_size=2, stride=2), # modification 321 | nn.GroupNorm(self.expansion * planes, self.expansion * planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(self.expansion * planes) 322 | ) 323 | 324 | def forward(self, x): 325 | out = F.relu(self.bn1(self.conv1(x))) 326 | out = F.relu(self.bn2(self.conv2(out))) 327 | if self.stride != 1: # modification 328 | out = F.avg_pool2d(out, kernel_size=2, stride=2) 329 | out = self.bn3(self.conv3(out)) 330 | out += self.shortcut(x) 331 | out = F.relu(out) 332 | return out 333 | 334 | 335 | class ResNet_AP(nn.Module): 336 | def __init__(self, block, num_blocks, channel=3, num_classes=10, norm='instancenorm'): 337 | super(ResNet_AP, self).__init__() 338 | self.in_planes = 64 339 | self.norm = norm 340 | 341 | self.conv1 = nn.Conv2d(channel, 64, kernel_size=3, stride=1, padding=1, bias=False) 342 | self.bn1 = nn.GroupNorm(64, 64, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(64) 343 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 344 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 345 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 346 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 347 | self.classifier = nn.Linear(512 * block.expansion * 3 * 3 if channel==1 else 512 * block.expansion * 4 * 4, num_classes) # modification 348 | 349 | def _make_layer(self, block, planes, num_blocks, stride): 350 | strides = [stride] + [1] * (num_blocks - 1) 351 | layers = [] 352 | for stride in strides: 353 | layers.append(block(self.in_planes, planes, stride, self.norm)) 354 | self.in_planes = planes * block.expansion 355 | return nn.Sequential(*layers) 356 | 357 | def forward(self, x): 358 | out = F.relu(self.bn1(self.conv1(x))) 359 | out = self.layer1(out) 360 | out = self.layer2(out) 361 | out = self.layer3(out) 362 | out = self.layer4(out) 363 | out = F.avg_pool2d(out, kernel_size=1, stride=1) # modification 364 | out = out.view(out.size(0), -1) 365 | out = self.classifier(out) 366 | return out 367 | 368 | def embed(self, x): 369 | out = F.relu(self.bn1(self.conv1(x))) 370 | out = self.layer1(out) 371 | out = self.layer2(out) 372 | out = self.layer3(out) 373 | out = self.layer4(out) 374 | out = F.avg_pool2d(out, kernel_size=1, stride=1) # modification 375 | out = out.view(out.size(0), -1) 376 | return out 377 | 378 | def ResNet18BN_AP(channel, num_classes): 379 | return ResNet_AP(BasicBlock_AP, [2,2,2,2], channel=channel, num_classes=num_classes, norm='batchnorm') 380 | 381 | def ResNet18_AP(channel, num_classes): 382 | return ResNet_AP(BasicBlock_AP, [2,2,2,2], channel=channel, num_classes=num_classes) 383 | 384 | 385 | ''' ResNet ''' 386 | 387 | class BasicBlock(nn.Module): 388 | expansion = 1 389 | 390 | def __init__(self, in_planes, planes, stride=1, norm='instancenorm'): 391 | super(BasicBlock, self).__init__() 392 | self.norm = norm 393 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 394 | self.bn1 = nn.GroupNorm(planes, planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(planes) 395 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 396 | self.bn2 = nn.GroupNorm(planes, planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(planes) 397 | 398 | self.shortcut = nn.Sequential() 399 | if stride != 1 or in_planes != self.expansion*planes: 400 | self.shortcut = nn.Sequential( 401 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 402 | nn.GroupNorm(self.expansion*planes, self.expansion*planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(self.expansion*planes) 403 | ) 404 | 405 | def forward(self, x): 406 | out = F.relu(self.bn1(self.conv1(x))) 407 | out = self.bn2(self.conv2(out)) 408 | out += self.shortcut(x) 409 | out = F.relu(out) 410 | return out 411 | 412 | 413 | class Bottleneck(nn.Module): 414 | expansion = 4 415 | 416 | def __init__(self, in_planes, planes, stride=1, norm='instancenorm'): 417 | super(Bottleneck, self).__init__() 418 | self.norm = norm 419 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 420 | self.bn1 = nn.GroupNorm(planes, planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(planes) 421 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 422 | self.bn2 = nn.GroupNorm(planes, planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(planes) 423 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) 424 | self.bn3 = nn.GroupNorm(self.expansion*planes, self.expansion*planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(self.expansion*planes) 425 | 426 | self.shortcut = nn.Sequential() 427 | if stride != 1 or in_planes != self.expansion*planes: 428 | self.shortcut = nn.Sequential( 429 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 430 | nn.GroupNorm(self.expansion*planes, self.expansion*planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(self.expansion*planes) 431 | ) 432 | 433 | def forward(self, x): 434 | out = F.relu(self.bn1(self.conv1(x))) 435 | out = F.relu(self.bn2(self.conv2(out))) 436 | out = self.bn3(self.conv3(out)) 437 | out += self.shortcut(x) 438 | out = F.relu(out) 439 | return out 440 | 441 | 442 | class ResNet(nn.Module): 443 | def __init__(self, block, num_blocks, channel=3, num_classes=10, norm='instancenorm'): 444 | super(ResNet, self).__init__() 445 | self.in_planes = 64 446 | self.norm = norm 447 | 448 | self.conv1 = nn.Conv2d(channel, 64, kernel_size=3, stride=1, padding=1, bias=False) 449 | self.bn1 = nn.GroupNorm(64, 64, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(64) 450 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 451 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 452 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 453 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 454 | self.classifier = nn.Linear(512*block.expansion, num_classes) 455 | 456 | def _make_layer(self, block, planes, num_blocks, stride): 457 | strides = [stride] + [1]*(num_blocks-1) 458 | layers = [] 459 | for stride in strides: 460 | layers.append(block(self.in_planes, planes, stride, self.norm)) 461 | self.in_planes = planes * block.expansion 462 | return nn.Sequential(*layers) 463 | 464 | def forward(self, x): 465 | out = F.relu(self.bn1(self.conv1(x))) 466 | out = self.layer1(out) 467 | out = self.layer2(out) 468 | out = self.layer3(out) 469 | out = self.layer4(out) 470 | out = F.avg_pool2d(out, 4) 471 | emb = out.view(out.size(0), -1) 472 | out = self.classifier(emb) 473 | return emb, out 474 | 475 | def embed(self, x): 476 | out = F.relu(self.bn1(self.conv1(x))) 477 | out = self.layer1(out) 478 | out = self.layer2(out) 479 | out = self.layer3(out) 480 | out = self.layer4(out) 481 | out = F.avg_pool2d(out, 4) 482 | out = out.view(out.size(0), -1) 483 | return out 484 | 485 | 486 | 487 | 488 | 489 | 490 | 491 | def ResNet18BN(channel, num_classes): 492 | return ResNet(BasicBlock, [2,2,2,2], channel=channel, num_classes=num_classes, norm='batchnorm') 493 | 494 | def ResNet18(channel, num_classes): 495 | return ResNet(BasicBlock, [2,2,2,2], channel=channel, num_classes=num_classes) 496 | 497 | def ResNet34(channel, num_classes): 498 | return ResNet(BasicBlock, [3,4,6,3], channel=channel, num_classes=num_classes) 499 | 500 | def ResNet50(channel, num_classes): 501 | return ResNet(Bottleneck, [3,4,6,3], channel=channel, num_classes=num_classes) 502 | 503 | def ResNet101(channel, num_classes): 504 | return ResNet(Bottleneck, [3,4,23,3], channel=channel, num_classes=num_classes) 505 | 506 | def ResNet152(channel, num_classes): 507 | return ResNet(Bottleneck, [3,8,36,3], channel=channel, num_classes=num_classes) 508 | 509 | 510 | 511 | 512 | '''ViT Model ''' 513 | def pair(t): 514 | return t if isinstance(t, tuple) else (t, t) 515 | 516 | # classes 517 | class PreNorm(nn.Module): 518 | def __init__(self, dim, fn): 519 | super().__init__() 520 | self.norm = nn.LayerNorm(dim) 521 | self.fn = fn 522 | def forward(self, x, **kwargs): 523 | return self.fn(self.norm(x), **kwargs) 524 | 525 | class FeedForward(nn.Module): 526 | def __init__(self, dim, hidden_dim, dropout = 0.): 527 | super().__init__() 528 | self.net = nn.Sequential( 529 | nn.Linear(dim, hidden_dim), 530 | nn.GELU(), 531 | nn.Dropout(dropout), 532 | nn.Linear(hidden_dim, dim), 533 | nn.Dropout(dropout) 534 | ) 535 | def forward(self, x): 536 | return self.net(x) 537 | 538 | class Attention(nn.Module): 539 | def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): 540 | super().__init__() 541 | inner_dim = dim_head * heads 542 | project_out = not (heads == 1 and dim_head == dim) 543 | 544 | self.heads = heads 545 | self.scale = dim_head ** -0.5 546 | 547 | self.attend = nn.Softmax(dim = -1) 548 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) 549 | 550 | self.to_out = nn.Sequential( 551 | nn.Linear(inner_dim, dim), 552 | nn.Dropout(dropout) 553 | ) if project_out else nn.Identity() 554 | 555 | def forward(self, x): 556 | qkv = self.to_qkv(x).chunk(3, dim = -1) 557 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv) 558 | 559 | dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale 560 | 561 | attn = self.attend(dots) 562 | 563 | out = torch.matmul(attn, v) 564 | out = rearrange(out, 'b h n d -> b n (h d)') 565 | return self.to_out(out) 566 | 567 | class Transformer(nn.Module): 568 | def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.): 569 | super().__init__() 570 | self.layers = nn.ModuleList([]) 571 | for _ in range(depth): 572 | self.layers.append(nn.ModuleList([ 573 | PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)), 574 | PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)) 575 | ])) 576 | def forward(self, x): 577 | for attn, ff in self.layers: 578 | x = attn(x) + x 579 | x = ff(x) + x 580 | return x 581 | 582 | class ViT(nn.Module): 583 | def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.): 584 | super().__init__() 585 | image_height, image_width = pair(image_size) 586 | patch_height, patch_width = pair(patch_size) 587 | 588 | assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.' 589 | 590 | num_patches = (image_height // patch_height) * (image_width // patch_width) 591 | patch_dim = channels * patch_height * patch_width 592 | assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)' 593 | 594 | self.to_patch_embedding = nn.Sequential( 595 | Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width), 596 | nn.Linear(patch_dim, dim), 597 | ) 598 | 599 | self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim)) 600 | self.cls_token = nn.Parameter(torch.randn(1, 1, dim)) 601 | self.dropout = nn.Dropout(emb_dropout) 602 | 603 | self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout) 604 | 605 | self.pool = pool 606 | self.to_latent = nn.Identity() 607 | 608 | self.mlp_head = nn.Sequential( 609 | nn.LayerNorm(dim), 610 | nn.Linear(dim, num_classes) 611 | ) 612 | 613 | def forward(self, img): 614 | x = self.to_patch_embedding(img) 615 | b, n, _ = x.shape 616 | 617 | cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b) 618 | x = torch.cat((cls_tokens, x), dim=1) 619 | x += self.pos_embedding[:, :(n + 1)] 620 | x = self.dropout(x) 621 | 622 | x = self.transformer(x) 623 | 624 | x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0] 625 | 626 | x = self.to_latent(x) 627 | return x, self.mlp_head(x) 628 | 629 | def ViTModel(im_size, num_classes): 630 | return ViT( 631 | image_size = im_size, 632 | patch_size = 4, 633 | num_classes = num_classes, 634 | dim = 512, 635 | depth = 6, 636 | heads = 8, 637 | mlp_dim = 512, 638 | dropout = 0.1, 639 | emb_dropout = 0.1) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==1.4.0 2 | aiohttp==3.8.3 3 | aiosignal==1.3.1 4 | astunparse==1.6.3 5 | async-timeout==4.0.2 6 | asynctest==0.13.0 7 | attrs==22.2.0 8 | blessed==1.20.0 9 | cachetools==5.3.0 10 | certifi==2022.12.7 11 | charset-normalizer==2.1.1 12 | cycler==0.11.0 13 | einops==0.6.0 14 | flatbuffers==23.1.21 15 | fonttools==4.38.0 16 | frozenlist==1.3.3 17 | fsspec==2023.1.0 18 | gast==0.4.0 19 | google-auth==2.16.1 20 | google-auth-oauthlib==0.4.6 21 | google-pasta==0.2.0 22 | gpustat==1.0.0 23 | grpcio==1.51.3 24 | h5py==3.8.0 25 | idna==3.4 26 | importlib-metadata==6.0.0 27 | joblib==1.2.0 28 | keras==2.11.0 29 | kiwisolver==1.4.4 30 | kornia==0.6.9 31 | libclang==15.0.6.1 32 | lightning-utilities==0.6.0.post0 33 | Markdown==3.4.1 34 | MarkupSafe==2.1.2 35 | matplotlib==3.5.3 36 | multidict==6.0.4 37 | nas-bench-201==2.1 38 | numpy==1.21.6 39 | nvidia-cublas-cu11==11.10.3.66 40 | nvidia-cuda-nvrtc-cu11==11.7.99 41 | nvidia-cuda-runtime-cu11==11.7.99 42 | nvidia-cudnn-cu11==8.5.0.96 43 | nvidia-ml-py==11.495.46 44 | oauthlib==3.2.2 45 | opt-einsum==3.3.0 46 | packaging==23.0 47 | pandas==1.3.5 48 | Pillow==9.4.0 49 | pkg_resources==0.0.0 50 | protobuf==3.19.6 51 | psutil==5.9.4 52 | pyasn1==0.4.8 53 | pyasn1-modules==0.2.8 54 | pyparsing==3.0.9 55 | python-dateutil==2.8.2 56 | pytorch-lightning==1.9.0 57 | pytz==2022.7.1 58 | PyYAML==6.0 59 | requests==2.28.2 60 | requests-oauthlib==1.3.1 61 | rsa==4.9 62 | scikit-learn==1.0.2 63 | scipy==1.1.0 64 | seaborn==0.12.2 65 | six==1.16.0 66 | tensorboard==2.11.2 67 | tensorboard-data-server==0.6.1 68 | tensorboard-plugin-wit==1.8.1 69 | tensorflow==2.11.0 70 | tensorflow-estimator==2.11.0 71 | tensorflow-io-gcs-filesystem==0.30.0 72 | termcolor==2.2.0 73 | threadpoolctl==3.1.0 74 | torch==1.13.1 75 | torchmetrics==0.11.1 76 | torchvision==0.14.1 77 | tqdm==4.64.1 78 | typing_extensions==4.4.0 79 | urllib3==1.26.14 80 | wcwidth==0.2.6 81 | Werkzeug==2.2.3 82 | wrapt==1.14.1 83 | yarl==1.8.2 84 | zipp==3.11.0 85 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import time 2 | import os 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from torch.utils.data import Dataset 8 | from torchvision import datasets, transforms 9 | from scipy.ndimage.interpolation import rotate as scipyrotate 10 | from networks import MLP, ConvNet, LeNet, AlexNet, AlexNetBN, VGG11, VGG11BN, ResNet18, ResNet18BN_AP, ResNet18BN, ViTModel 11 | import tqdm 12 | import kornia as K 13 | from copy import deepcopy 14 | from torchvision.utils import save_image 15 | 16 | # Attention Module 17 | def get_attention(feature_set, param=0, exp=4, norm='l2'): 18 | if param==0: 19 | attention_map = torch.sum(torch.abs(feature_set), dim=1) 20 | 21 | elif param ==1: 22 | attention_map = torch.sum(torch.abs(feature_set)**exp, dim=1) 23 | 24 | elif param == 2: 25 | attention_map = torch.max(torch.abs(feature_set)**exp, dim=1) 26 | 27 | if norm == 'l2': 28 | # Dimension: [B x (H*W)] -- Vectorized 29 | vectorized_attention_map = attention_map.view(feature_set.size(0), -1) 30 | normalized_attention_maps = F.normalize(vectorized_attention_map, p=2.0) 31 | 32 | elif norm == 'fro': 33 | # Dimension: [B x H x W] -- Un-Vectorized 34 | un_vectorized_attention_map = attention_map 35 | # Dimension: [B] 36 | fro_norm = torch.sum(torch.sum(torch.abs(attention_map)**2, dim=1), dim=1) 37 | # Dimension: [B x H x W] -- Un-Vectorized) 38 | normalized_attention_maps = un_vectorized_attention_map / fro_norm.unsqueeze(dim=-1).unsqueeze(dim=-1) 39 | elif norm == 'l1': 40 | # Dimension: [B x (H*W)] -- Vectorized 41 | vectorized_attention_map = attention_map.view(feature_set.size(0), -1) 42 | normalized_attention_maps = F.normalize(vectorized_attention_map, p=1.0) 43 | 44 | elif norm =='none': 45 | normalized_attention_maps = attention_map 46 | 47 | elif norm == 'none-vectorized': 48 | normalized_attention_maps = attention_map.view(feature_set.size(0), -1) 49 | 50 | return normalized_attention_maps 51 | 52 | 53 | 54 | 55 | def get_dataset(dataset, data_path, args): 56 | if dataset == 'MNIST': 57 | channel = 1 58 | im_size = (28, 28) 59 | num_classes = 10 60 | mean = [0.1307] 61 | std = [0.3081] 62 | transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=mean, std=std)]) 63 | dst_train = datasets.MNIST(data_path, train=True, download=True, transform=transform) # no augmentation 64 | dst_test = datasets.MNIST(data_path, train=False, download=True, transform=transform) 65 | class_names = [str(c) for c in range(num_classes)] 66 | 67 | elif dataset == 'CIFAR10': 68 | channel = 3 69 | im_size = (32, 32) 70 | num_classes = 10 71 | mean = [0.4914, 0.4822, 0.4465] 72 | std = [0.2023, 0.1994, 0.2010] 73 | transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=mean, std=std)]) 74 | dst_train = datasets.CIFAR10(data_path, train=True, download=True, transform=transform) # no augmentation 75 | dst_test = datasets.CIFAR10(data_path, train=False, download=True, transform=transform) 76 | class_names = dst_train.classes 77 | 78 | elif dataset == 'CIFAR100': 79 | channel = 3 80 | im_size = (32, 32) 81 | num_classes = 100 82 | mean = [0.5071, 0.4866, 0.4409] 83 | std = [0.2673, 0.2564, 0.2762] 84 | transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=mean, std=std)]) 85 | dst_train = datasets.CIFAR100(data_path, train=True, download=True, transform=transform) # no augmentation 86 | dst_test = datasets.CIFAR100(data_path, train=False, download=True, transform=transform) 87 | class_names = dst_train.classes 88 | 89 | elif dataset == 'TinyImageNet': 90 | channel = 3 91 | im_size = (64, 64) 92 | num_classes = 200 93 | mean = [0.485, 0.456, 0.406] 94 | std = [0.229, 0.224, 0.225] 95 | data = torch.load(os.path.join(data_path, 'tinyimagenet.pt'), map_location='cpu') 96 | 97 | class_names = data['classes'] 98 | 99 | images_train = data['images_train'] 100 | labels_train = data['labels_train'] 101 | images_train = images_train.detach().float() / 255.0 102 | labels_train = labels_train.detach() 103 | for c in range(channel): 104 | images_train[:,c] = (images_train[:,c].clone() - mean[c])/std[c] 105 | dst_train = TensorDataset(images_train, labels_train) # no augmentation 106 | 107 | images_val = data['images_val'] 108 | labels_val = data['labels_val'] 109 | images_val = images_val.detach().float() / 255.0 110 | labels_val = labels_val.detach() 111 | 112 | for c in range(channel): 113 | images_val[:, c] = (images_val[:, c].clone() - mean[c]) / std[c] 114 | 115 | dst_test = TensorDataset(images_val, labels_val) # no augmentation 116 | 117 | elif dataset == 'ImageNette': 118 | channel = 3 119 | im_size = (128, 128) 120 | num_classes = 10 121 | 122 | class_names = ["Tench", "English Springer", "Cassette Player", "Chainsaw", "Church", "French Horn", "Garbage Truck", "Gas Pump","Golf Ball", "Parachute"] 123 | 124 | mean = [0.485, 0.456, 0.406] 125 | std = [0.229, 0.224, 0.225] 126 | 127 | transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=mean, std=std), 128 | transforms.Resize(im_size), 129 | transforms.CenterCrop(im_size)]) 130 | dst_train = datasets.ImageFolder(os.path.join(data_path, "train"), transform=transform) # no augmentation 131 | dst_test = datasets.ImageFolder(os.path.join(data_path, "val"), transform=transform) 132 | 133 | elif dataset == 'ImageWoof': 134 | channel = 3 135 | im_size = (128, 128) 136 | num_classes = 10 137 | 138 | class_names = ["Australian Terrier", "Border Terrier", "Samoyed", "Beagle", "Shih-Tzu" ,"English Foxhound", "Rhodesian Ridgeback", "Dingo", "Golden Retriever", "English Sheepdog"] 139 | 140 | mean = [0.485, 0.456, 0.406] 141 | std = [0.229, 0.224, 0.225] 142 | 143 | transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=mean, std=std), 144 | transforms.Resize(im_size), 145 | transforms.CenterCrop(im_size)]) 146 | dst_train = datasets.ImageFolder(os.path.join(data_path, "train"), transform=transform) # no augmentation 147 | dst_test = datasets.ImageFolder(os.path.join(data_path, "val"), transform=transform) 148 | 149 | 150 | elif dataset == 'ImageSquack': 151 | channel = 3 152 | im_size = (128, 128) 153 | num_classes = 10 154 | 155 | class_names = ["peacock", "flamingo", "macaw", "pelican", "king_penguin", "bald_eagle", "toucan", "ostrich", "black_swan", "cockatoo"] 156 | 157 | mean = [0.485, 0.456, 0.406] 158 | std = [0.229, 0.224, 0.225] 159 | 160 | transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=mean, std=std), 161 | transforms.Resize(im_size), 162 | transforms.CenterCrop(im_size)]) 163 | dst_train = datasets.ImageFolder(os.path.join(data_path, "train"), transform=transform) # no augmentation 164 | dst_test = datasets.ImageFolder(os.path.join(data_path, "val"), transform=transform) 165 | 166 | elif dataset == 'ImageFruit': 167 | imagefruit = [953, 954, 949, 950, 951, 957, 952, 945, 943, 948] 168 | imagefruitLabels = {j:i for i,j in enumerate(imagefruit)} 169 | class_names = ["pineapple", "banana", "strawberry", "orange", "lemon", "pomegranate", "fig", "bell_pepper", "cucumber", "green_apple"] 170 | channel = 3 171 | im_size = (128, 128) 172 | num_classes = 10 173 | 174 | mean = [0.485, 0.456, 0.406] 175 | std = [0.229, 0.224, 0.225] 176 | if args.zca: 177 | transform = transforms.Compose([transforms.ToTensor(), 178 | transforms.Resize(im_size), 179 | transforms.CenterCrop(im_size)]) 180 | else: 181 | transform = transforms.Compose([transforms.ToTensor(), 182 | transforms.Normalize(mean=mean, std=std), 183 | transforms.Resize(im_size), 184 | transforms.CenterCrop(im_size)]) 185 | 186 | dst_train = datasets.ImageFolder(os.path.join(data_path, "train"), transform=transform) 187 | 188 | for idx, (image, label) in enumerate(dst_train): 189 | if label in imagefruit: 190 | selected_dataset.append((image, imagefruitLabels[label])) 191 | # Create a new dataset using the selected classes 192 | dst_train = torch.utils.data.Subset(selected_dataset, torch.arange(len(selected_dataset))) 193 | 194 | dst_test = datasets.ImageFolder(os.path.join(data_path, "val"), transform=transform) 195 | dst_test = torch.utils.data.Subset(dst_test, np.squeeze(np.argwhere(np.isin(dst_test.targets, imagefruit)))) 196 | for c in range(len(imagefruit)): 197 | dst_test.dataset.targets[dst_test.dataset.targets == imagefruit[c]] = c 198 | dst_train.dataset.targets[dst_train.dataset.targets == imagefruit[c]] = c 199 | print(dst_test.dataset) 200 | print(len(dst_train )) 201 | print(dst_train.dataset) 202 | print(min(dst_train.dataset.targets), max(dst_train.dataset.targets)) 203 | class_map = {x: i for i, x in enumerate(imagefruit)} 204 | class_map_inv = {i: x for i, x in enumerate(imagefruit)} 205 | class_names = None 206 | 207 | 208 | else: 209 | exit('unknown dataset: %s'%dataset) 210 | zca=None 211 | if args.zca: 212 | images = [] 213 | labels = [] 214 | print("Train ZCA") 215 | for i in tqdm.tqdm(range(len(dst_train))): 216 | im, lab = dst_train[i] 217 | images.append(im) 218 | labels.append(lab) 219 | images = torch.stack(images, dim=0).to(args.device) 220 | labels = torch.tensor(labels, dtype=torch.long, device="cpu") 221 | zca = K.enhance.ZCAWhitening(eps=0.1, compute_inv=True) 222 | zca.fit(images) 223 | zca_images = zca(images).to("cpu") 224 | dst_train = TensorDataset(zca_images, labels) 225 | 226 | images = [] 227 | labels = [] 228 | print("Test ZCA") 229 | for i in tqdm.tqdm(range(len(dst_test))): 230 | im, lab = dst_test[i] 231 | images.append(im) 232 | labels.append(lab) 233 | images = torch.stack(images, dim=0).to(args.device) 234 | labels = torch.tensor(labels, dtype=torch.long, device="cpu") 235 | 236 | zca_images = zca(images).to("cpu") 237 | dst_test = TensorDataset(zca_images, labels) 238 | 239 | args.zca_trans = zca 240 | 241 | testloader = torch.utils.data.DataLoader(dst_test, batch_size=256, shuffle=False, num_workers=0) 242 | return channel, im_size, num_classes, class_names, mean, std, dst_train, dst_test, testloader, zca 243 | 244 | 245 | 246 | class TensorDataset(Dataset): 247 | def __init__(self, images, labels): # images: n x c x h x w tensor 248 | self.images = images.detach().float() 249 | self.labels = labels.detach() 250 | 251 | def __getitem__(self, index): 252 | return self.images[index], self.labels[index] 253 | 254 | def __len__(self): 255 | return self.images.shape[0] 256 | 257 | 258 | 259 | def get_default_convnet_setting(): 260 | net_width, net_depth, net_act, net_norm, net_pooling = 128, 3, 'relu', 'instancenorm', 'avgpooling' 261 | return net_width, net_depth, net_act, net_norm, net_pooling 262 | 263 | 264 | 265 | def get_network(model, channel, num_classes, im_size=(32, 32)): 266 | torch.random.manual_seed(int(time.time() * 1000) % 100000) 267 | net_width, net_depth, net_act, net_norm, net_pooling = get_default_convnet_setting() 268 | 269 | if model == 'MLP': 270 | net = MLP(channel=channel, num_classes=num_classes) 271 | elif model == 'ConvNet': 272 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=net_depth, net_act=net_act, net_norm='batchnorm', net_pooling=net_pooling, im_size=im_size) 273 | 274 | elif model == 'ConvNet128IN': # Higher Resolution 275 | net_depth=6 276 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=net_depth, net_act=net_act, net_norm='batchnorm', net_pooling=net_pooling, im_size=im_size) 277 | 278 | elif model == 'LeNet': 279 | net = LeNet(channel=channel, num_classes=num_classes) 280 | elif model == 'AlexNet': 281 | net = AlexNet(channel=channel, num_classes=num_classes) 282 | elif model == 'AlexNetBN': 283 | net = AlexNetBN(channel=channel, num_classes=num_classes) 284 | elif model == 'VGG11': 285 | net = VGG11( channel=channel, num_classes=num_classes) 286 | elif model == 'VGG11BN': 287 | net = VGG11BN(channel=channel, num_classes=num_classes) 288 | elif model == 'ResNet18': 289 | net = ResNet18(channel=channel, num_classes=num_classes) 290 | elif model == 'ResNet18BN_AP': 291 | net = ResNet18BN_AP(channel=channel, num_classes=num_classes) 292 | elif model == 'ResNet18BN': 293 | net = ResNet18BN(channel=channel, num_classes=num_classes) 294 | elif model == 'ViT': 295 | net = ViTModel(im_size, num_classes) 296 | print("ViT Model") 297 | elif model == 'ConvNetD1': 298 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=1, net_act=net_act, net_norm=net_norm, net_pooling=net_pooling, im_size=im_size) 299 | elif model == 'ConvNetD2': 300 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=2, net_act=net_act, net_norm=net_norm, net_pooling=net_pooling, im_size=im_size) 301 | elif model == 'ConvNetD3': 302 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=3, net_act=net_act, net_norm=net_norm, net_pooling=net_pooling, im_size=im_size) 303 | elif model == 'ConvNetD4': 304 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=4, net_act=net_act, net_norm=net_norm, net_pooling=net_pooling, im_size=im_size) 305 | 306 | elif model == 'ConvNetW32': 307 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=32, net_depth=net_depth, net_act=net_act, net_norm=net_norm, net_pooling=net_pooling, im_size=im_size) 308 | elif model == 'ConvNetW64': 309 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=64, net_depth=net_depth, net_act=net_act, net_norm=net_norm, net_pooling=net_pooling, im_size=im_size) 310 | elif model == 'ConvNetW128': 311 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=128, net_depth=net_depth, net_act=net_act, net_norm=net_norm, net_pooling=net_pooling, im_size=im_size) 312 | elif model == 'ConvNetW256': 313 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=256, net_depth=net_depth, net_act=net_act, net_norm=net_norm, net_pooling=net_pooling, im_size=im_size) 314 | 315 | elif model == 'ConvNetAS': 316 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=net_depth, net_act='sigmoid', net_norm=net_norm, net_pooling=net_pooling, im_size=im_size) 317 | elif model == 'ConvNetAR': 318 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=net_depth, net_act='relu', net_norm=net_norm, net_pooling=net_pooling, im_size=im_size) 319 | elif model == 'ConvNetAL': 320 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=net_depth, net_act='leakyrelu', net_norm=net_norm, net_pooling=net_pooling, im_size=im_size) 321 | elif model == 'ConvNetASwish': 322 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=net_depth, net_act='swish', net_norm=net_norm, net_pooling=net_pooling, im_size=im_size) 323 | elif model == 'ConvNetASwishBN': 324 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=net_depth, net_act='swish', net_norm='batchnorm', net_pooling=net_pooling, im_size=im_size) 325 | 326 | elif model == 'ConvNetNN': 327 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=net_depth, net_act=net_act, net_norm='none', net_pooling=net_pooling, im_size=im_size) 328 | elif model == 'ConvNetBN': 329 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=net_depth, net_act=net_act, net_norm='batchnorm', net_pooling=net_pooling, im_size=im_size) 330 | elif model == 'ConvNetBNImageNet': 331 | net_depth=4 332 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=net_depth, net_act=net_act, net_norm='batchnorm', net_pooling=net_pooling, im_size=im_size) 333 | elif model == 'ConvNetLN': 334 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=net_depth, net_act=net_act, net_norm='layernorm', net_pooling=net_pooling, im_size=im_size) 335 | elif model == 'ConvNetIN': 336 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=net_depth, net_act=net_act, net_norm='instancenorm', net_pooling=net_pooling, im_size=im_size) 337 | elif model == 'ConvNetGN': 338 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=net_depth, net_act=net_act, net_norm='groupnorm', net_pooling=net_pooling, im_size=im_size) 339 | 340 | elif model == 'ConvNetNP': 341 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=net_depth, net_act=net_act, net_norm=net_norm, net_pooling='none', im_size=im_size) 342 | elif model == 'ConvNetMP': 343 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=net_depth, net_act=net_act, net_norm=net_norm, net_pooling='maxpooling', im_size=im_size) 344 | elif model == 'ConvNetAP': 345 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=net_depth, net_act=net_act, net_norm=net_norm, net_pooling='avgpooling', im_size=im_size) 346 | 347 | else: 348 | net = None 349 | exit('unknown model: %s'%model) 350 | 351 | gpu_num = torch.cuda.device_count() 352 | 353 | if gpu_num > 1: 354 | net = nn.DataParallel(net) 355 | net = net.cuda() 356 | 357 | return net 358 | 359 | 360 | 361 | def get_time(): 362 | return str(time.strftime("[%Y-%m-%d %H:%M:%S]", time.localtime())) 363 | 364 | def get_loops(ipc): 365 | # Get the two hyper-parameters of outer-loop and inner-loop. 366 | # The following values are empirically good. 367 | if ipc == 1: 368 | outer_loop, inner_loop = 1, 1 369 | elif ipc == 10: 370 | outer_loop, inner_loop = 10, 50 371 | elif ipc == 20: 372 | outer_loop, inner_loop = 20, 25 373 | elif ipc == 30: 374 | outer_loop, inner_loop = 30, 20 375 | elif ipc == 40: 376 | outer_loop, inner_loop = 40, 15 377 | elif ipc == 50: 378 | outer_loop, inner_loop = 50, 10 379 | else: 380 | outer_loop, inner_loop = 0, 0 381 | exit('loop hyper-parameters are not defined for %d ipc'%ipc) 382 | return outer_loop, inner_loop 383 | 384 | 385 | 386 | def epoch(mode, dataloader, net, optimizer, criterion, args, aug): 387 | loss_avg, acc_avg, num_exp = 0, 0, 0 388 | net = net.to(args.device) 389 | criterion = criterion.to(args.device) 390 | 391 | if mode == 'train': 392 | net.train() 393 | else: 394 | net.eval() 395 | 396 | for i_batch, datum in enumerate(dataloader): 397 | img = datum[0].float().to(args.device) 398 | if aug: 399 | if args.dsa: 400 | img = DiffAugment(img, args.dsa_strategy, param=args.dsa_param) 401 | lab = datum[1].long().to(args.device) 402 | n_b = lab.shape[0] 403 | 404 | output = net(img)[1] 405 | loss = criterion(output, lab) 406 | acc = np.sum(np.equal(np.argmax(output.cpu().data.numpy(), axis=-1), lab.cpu().data.numpy())) 407 | 408 | loss_avg += loss.item()*n_b 409 | acc_avg += acc 410 | num_exp += n_b 411 | 412 | if mode == 'train': 413 | optimizer.zero_grad() 414 | loss.backward() 415 | optimizer.step() 416 | loss_avg /= num_exp 417 | acc_avg /= num_exp 418 | 419 | return loss_avg, acc_avg 420 | 421 | 422 | 423 | def evaluate_synset(it_eval, net, images_train, labels_train, testloader, args, skip=False): 424 | net = net.to(args.device) 425 | 426 | images_train = images_train.to(args.device) 427 | labels_train = labels_train.to(args.device) 428 | lr = float(args.lr_net) 429 | Epoch = int(args.epoch_eval_train) 430 | lr_schedule = [Epoch//2+1] 431 | optimizer = torch.optim.SGD(net.parameters(), lr=lr, momentum=0.9, weight_decay=0.0005) 432 | criterion = nn.CrossEntropyLoss().to(args.device) 433 | 434 | dst_train = TensorDataset(images_train, labels_train) 435 | trainloader = torch.utils.data.DataLoader(dst_train, batch_size=args.batch_train, shuffle=True, num_workers=0) 436 | 437 | start = time.time() 438 | acc_test = 0 439 | loss_train = 0 440 | time_train = 0 441 | acc_train = 0 442 | if not skip: 443 | for ep in tqdm.tqdm(range(Epoch+1)): 444 | loss_train, acc_train = epoch('train', trainloader, net, optimizer, criterion, args, aug = True) 445 | if ep in lr_schedule: 446 | lr *= 0.1 447 | optimizer = torch.optim.SGD(net.parameters(), lr=lr, momentum=0.9, weight_decay=0.0005) 448 | time_train = time.time() - start 449 | 450 | loss_test, acc_test = epoch('test', testloader, net, optimizer, criterion, args, aug = False) 451 | print('%s Evaluate_%02d: epoch = %04d train time = %d s train loss = %.6f train acc = %.4f, test acc = %.4f' % (get_time(), it_eval, Epoch, int(time_train), loss_train, acc_train, acc_test)) 452 | 453 | return net, acc_train, acc_test 454 | 455 | 456 | 457 | def augment(images, dc_aug_param, device): 458 | # This can be sped up in the future. 459 | print("In here, no dsa lol", dc_aug_param) 460 | 461 | 462 | 463 | if dc_aug_param != None and dc_aug_param['strategy'] != 'none': 464 | scale = dc_aug_param['scale'] 465 | crop = dc_aug_param['crop'] 466 | rotate = dc_aug_param['rotate'] 467 | noise = dc_aug_param['noise'] 468 | strategy = dc_aug_param['strategy'] 469 | 470 | shape = images.shape 471 | mean = [] 472 | for c in range(shape[1]): 473 | mean.append(float(torch.mean(images[:,c]))) 474 | 475 | def cropfun(i): 476 | im_ = torch.zeros(shape[1],shape[2]+crop*2,shape[3]+crop*2, dtype=torch.float, device=device) 477 | for c in range(shape[1]): 478 | im_[c] = mean[c] 479 | im_[:, crop:crop+shape[2], crop:crop+shape[3]] = images[i] 480 | r, c = np.random.permutation(crop*2)[0], np.random.permutation(crop*2)[0] 481 | images[i] = im_[:, r:r+shape[2], c:c+shape[3]] 482 | 483 | def scalefun(i): 484 | h = int((np.random.uniform(1 - scale, 1 + scale)) * shape[2]) 485 | w = int((np.random.uniform(1 - scale, 1 + scale)) * shape[2]) 486 | tmp = F.interpolate(images[i:i + 1], [h, w], )[0] 487 | mhw = max(h, w, shape[2], shape[3]) 488 | im_ = torch.zeros(shape[1], mhw, mhw, dtype=torch.float, device=device) 489 | r = int((mhw - h) / 2) 490 | c = int((mhw - w) / 2) 491 | im_[:, r:r + h, c:c + w] = tmp 492 | r = int((mhw - shape[2]) / 2) 493 | c = int((mhw - shape[3]) / 2) 494 | images[i] = im_[:, r:r + shape[2], c:c + shape[3]] 495 | 496 | def rotatefun(i): 497 | im_ = scipyrotate(images[i].cpu().data.numpy(), angle=np.random.randint(-rotate, rotate), axes=(-2, -1), cval=np.mean(mean)) 498 | r = int((im_.shape[-2] - shape[-2]) / 2) 499 | c = int((im_.shape[-1] - shape[-1]) / 2) 500 | images[i] = torch.tensor(im_[:, r:r + shape[-2], c:c + shape[-1]], dtype=torch.float, device=device) 501 | 502 | def noisefun(i): 503 | images[i] = images[i] + noise * torch.randn(shape[1:], dtype=torch.float, device=device) 504 | 505 | 506 | augs = strategy.split('_') 507 | 508 | for i in range(shape[0]): 509 | choice = np.random.permutation(augs)[0] # randomly implement one augmentation 510 | if choice == 'crop': 511 | cropfun(i) 512 | elif choice == 'scale': 513 | scalefun(i) 514 | elif choice == 'rotate': 515 | rotatefun(i) 516 | elif choice == 'noise': 517 | noisefun(i) 518 | 519 | return images 520 | 521 | 522 | 523 | def get_daparam(dataset, model, model_eval, ipc): 524 | # We find that augmentation doesn't always benefit the performance. 525 | # So we do augmentation for some of the settings. 526 | 527 | dc_aug_param = dict() 528 | dc_aug_param['crop'] = 4 529 | dc_aug_param['scale'] = 0.2 530 | dc_aug_param['rotate'] = 45 531 | dc_aug_param['noise'] = 0.001 532 | dc_aug_param['strategy'] = 'none' 533 | 534 | if dataset == 'MNIST': 535 | dc_aug_param['strategy'] = 'crop_scale_rotate' 536 | 537 | if model_eval in ['ConvNetBN']: # Data augmentation makes model training with Batch Norm layer easier. 538 | dc_aug_param['strategy'] = 'crop_noise' 539 | 540 | return dc_aug_param 541 | 542 | 543 | def get_eval_pool(eval_mode, model, model_eval): 544 | if eval_mode == 'M': # multiple architectures 545 | model_eval_pool = ['MLP', 'ConvNet', 'LeNet', 'AlexNet', 'VGG11', 'ResNet18'] 546 | elif eval_mode == 'B': # multiple architectures with BatchNorm for DM experiments 547 | model_eval_pool = ['ConvNetBN', 'ConvNetASwishBN', 'AlexNetBN', 'VGG11BN', 'ResNet18BN'] 548 | elif eval_mode == 'W': # ablation study on network width 549 | model_eval_pool = ['ConvNetW32', 'ConvNetW64', 'ConvNetW128', 'ConvNetW256'] 550 | elif eval_mode == 'D': # ablation study on network depth 551 | model_eval_pool = ['ConvNetD1', 'ConvNetD2', 'ConvNetD3', 'ConvNetD4'] 552 | elif eval_mode == 'A': # ablation study on network activation function 553 | model_eval_pool = ['ConvNetAS', 'ConvNetAR', 'ConvNetAL', 'ConvNetASwish'] 554 | elif eval_mode == 'P': # ablation study on network pooling layer 555 | model_eval_pool = ['ConvNetNP', 'ConvNetMP', 'ConvNetAP'] 556 | elif eval_mode == 'N': # ablation study on network normalization layer 557 | model_eval_pool = ['ConvNetNN', 'ConvNetBN', 'ConvNetLN', 'ConvNetIN', 'ConvNetGN'] 558 | elif eval_mode == 'S': # itself 559 | if 'BN' in model: 560 | print('Attention: Here I will replace BN with IN in evaluation, as the synthetic set is too small to measure BN hyper-parameters.') 561 | model_eval_pool = [model[:model.index('BN')]] if 'BN' in model else [model] 562 | elif eval_mode == 'SS': # itself 563 | model_eval_pool = [model] 564 | else: 565 | model_eval_pool = [model_eval] 566 | return model_eval_pool 567 | 568 | 569 | class ParamDiffAug(): 570 | def __init__(self): 571 | self.aug_mode = 'S' #'multiple or single' 572 | self.prob_flip = 0.5 573 | self.ratio_scale = 1.2 574 | self.ratio_rotate = 15.0 575 | self.ratio_crop_pad = 0.125 576 | self.ratio_cutout = 0.5 # the size would be 0.5x0.5 577 | self.brightness = 1.0 578 | self.saturation = 2.0 579 | self.contrast = 0.5 580 | 581 | 582 | def set_seed_DiffAug(param): 583 | if param.latestseed == -1: 584 | return 585 | else: 586 | torch.random.manual_seed(param.latestseed) 587 | param.latestseed += 1 588 | 589 | 590 | def DiffAugment(x, strategy='', seed = -1, param = None): 591 | if strategy == 'None' or strategy == 'none' or strategy == '': 592 | return x 593 | 594 | if seed == -1: 595 | param.Siamese = False 596 | else: 597 | param.Siamese = True 598 | 599 | param.latestseed = seed 600 | 601 | if strategy: 602 | if param.aug_mode == 'M': # original 603 | for p in strategy.split('_'): 604 | for f in AUGMENT_FNS[p]: 605 | x = f(x, param) 606 | elif param.aug_mode == 'S': 607 | pbties = strategy.split('_') 608 | set_seed_DiffAug(param) 609 | p = pbties[torch.randint(0, len(pbties), size=(1,)).item()] 610 | for f in AUGMENT_FNS[p]: 611 | x = f(x, param) 612 | else: 613 | exit('unknown augmentation mode: %s'%param.aug_mode) 614 | x = x.contiguous() 615 | return x 616 | 617 | 618 | # We implement the following differentiable augmentation strategies based on the code provided in https://github.com/mit-han-lab/data-efficient-gans. 619 | def rand_scale(x, param): 620 | # x>1, max scale 621 | # sx, sy: (0, +oo), 1: orignial size, 0.5: enlarge 2 times 622 | ratio = param.ratio_scale 623 | set_seed_DiffAug(param) 624 | sx = torch.rand(x.shape[0]) * (ratio - 1.0/ratio) + 1.0/ratio 625 | set_seed_DiffAug(param) 626 | sy = torch.rand(x.shape[0]) * (ratio - 1.0/ratio) + 1.0/ratio 627 | theta = [[[sx[i], 0, 0], 628 | [0, sy[i], 0],] for i in range(x.shape[0])] 629 | theta = torch.tensor(theta, dtype=torch.float) 630 | if param.Siamese: # Siamese augmentation: 631 | theta[:] = theta[0].clone() 632 | grid = F.affine_grid(theta, x.shape).to(x.device) 633 | x = F.grid_sample(x, grid) 634 | return x 635 | 636 | 637 | def rand_rotate(x, param): # [-180, 180], 90: anticlockwise 90 degree 638 | ratio = param.ratio_rotate 639 | set_seed_DiffAug(param) 640 | theta = (torch.rand(x.shape[0]) - 0.5) * 2 * ratio / 180 * float(np.pi) 641 | theta = [[[torch.cos(theta[i]), torch.sin(-theta[i]), 0], 642 | [torch.sin(theta[i]), torch.cos(theta[i]), 0],] for i in range(x.shape[0])] 643 | theta = torch.tensor(theta, dtype=torch.float) 644 | if param.Siamese: # Siamese augmentation: 645 | theta[:] = theta[0].clone() 646 | grid = F.affine_grid(theta, x.shape).to(x.device) 647 | x = F.grid_sample(x, grid) 648 | return x 649 | 650 | 651 | def rand_flip(x, param): 652 | prob = param.prob_flip 653 | set_seed_DiffAug(param) 654 | randf = torch.rand(x.size(0), 1, 1, 1, device=x.device) 655 | if param.Siamese: # Siamese augmentation: 656 | randf[:] = randf[0].clone() 657 | return torch.where(randf < prob, x.flip(3), x) 658 | 659 | 660 | def rand_brightness(x, param): 661 | ratio = param.brightness 662 | set_seed_DiffAug(param) 663 | randb = torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) 664 | if param.Siamese: # Siamese augmentation: 665 | randb[:] = randb[0].clone() 666 | x = x + (randb - 0.5)*ratio 667 | return x 668 | 669 | 670 | def rand_saturation(x, param): 671 | ratio = param.saturation 672 | x_mean = x.mean(dim=1, keepdim=True) 673 | set_seed_DiffAug(param) 674 | rands = torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) 675 | if param.Siamese: # Siamese augmentation: 676 | rands[:] = rands[0].clone() 677 | x = (x - x_mean) * (rands * ratio) + x_mean 678 | return x 679 | 680 | 681 | def rand_contrast(x, param): 682 | ratio = param.contrast 683 | x_mean = x.mean(dim=[1, 2, 3], keepdim=True) 684 | set_seed_DiffAug(param) 685 | randc = torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) 686 | if param.Siamese: # Siamese augmentation: 687 | randc[:] = randc[0].clone() 688 | x = (x - x_mean) * (randc + ratio) + x_mean 689 | return x 690 | 691 | 692 | def rand_crop(x, param): 693 | # The image is padded on its surrounding and then cropped. 694 | ratio = param.ratio_crop_pad 695 | shift_x, shift_y = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5) 696 | set_seed_DiffAug(param) 697 | translation_x = torch.randint(-shift_x, shift_x + 1, size=[x.size(0), 1, 1], device=x.device) 698 | set_seed_DiffAug(param) 699 | translation_y = torch.randint(-shift_y, shift_y + 1, size=[x.size(0), 1, 1], device=x.device) 700 | if param.Siamese: # Siamese augmentation: 701 | translation_x[:] = translation_x[0].clone() 702 | translation_y[:] = translation_y[0].clone() 703 | grid_batch, grid_x, grid_y = torch.meshgrid( 704 | torch.arange(x.size(0), dtype=torch.long, device=x.device), 705 | torch.arange(x.size(2), dtype=torch.long, device=x.device), 706 | torch.arange(x.size(3), dtype=torch.long, device=x.device), 707 | ) 708 | grid_x = torch.clamp(grid_x + translation_x + 1, 0, x.size(2) + 1) 709 | grid_y = torch.clamp(grid_y + translation_y + 1, 0, x.size(3) + 1) 710 | x_pad = F.pad(x, [1, 1, 1, 1, 0, 0, 0, 0]) 711 | x = x_pad.permute(0, 2, 3, 1).contiguous()[grid_batch, grid_x, grid_y].permute(0, 3, 1, 2) 712 | return x 713 | 714 | 715 | def rand_cutout(x, param): 716 | ratio = param.ratio_cutout 717 | cutout_size = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5) 718 | set_seed_DiffAug(param) 719 | offset_x = torch.randint(0, x.size(2) + (1 - cutout_size[0] % 2), size=[x.size(0), 1, 1], device=x.device) 720 | set_seed_DiffAug(param) 721 | offset_y = torch.randint(0, x.size(3) + (1 - cutout_size[1] % 2), size=[x.size(0), 1, 1], device=x.device) 722 | if param.Siamese: # Siamese augmentation: 723 | offset_x[:] = offset_x[0].clone() 724 | offset_y[:] = offset_y[0].clone() 725 | grid_batch, grid_x, grid_y = torch.meshgrid( 726 | torch.arange(x.size(0), dtype=torch.long, device=x.device), 727 | torch.arange(cutout_size[0], dtype=torch.long, device=x.device), 728 | torch.arange(cutout_size[1], dtype=torch.long, device=x.device), 729 | ) 730 | grid_x = torch.clamp(grid_x + offset_x - cutout_size[0] // 2, min=0, max=x.size(2) - 1) 731 | grid_y = torch.clamp(grid_y + offset_y - cutout_size[1] // 2, min=0, max=x.size(3) - 1) 732 | mask = torch.ones(x.size(0), x.size(2), x.size(3), dtype=x.dtype, device=x.device) 733 | mask[grid_batch, grid_x, grid_y] = 0 734 | x = x * mask.unsqueeze(1) 735 | return x 736 | 737 | 738 | AUGMENT_FNS = { 739 | 'color': [rand_brightness, rand_saturation, rand_contrast], 740 | 'crop': [rand_crop], 741 | 'cutout': [rand_cutout], 742 | 'flip': [rand_flip], 743 | 'scale': [rand_scale], 744 | 'rotate': [rand_rotate], 745 | } 746 | 747 | --------------------------------------------------------------------------------