├── .gitignore ├── README.md ├── buffer ├── buffer_FTD.py ├── cifar.py ├── gsam │ ├── __init__.py │ ├── gsam.py │ ├── scheduler.py │ ├── util.py │ └── wide_res_net.py └── utility │ ├── bypass_bn.py │ ├── cutout.py │ ├── initialize.py │ ├── loading_bar.py │ ├── log.py │ └── step_lr.py ├── configs ├── CIFAR-10 │ └── ConvIN │ │ ├── IPC1.yaml │ │ ├── IPC10.yaml │ │ ├── IPC1000.yaml │ │ ├── IPC50.yaml │ │ └── IPC500.yaml ├── CIFAR-100 │ └── ConvIN │ │ ├── IPC1.yaml │ │ ├── IPC10.yaml │ │ ├── IPC100.yaml │ │ └── IPC50.yaml └── TinyImageNet │ └── ConvIN │ ├── IPC1.yaml │ ├── IPC10.yaml │ └── IPC50.yaml ├── distill ├── DATM.py ├── DATM_tesla.py ├── baseline.py ├── distill_arch.py ├── evaluation.py └── model_ema.py ├── environment.yaml ├── figures ├── comparison.png ├── visualization.png └── visualization_ipc.png ├── networks.py ├── reparam_module.py └── utils ├── cfg.py ├── step_lr.py ├── utils_arch.py ├── utils_baseline.py ├── utils_baseline_backup.py ├── utils_buffer_sam.py ├── utils_eval_sam.py ├── utils_gsam.py ├── utils_mixup.py └── utils_vanilla_test.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | 3 | dataset 4 | buffer_storage 5 | distill/logged_files 6 | distill/wandb 7 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # [ICLR 2024] Towards Lossless Dataset Distillation via Difficulty-Aligned Trajectory Matching 2 | 3 | ## [Project Page](https://gzyaftermath.github.io/DATM/) | [Paper](https://arxiv.org/abs/2310.05773) | [Distilled Datasets](https://drive.google.com/drive/folders/1kZlYgiVrmFEz0OUyxnww3II7FBPQe7W0) 4 | To achieve lossless dataset distillation, an intuitive idea is to increase the size of the synthetic dataset. 5 | However, previous dataset distillation methods tend to perform worse than random selection as IPC (i.e., data keep ratio) increases. 6 | 7 | To address this issue, we find the difficulty of the generated patterns should be aligned with the size of the synthetic dataset 8 | (avoid generating patterns that are too easy or too difficult). 9 | 10 | By doing so, our method remains effective in high IPC cases and achieves lossless dataset distillation for the very first time. 11 | ![image](figures/comparison.png) 12 | What do easy patterns and hard patterns look like? 13 | 14 | ![image](figures/visualization.png) 15 | 16 | 17 | ![image](figures/visualization_ipc.png) 18 | 19 | ## News 20 | 16 May. The implementation of DATM_with_[TESLA](https://github.com/justincui03/tesla) is merged. Thanks for the PR from [Yue XU](https://github.com/silicx)! 21 | 22 | ## Getting Started 23 | 1. Create environment as follows 24 | ``` 25 | conda env create -f environment.yaml 26 | conda activate distillation 27 | ``` 28 | 2. Generate expert trajectories 29 | ``` 30 | cd buffer 31 | python buffer_FTD.py --dataset=CIFAR10 --model=ConvNet --train_epochs=100 --num_experts=100 --zca --buffer_path=../buffer_storage/ --data_path=../dataset/ --rho_max=0.01 --rho_min=0.01 --alpha=0.3 --lr_teacher=0.01 --mom=0. --batch_train=256 32 | ``` 33 | 3. Perform the distillation 34 | ``` 35 | cd distill 36 | python DATM.py --cfg ../configs/xxxx.yaml 37 | ``` 38 | `DATM_tesla.py` is a [TESLA](https://github.com/justincui03/tesla) implementation of DATM, which could greatly reduce the VRAM usage, *e.g.* ~12G for CIFAR10 and IPC=1000. 39 | 40 | ## Evaluation 41 | We provide a simple script for evaluating the distilled datasets. 42 | ``` 43 | cd distill 44 | python evaluation.py --lr_dir=path_to_lr --data_dir=path_to_images --label_dir=path_to_labels --zca 45 | ``` 46 | ## Acknowledgement 47 | Our code is built upon [MTT](https://github.com/GeorgeCazenavette/mtt-distillation), [FTD](https://github.com/AngusDujw/FTD-distillation) and [TESLA](https://github.com/justincui03/tesla). 48 | ## Citation 49 | If you find our code useful for your research, please cite our paper. 50 | ``` 51 | @inproceedings{guo2024lossless, 52 | title={Towards Lossless Dataset Distillation via Difficulty-Aligned Trajectory Matching}, 53 | author={Ziyao Guo and Kai Wang and George Cazenavette and Hui Li and Kaipeng Zhang and Yang You}, 54 | year={2024}, 55 | booktitle={The Twelfth International Conference on Learning Representations} 56 | } 57 | ``` 58 | -------------------------------------------------------------------------------- /buffer/buffer_FTD.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import sys 4 | sys.path.append("../") 5 | import torch 6 | import torch.nn as nn 7 | from tqdm import tqdm 8 | from utils.utils_gsam import get_dataset, get_network, get_daparam,\ 9 | TensorDataset, epoch, ParamDiffAug 10 | import copy 11 | 12 | import warnings 13 | warnings.filterwarnings("ignore", category=DeprecationWarning) 14 | 15 | def main(args): 16 | 17 | args.dsa = True if args.dsa == 'True' else False 18 | args.device = 'cuda' if torch.cuda.is_available() else 'cpu' 19 | args.dsa_param = ParamDiffAug() 20 | 21 | channel, im_size, num_classes, class_names, mean, std, dst_train, dst_test, testloader, loader_train_dict, class_map, class_map_inv = get_dataset(args.dataset, args.data_path, args.batch_real, args.subset, args=args) 22 | 23 | # print('\n================== Exp %d ==================\n '%exp) 24 | print('Hyper-parameters: \n', args.__dict__) 25 | 26 | save_dir = os.path.join(args.buffer_path, args.dataset) 27 | if args.dataset == "ImageNet": 28 | save_dir = os.path.join(save_dir, args.subset, str(args.res)) 29 | if args.dataset in ["CIFAR10", "CIFAR100"] and not args.zca: 30 | save_dir += "_NO_ZCA" 31 | save_dir = os.path.join(save_dir, args.model) 32 | if not os.path.exists(save_dir): 33 | os.makedirs(save_dir) 34 | 35 | 36 | ''' organize the real dataset ''' 37 | images_all = [] 38 | labels_all = [] 39 | indices_class = [[] for c in range(num_classes)] 40 | print("BUILDING DATASET") 41 | for i in tqdm(range(len(dst_train))): 42 | sample = dst_train[i] 43 | images_all.append(torch.unsqueeze(sample[0], dim=0)) 44 | labels_all.append(class_map[torch.tensor(sample[1]).item()]) 45 | #print('num of training images',len(images_all)) 46 | len_dst_train = len(images_all) ##50000 47 | 48 | for i, lab in tqdm(enumerate(labels_all)): 49 | indices_class[lab].append(i) 50 | images_all = torch.cat(images_all, dim=0).to("cpu") 51 | labels_all = torch.tensor(labels_all, dtype=torch.long, device="cpu") 52 | 53 | for c in range(num_classes): 54 | print('class c = %d: %d real images'%(c, len(indices_class[c]))) 55 | 56 | for ch in range(channel): 57 | print('real images channel %d, mean = %.4f, std = %.4f'%(ch, torch.mean(images_all[:, ch]), torch.std(images_all[:, ch]))) 58 | 59 | criterion = nn.CrossEntropyLoss().to(args.device) 60 | 61 | trajectories = [] 62 | 63 | dst_train = TensorDataset(copy.deepcopy(images_all.detach()), copy.deepcopy(labels_all.detach())) 64 | trainloader = torch.utils.data.DataLoader(dst_train, batch_size=args.batch_train, shuffle=True, num_workers=0) 65 | 66 | ''' set augmentation for whole-dataset training ''' 67 | args.dc_aug_param = get_daparam(args.dataset, args.model, args.model, None) 68 | args.dc_aug_param['strategy'] = 'crop_scale_rotate' # for whole-dataset training 69 | print('DC augmentation parameters: \n', args.dc_aug_param) 70 | 71 | for it in range(0, args.num_experts): 72 | 73 | ''' Train synthetic data ''' 74 | teacher_net = get_network(args.model, channel, num_classes, im_size).to(args.device) # get a random model 75 | teacher_net.train() 76 | lr = args.lr_teacher 77 | 78 | 79 | ##modification: using FTD here 80 | from gsam import GSAM, LinearScheduler, CosineScheduler, ProportionScheduler 81 | base_optimizer = torch.optim.SGD(teacher_net.parameters(), lr=lr, momentum=args.mom, weight_decay=args.l2) 82 | # scheduler = CosineScheduler(T_max=args.train_epochs*len_dst_train, max_value=lr, min_value=0.0, 83 | # optimizer=base_optimizer) 84 | scheduler = torch.optim.lr_scheduler.StepLR(base_optimizer,step_size=args.train_epochs*len(trainloader),gamma=1) 85 | rho_scheduler = ProportionScheduler(pytorch_lr_scheduler=scheduler, max_lr=lr, min_lr=lr, 86 | max_value=args.rho_max, min_value=args.rho_min) 87 | teacher_optim = GSAM(params=teacher_net.parameters(), base_optimizer=base_optimizer, 88 | model=teacher_net, gsam_alpha=args.alpha, rho_scheduler=rho_scheduler, adaptive=args.adaptive) 89 | 90 | 91 | teacher_optim.zero_grad() 92 | 93 | timestamps = [] 94 | 95 | timestamps.append([p.detach().cpu() for p in teacher_net.parameters()]) 96 | 97 | lr_schedule = [args.train_epochs // 2 + 1] 98 | for e in range(args.train_epochs): 99 | 100 | train_loss, train_acc = epoch("train", dataloader=trainloader, net=teacher_net, optimizer=teacher_optim, 101 | criterion=criterion, args=args, aug=True,scheduler=scheduler) 102 | 103 | test_loss, test_acc = epoch("test", dataloader=testloader, net=teacher_net, optimizer=None, 104 | criterion=criterion, args=args, aug=False, scheduler=scheduler) 105 | 106 | print("Itr: {}\tEpoch: {}\tTrain Acc: {}\tTest Acc: {}".format(it, e, train_acc, test_acc)) 107 | 108 | timestamps.append([p.detach().cpu() for p in teacher_net.parameters()]) 109 | 110 | 111 | trajectories.append(timestamps) 112 | 113 | if len(trajectories) == args.save_interval: 114 | n = 0 115 | while os.path.exists(os.path.join(save_dir, "replay_buffer_{}.pt".format(n))): 116 | n += 1 117 | print("Saving {}".format(os.path.join(save_dir, "replay_buffer_{}.pt".format(n)))) 118 | torch.save(trajectories, os.path.join(save_dir, "replay_buffer_{}.pt".format(n))) 119 | trajectories = [] 120 | 121 | 122 | if __name__ == '__main__': 123 | parser = argparse.ArgumentParser(description='Parameter Processing') 124 | parser.add_argument('--dataset', type=str, default='CIFAR10', help='dataset') 125 | parser.add_argument('--subset', type=str, default='imagenette', help='subset') 126 | parser.add_argument('--model', type=str, default='ConvNet', help='model') 127 | parser.add_argument('--num_experts', type=int, default=100, help='training iterations') 128 | parser.add_argument('--lr_teacher', type=float, default=0.01, help='learning rate for updating network parameters') 129 | parser.add_argument('--batch_train', type=int, default=256, help='batch size for training networks') 130 | parser.add_argument('--batch_real', type=int, default=256, help='batch size for real loader') 131 | parser.add_argument('--dsa', type=str, default='True', choices=['True', 'False'], 132 | help='whether to use differentiable Siamese augmentation.') 133 | parser.add_argument('--dsa_strategy', type=str, default='color_crop_cutout_flip_scale_rotate', 134 | help='differentiable Siamese augmentation strategy') 135 | parser.add_argument('--data_path', type=str, default='data', help='dataset path') 136 | parser.add_argument('--buffer_path', type=str, default='./buffers', help='buffer path') 137 | parser.add_argument('--train_epochs', type=int, default=50) 138 | parser.add_argument('--zca', action='store_true') 139 | parser.add_argument('--decay', action='store_true') 140 | parser.add_argument('--mom', type=float, default=0, help='momentum') 141 | parser.add_argument('--l2', type=float, default=0, help='l2 regularization') 142 | parser.add_argument('--save_interval', type=int, default=10) 143 | #parser.add_argument('--rho', type=float, default=0.05) 144 | parser.add_argument("--rho_max", default=2.0, type=float, help="Rho parameter for SAM.") 145 | parser.add_argument("--rho_min", default=2.0, type=float, help="Rho parameter for SAM.") 146 | parser.add_argument("--alpha", default=0.4, type=float, help="Rho parameter for SAM.") 147 | parser.add_argument("--adaptive", default=True, type=bool, help="True if you want to use the Adaptive SAM.") 148 | 149 | args = parser.parse_args() 150 | main(args) 151 | 152 | 153 | -------------------------------------------------------------------------------- /buffer/cifar.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | import torchvision.transforms as transforms 4 | from torch.utils.data import DataLoader 5 | 6 | from utility.cutout import Cutout 7 | 8 | 9 | class Cifar: 10 | def __init__(self, batch_size, threads): 11 | mean, std = self._get_statistics() 12 | 13 | train_transform = transforms.Compose([ 14 | torchvision.transforms.RandomCrop(size=(32, 32), padding=4), 15 | torchvision.transforms.RandomHorizontalFlip(), 16 | transforms.ToTensor(), 17 | transforms.Normalize(mean, std), 18 | Cutout() 19 | ]) 20 | 21 | test_transform = transforms.Compose([ 22 | transforms.ToTensor(), 23 | transforms.Normalize(mean, std) 24 | ]) 25 | 26 | train_set = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transform) 27 | test_set = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=test_transform) 28 | 29 | self.train = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=threads, pin_memory=True) 30 | self.test = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=threads, pin_memory=True) 31 | 32 | self.classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck') 33 | 34 | def _get_statistics(self): 35 | train_set = torchvision.datasets.CIFAR10(root='./cifar', train=True, download=True, transform=transforms.ToTensor()) 36 | 37 | data = torch.cat([d[0] for d in DataLoader(train_set)]) 38 | return data.mean(dim=[0, 2, 3]), data.std(dim=[0, 2, 3]) -------------------------------------------------------------------------------- /buffer/gsam/__init__.py: -------------------------------------------------------------------------------- 1 | from .gsam import GSAM 2 | from .scheduler import * 3 | -------------------------------------------------------------------------------- /buffer/gsam/gsam.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .util import enable_running_stats, disable_running_stats 3 | import contextlib 4 | from torch.distributed import ReduceOp 5 | 6 | class GSAM(torch.optim.Optimizer): 7 | def __init__(self, params, base_optimizer, model, gsam_alpha, rho_scheduler, adaptive=False, perturb_eps=1e-12, grad_reduce='mean', **kwargs): 8 | defaults = dict(adaptive=adaptive, **kwargs) 9 | super(GSAM, self).__init__(params, defaults) 10 | self.model = model 11 | self.base_optimizer = base_optimizer 12 | self.param_groups = self.base_optimizer.param_groups 13 | self.adaptive = adaptive 14 | self.rho_scheduler = rho_scheduler 15 | self.perturb_eps = perturb_eps 16 | self.alpha = gsam_alpha 17 | 18 | # initialize self.rho_t 19 | self.update_rho_t() 20 | 21 | # set up reduction for gradient across workers 22 | if grad_reduce.lower() == 'mean': 23 | if hasattr(ReduceOp, 'AVG'): 24 | self.grad_reduce = ReduceOp.AVG 25 | self.manual_average = False 26 | else: # PyTorch <= 1.11.0 does not have AVG, need to manually average across processes 27 | self.grad_reduce = ReduceOp.SUM 28 | self.manual_average = True 29 | elif grad_reduce.lower() == 'sum': 30 | self.grad_reduce = ReduceOp.SUM 31 | self.manual_average = False 32 | else: 33 | raise ValueError('"grad_reduce" should be one of ["mean", "sum"].') 34 | 35 | @torch.no_grad() 36 | def update_rho_t(self): 37 | self.rho_t = self.rho_scheduler.step() 38 | return self.rho_t 39 | 40 | @torch.no_grad() 41 | def perturb_weights(self, rho=0.0): 42 | grad_norm = self._grad_norm( weight_adaptive = self.adaptive ) 43 | for group in self.param_groups: 44 | scale = rho / (grad_norm + self.perturb_eps) 45 | 46 | for p in group["params"]: 47 | if p.grad is None: continue 48 | self.state[p]["old_g"] = p.grad.data.clone() 49 | e_w = p.grad * scale.to(p) 50 | if self.adaptive: 51 | e_w *= torch.pow(p, 2) 52 | p.add_(e_w) # climb to the local maximum "w + e(w)" 53 | self.state[p]['e_w'] = e_w 54 | 55 | @torch.no_grad() 56 | def unperturb(self): 57 | for group in self.param_groups: 58 | for p in group['params']: 59 | if 'e_w' in self.state[p].keys(): 60 | p.data.sub_(self.state[p]['e_w']) 61 | 62 | @torch.no_grad() 63 | def gradient_decompose(self, alpha=0.0): 64 | # calculate inner product 65 | inner_prod = 0.0 66 | for group in self.param_groups: 67 | for p in group['params']: 68 | if p.grad is None: continue 69 | inner_prod += torch.sum( 70 | self.state[p]['old_g'] * p.grad.data 71 | ) 72 | 73 | # get norm 74 | new_grad_norm = self._grad_norm() 75 | old_grad_norm = self._grad_norm(by='old_g') 76 | 77 | # get cosine 78 | cosine = inner_prod / (new_grad_norm * old_grad_norm + self.perturb_eps) 79 | 80 | # gradient decomposition 81 | for group in self.param_groups: 82 | for p in group['params']: 83 | if p.grad is None: continue 84 | vertical = self.state[p]['old_g'] - cosine * old_grad_norm * p.grad.data / (new_grad_norm + self.perturb_eps) 85 | p.grad.data.add_( vertical, alpha=-alpha) 86 | 87 | @torch.no_grad() 88 | def _sync_grad(self): 89 | if torch.distributed.is_initialized(): # synchronize final gardients 90 | for group in self.param_groups: 91 | for p in group['params']: 92 | if p.grad is None: continue 93 | if self.manual_average: 94 | torch.distributed.all_reduce(p.grad, op=self.grad_reduce) 95 | world_size = torch.distributed.get_world_size() 96 | p.grad.div_(float(world_size)) 97 | else: 98 | torch.distributed.all_reduce(p.grad, op=self.grad_reduce) 99 | return 100 | 101 | @torch.no_grad() 102 | def _grad_norm(self, by=None, weight_adaptive=False): 103 | #shared_device = self.param_groups[0]["params"][0].device # put everything on the same device, in case of model parallelism 104 | if not by: 105 | norm = torch.norm( 106 | torch.stack([ 107 | ( (torch.abs(p.data) if weight_adaptive else 1.0) * p.grad).norm(p=2) 108 | for group in self.param_groups for p in group["params"] 109 | if p.grad is not None 110 | ]), 111 | p=2 112 | ) 113 | else: 114 | norm = torch.norm( 115 | torch.stack([ 116 | ( (torch.abs(p.data) if weight_adaptive else 1.0) * self.state[p][by]).norm(p=2) 117 | for group in self.param_groups for p in group["params"] 118 | if p.grad is not None 119 | ]), 120 | p=2 121 | ) 122 | return norm 123 | 124 | def load_state_dict(self, state_dict): 125 | super().load_state_dict(state_dict) 126 | self.base_optimizer.param_groups = self.param_groups 127 | 128 | def maybe_no_sync(self): 129 | if torch.distributed.is_initialized(): 130 | return self.model.no_sync() 131 | else: 132 | return contextlib.ExitStack() 133 | 134 | @torch.no_grad() 135 | def set_closure(self, loss_fn, inputs, targets, **kwargs): 136 | # create self.forward_backward_func, which is a function such that 137 | # self.forward_backward_func() automatically performs forward and backward passes. 138 | # This function does not take any arguments, and the inputs and targets data 139 | # should be pre-set in the definition of partial-function 140 | 141 | def get_grad(): 142 | self.base_optimizer.zero_grad() 143 | with torch.enable_grad(): 144 | outputs = self.model(inputs) 145 | loss = loss_fn(outputs, targets, **kwargs) 146 | loss_value = loss.data.clone().detach() 147 | loss.backward() 148 | return outputs, loss_value 149 | 150 | self.forward_backward_func = get_grad 151 | 152 | @torch.no_grad() 153 | def step(self, closure=None): 154 | 155 | if closure: 156 | get_grad = closure 157 | else: 158 | get_grad = self.forward_backward_func 159 | 160 | with self.maybe_no_sync(): 161 | # get gradient 162 | outputs, loss_value = get_grad() 163 | 164 | # perturb weights 165 | self.perturb_weights(rho=self.rho_t) 166 | 167 | # disable running stats for second pass 168 | disable_running_stats(self.model) 169 | 170 | # get gradient at perturbed weights 171 | get_grad() 172 | 173 | # decompose and get new update direction 174 | self.gradient_decompose(self.alpha) 175 | 176 | # unperturb 177 | self.unperturb() 178 | 179 | # synchronize gradients across workers 180 | self._sync_grad() 181 | 182 | # update with new directions 183 | self.base_optimizer.step() 184 | 185 | # enable running stats 186 | enable_running_stats(self.model) 187 | 188 | return outputs, loss_value 189 | -------------------------------------------------------------------------------- /buffer/gsam/scheduler.py: -------------------------------------------------------------------------------- 1 | 2 | import math 3 | import numpy as np 4 | 5 | class ProportionScheduler: 6 | def __init__(self, pytorch_lr_scheduler, max_lr, min_lr, max_value, min_value): 7 | """ 8 | This scheduler outputs a value that evolves proportional to pytorch_lr_scheduler, e.g. 9 | (value - min_value) / (max_value - min_value) = (lr - min_lr) / (max_lr - min_lr) 10 | """ 11 | self.t = 0 12 | self.pytorch_lr_scheduler = pytorch_lr_scheduler 13 | self.max_lr = max_lr 14 | self.min_lr = min_lr 15 | self.max_value = max_value 16 | self.min_value = min_value 17 | 18 | assert (max_lr > min_lr) or ((max_lr==min_lr) and (max_value==min_value)), "Current scheduler for `value` is scheduled to evolve proportionally to `lr`," \ 19 | "e.g. `(lr - min_lr) / (max_lr - min_lr) = (value - min_value) / (max_value - min_value)`. Please check `max_lr >= min_lr` and `max_value >= min_value`;" \ 20 | "if `max_lr==min_lr` hence `lr` is constant with step, please set 'max_value == min_value' so 'value' is constant with step." 21 | 22 | assert max_value >= min_value 23 | 24 | self.step() # take 1 step during initialization to get self._last_lr 25 | 26 | def lr(self): 27 | return self._last_lr[0] 28 | 29 | def step(self): 30 | self.t += 1 31 | if hasattr(self.pytorch_lr_scheduler, "_last_lr"): 32 | lr = self.pytorch_lr_scheduler._last_lr[0] 33 | else: 34 | lr = self.pytorch_lr_scheduler.optimizer.param_groups[0]['lr'] 35 | 36 | if self.max_lr > self.min_lr: 37 | value = self.min_value + (self.max_value - self.min_value) * (lr - self.min_lr) / (self.max_lr - self.min_lr) 38 | else: 39 | value = self.max_value 40 | 41 | self._last_lr = [value] 42 | return value 43 | 44 | class SchedulerBase: 45 | def __init__(self, T_max, max_value, min_value=0.0, init_value=0.0, warmup_steps=0, optimizer=None): 46 | super(SchedulerBase, self).__init__() 47 | self.t = 0 48 | self.min_value = min_value 49 | self.max_value = max_value 50 | self.init_value = init_value 51 | self.warmup_steps = warmup_steps 52 | self.total_steps = T_max 53 | 54 | # record current value in self._last_lr to match API from torch.optim.lr_scheduler 55 | self._last_lr = [init_value] 56 | 57 | # If optimizer is not None, will set learning rate to all trainable parameters in optimizer. 58 | # If optimizer is None, only output the value of lr. 59 | self.optimizer = optimizer 60 | 61 | def step(self): 62 | if self.t < self.warmup_steps: 63 | value = self.init_value + (self.max_value - self.init_value) * self.t / self.warmup_steps 64 | elif self.t == self.warmup_steps: 65 | value = self.max_value 66 | else: 67 | value = self.step_func() 68 | self.t += 1 69 | 70 | # apply the lr to optimizer if it's provided 71 | if self.optimizer is not None: 72 | for param_group in self.optimizer.param_groups: 73 | param_group['lr'] = value 74 | 75 | self._last_lr = [value] 76 | return value 77 | 78 | def step_func(self): 79 | pass 80 | 81 | def lr(self): 82 | return self._last_lr[0] 83 | 84 | class LinearScheduler(SchedulerBase): 85 | def step_func(self): 86 | value = self.max_value + (self.min_value - self.max_value) * (self.t - self.warmup_steps) / ( 87 | self.total_steps - self.warmup_steps) 88 | return value 89 | 90 | class CosineScheduler(SchedulerBase): 91 | def step_func(self): 92 | phase = (self.t-self.warmup_steps) / (self.total_steps-self.warmup_steps) * math.pi 93 | value = self.min_value + (self.max_value-self.min_value) * (np.cos(phase) + 1.) / 2.0 94 | return value 95 | 96 | class PolyScheduler(SchedulerBase): 97 | def __init__(self, poly_order=-0.5, *args, **kwargs): 98 | super(PolyScheduler, self).__init__(*args, **kwargs) 99 | self.poly_order = poly_order 100 | assert poly_order<=0, "Please check poly_order<=0 so that the scheduler decreases with steps" 101 | 102 | def step_func(self): 103 | value = self.min_value + (self.max_value-self.min_value) * (self.t - self.warmup_steps)**self.poly_order 104 | return value 105 | 106 | 107 | 108 | 109 | -------------------------------------------------------------------------------- /buffer/gsam/util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn.modules.batchnorm import _BatchNorm 4 | 5 | def disable_running_stats(model): 6 | def _disable(module): 7 | if isinstance(module, _BatchNorm): 8 | module.backup_momentum = module.momentum 9 | module.momentum = 0 10 | 11 | model.apply(_disable) 12 | 13 | def enable_running_stats(model): 14 | def _enable(module): 15 | if isinstance(module, _BatchNorm) and hasattr(module, "backup_momentum"): 16 | module.momentum = module.backup_momentum 17 | 18 | model.apply(_enable) 19 | -------------------------------------------------------------------------------- /buffer/gsam/wide_res_net.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | class BasicUnit(nn.Module): 9 | def __init__(self, channels: int, dropout: float): 10 | super(BasicUnit, self).__init__() 11 | self.block = nn.Sequential(OrderedDict([ 12 | ("0_normalization", nn.BatchNorm2d(channels)), 13 | ("1_activation", nn.ReLU(inplace=True)), 14 | ("2_convolution", nn.Conv2d(channels, channels, (3, 3), stride=1, padding=1, bias=False)), 15 | ("3_normalization", nn.BatchNorm2d(channels)), 16 | ("4_activation", nn.ReLU(inplace=True)), 17 | ("5_dropout", nn.Dropout(dropout, inplace=True)), 18 | ("6_convolution", nn.Conv2d(channels, channels, (3, 3), stride=1, padding=1, bias=False)), 19 | ])) 20 | 21 | def forward(self, x): 22 | return x + self.block(x) 23 | 24 | 25 | class DownsampleUnit(nn.Module): 26 | def __init__(self, in_channels: int, out_channels: int, stride: int, dropout: float): 27 | super(DownsampleUnit, self).__init__() 28 | self.norm_act = nn.Sequential(OrderedDict([ 29 | ("0_normalization", nn.BatchNorm2d(in_channels)), 30 | ("1_activation", nn.ReLU(inplace=True)), 31 | ])) 32 | self.block = nn.Sequential(OrderedDict([ 33 | ("0_convolution", nn.Conv2d(in_channels, out_channels, (3, 3), stride=stride, padding=1, bias=False)), 34 | ("1_normalization", nn.BatchNorm2d(out_channels)), 35 | ("2_activation", nn.ReLU(inplace=True)), 36 | ("3_dropout", nn.Dropout(dropout, inplace=True)), 37 | ("4_convolution", nn.Conv2d(out_channels, out_channels, (3, 3), stride=1, padding=1, bias=False)), 38 | ])) 39 | self.downsample = nn.Conv2d(in_channels, out_channels, (1, 1), stride=stride, padding=0, bias=False) 40 | 41 | def forward(self, x): 42 | x = self.norm_act(x) 43 | return self.block(x) + self.downsample(x) 44 | 45 | 46 | class Block(nn.Module): 47 | def __init__(self, in_channels: int, out_channels: int, stride: int, depth: int, dropout: float): 48 | super(Block, self).__init__() 49 | self.block = nn.Sequential( 50 | DownsampleUnit(in_channels, out_channels, stride, dropout), 51 | *(BasicUnit(out_channels, dropout) for _ in range(depth)) 52 | ) 53 | 54 | def forward(self, x): 55 | return self.block(x) 56 | 57 | 58 | class WideResNet(nn.Module): 59 | def __init__(self, depth: int, width_factor: int, dropout: float, in_channels: int, labels: int): 60 | super(WideResNet, self).__init__() 61 | 62 | self.filters = [16, 1 * 16 * width_factor, 2 * 16 * width_factor, 4 * 16 * width_factor] 63 | self.block_depth = (depth - 4) // (3 * 2) 64 | 65 | self.f = nn.Sequential(OrderedDict([ 66 | ("0_convolution", nn.Conv2d(in_channels, self.filters[0], (3, 3), stride=1, padding=1, bias=False)), 67 | ("1_block", Block(self.filters[0], self.filters[1], 1, self.block_depth, dropout)), 68 | ("2_block", Block(self.filters[1], self.filters[2], 2, self.block_depth, dropout)), 69 | ("3_block", Block(self.filters[2], self.filters[3], 2, self.block_depth, dropout)), 70 | ("4_normalization", nn.BatchNorm2d(self.filters[3])), 71 | ("5_activation", nn.ReLU(inplace=True)), 72 | ("6_pooling", nn.AvgPool2d(kernel_size=8)), 73 | ("7_flattening", nn.Flatten()), 74 | ("8_classification", nn.Linear(in_features=self.filters[3], out_features=labels)), 75 | ])) 76 | 77 | self._initialize() 78 | 79 | def _initialize(self): 80 | for m in self.modules(): 81 | if isinstance(m, nn.Conv2d): 82 | nn.init.kaiming_normal_(m.weight.data, mode="fan_in", nonlinearity="relu") 83 | if m.bias is not None: 84 | m.bias.data.zero_() 85 | elif isinstance(m, nn.BatchNorm2d): 86 | m.weight.data.fill_(1) 87 | m.bias.data.zero_() 88 | elif isinstance(m, nn.Linear): 89 | m.weight.data.zero_() 90 | m.bias.data.zero_() 91 | 92 | def forward(self, x): 93 | return self.f(x) 94 | -------------------------------------------------------------------------------- /buffer/utility/bypass_bn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | def disable_running_stats(model): 5 | def _disable(module): 6 | if isinstance(module, nn.BatchNorm2d): 7 | module.backup_momentum = module.momentum 8 | module.momentum = 0 9 | 10 | model.apply(_disable) 11 | 12 | def enable_running_stats(model): 13 | def _enable(module): 14 | if isinstance(module, nn.BatchNorm2d) and hasattr(module, "backup_momentum"): 15 | module.momentum = module.backup_momentum 16 | 17 | model.apply(_enable) 18 | -------------------------------------------------------------------------------- /buffer/utility/cutout.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class Cutout: 5 | def __init__(self, size=16, p=0.5): 6 | self.size = size 7 | self.half_size = size // 2 8 | self.p = p 9 | 10 | def __call__(self, image): 11 | if torch.rand([1]).item() > self.p: 12 | return image 13 | 14 | left = torch.randint(-self.half_size, image.size(1) - self.half_size, [1]).item() 15 | top = torch.randint(-self.half_size, image.size(2) - self.half_size, [1]).item() 16 | right = min(image.size(1), left + self.size) 17 | bottom = min(image.size(2), top + self.size) 18 | 19 | image[:, max(0, left): right, max(0, top): bottom] = 0 20 | return image 21 | -------------------------------------------------------------------------------- /buffer/utility/initialize.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | 4 | 5 | def initialize(args, seed: int): 6 | random.seed(seed) 7 | torch.manual_seed(seed) 8 | torch.cuda.manual_seed(seed) 9 | torch.cuda.manual_seed_all(seed) 10 | 11 | torch.backends.cudnn.enabled = True 12 | torch.backends.cudnn.benchmark = True 13 | torch.backends.cudnn.deterministic = False 14 | -------------------------------------------------------------------------------- /buffer/utility/loading_bar.py: -------------------------------------------------------------------------------- 1 | class LoadingBar: 2 | def __init__(self, length: int = 40): 3 | self.length = length 4 | self.symbols = ['┈', '░', '▒', '▓'] 5 | 6 | def __call__(self, progress: float) -> str: 7 | p = int(progress * self.length*4 + 0.5) 8 | d, r = p // 4, p % 4 9 | return '┠┈' + d * '█' + ((self.symbols[r]) + max(0, self.length-1-d) * '┈' if p < self.length*4 else '') + "┈┨" 10 | -------------------------------------------------------------------------------- /buffer/utility/log.py: -------------------------------------------------------------------------------- 1 | from utility.loading_bar import LoadingBar 2 | import time 3 | 4 | 5 | class Log: 6 | def __init__(self, log_each: int, initial_epoch=-1): 7 | self.loading_bar = LoadingBar(length=27) 8 | self.best_accuracy = 0.0 9 | self.log_each = log_each 10 | self.epoch = initial_epoch 11 | 12 | def train(self, len_dataset: int) -> None: 13 | self.epoch += 1 14 | if self.epoch == 0: 15 | self._print_header() 16 | else: 17 | self.flush() 18 | 19 | self.is_train = True 20 | self.last_steps_state = {"loss": 0.0, "accuracy": 0.0, "steps": 0} 21 | self._reset(len_dataset) 22 | 23 | def eval(self, len_dataset: int) -> None: 24 | self.flush() 25 | self.is_train = False 26 | self._reset(len_dataset) 27 | 28 | def __call__(self, model, loss, accuracy, learning_rate: float = None) -> None: 29 | if self.is_train: 30 | self._train_step(model, loss, accuracy, learning_rate) 31 | else: 32 | self._eval_step(loss, accuracy) 33 | 34 | def flush(self) -> None: 35 | if self.is_train: 36 | loss = self.epoch_state["loss"] / self.epoch_state["steps"] 37 | accuracy = self.epoch_state["accuracy"] / self.epoch_state["steps"] 38 | 39 | print( 40 | f"\r┃{self.epoch:12d} ┃{loss:12.4f} │{100*accuracy:10.2f} % ┃{self.learning_rate:12.3e} │{self._time():>12} ┃", 41 | end="", 42 | flush=True, 43 | ) 44 | 45 | else: 46 | loss = self.epoch_state["loss"] / self.epoch_state["steps"] 47 | accuracy = self.epoch_state["accuracy"] / self.epoch_state["steps"] 48 | 49 | print(f"{loss:12.4f} │{100*accuracy:10.2f} % ┃", flush=True) 50 | 51 | if accuracy > self.best_accuracy: 52 | self.best_accuracy = accuracy 53 | 54 | def _train_step(self, model, loss, accuracy, learning_rate: float) -> None: 55 | self.learning_rate = learning_rate 56 | self.last_steps_state["loss"] += loss.sum().item() 57 | self.last_steps_state["accuracy"] += accuracy.sum().item() 58 | self.last_steps_state["steps"] += loss.numel() 59 | self.epoch_state["loss"] += loss.sum().item() 60 | self.epoch_state["accuracy"] += accuracy.sum().item() 61 | self.epoch_state["steps"] += loss.numel() 62 | self.step += 1 63 | 64 | if self.step % self.log_each == self.log_each - 1: 65 | loss = self.last_steps_state["loss"] / self.last_steps_state["steps"] 66 | accuracy = self.last_steps_state["accuracy"] / self.last_steps_state["steps"] 67 | 68 | self.last_steps_state = {"loss": 0.0, "accuracy": 0.0, "steps": 0} 69 | progress = self.step / self.len_dataset 70 | 71 | print( 72 | f"\r┃{self.epoch:12d} ┃{loss:12.4f} │{100*accuracy:10.2f} % ┃{learning_rate:12.3e} │{self._time():>12} {self.loading_bar(progress)}", 73 | end="", 74 | flush=True, 75 | ) 76 | 77 | def _eval_step(self, loss, accuracy) -> None: 78 | self.epoch_state["loss"] += loss.sum().item() 79 | self.epoch_state["accuracy"] += accuracy.sum().item() 80 | self.epoch_state["steps"] += loss.size(0) 81 | 82 | def _reset(self, len_dataset: int) -> None: 83 | self.start_time = time.time() 84 | self.step = 0 85 | self.len_dataset = len_dataset 86 | self.epoch_state = {"loss": 0.0, "accuracy": 0.0, "steps": 0} 87 | 88 | def _time(self) -> str: 89 | time_seconds = int(time.time() - self.start_time) 90 | return f"{time_seconds // 60:02d}:{time_seconds % 60:02d} min" 91 | 92 | def _print_header(self) -> None: 93 | print(f"┏━━━━━━━━━━━━━━┳━━━━━━━╸T╺╸R╺╸A╺╸I╺╸N╺━━━━━━━┳━━━━━━━╸S╺╸T╺╸A╺╸T╺╸S╺━━━━━━━┳━━━━━━━╸V╺╸A╺╸L╺╸I╺╸D╺━━━━━━━┓") 94 | print(f"┃ ┃ ╷ ┃ ╷ ┃ ╷ ┃") 95 | print(f"┃ epoch ┃ loss │ accuracy ┃ l.r. │ elapsed ┃ loss │ accuracy ┃") 96 | print(f"┠──────────────╂──────────────┼──────────────╂──────────────┼──────────────╂──────────────┼──────────────┨") 97 | -------------------------------------------------------------------------------- /buffer/utility/step_lr.py: -------------------------------------------------------------------------------- 1 | class StepLR: 2 | def __init__(self, optimizer, learning_rate: float, total_epochs: int): 3 | self.optimizer = optimizer 4 | self.total_epochs = total_epochs 5 | self.base = learning_rate 6 | 7 | def __call__(self, epoch): 8 | if epoch < self.total_epochs * 3/10: 9 | lr = self.base 10 | elif epoch < self.total_epochs * 6/10: 11 | lr = self.base * 0.2 12 | elif epoch < self.total_epochs * 8/10: 13 | lr = self.base * 0.2 ** 2 14 | else: 15 | lr = self.base * 0.2 ** 3 16 | 17 | for param_group in self.optimizer.param_groups: 18 | param_group["lr"] = lr 19 | 20 | self._last_lr = [lr] 21 | 22 | def lr(self) -> float: 23 | return self.optimizer.param_groups[0]["lr"] 24 | -------------------------------------------------------------------------------- /configs/CIFAR-10/ConvIN/IPC1.yaml: -------------------------------------------------------------------------------- 1 | dataset: CIFAR10 2 | device: [3] 3 | 4 | ipc: 1 5 | syn_steps: 80 6 | expert_epochs: 2 7 | lr_img: 100 8 | lr_teacher: 0.01 9 | 10 | buffer_path: ../buffer_storage/ 11 | 12 | data_path: ../dataset 13 | ema_decay: 0.995 14 | Iteration: 10000 15 | batch_syn: 10 16 | 17 | # wandb 18 | project: CIFAR10_ipc1 19 | 20 | num_eval: 1 21 | eval_it: 500 22 | skip_first_eva: True 23 | 24 | lr_y: 5. 25 | Momentum_y: 0.9 26 | threshold: 1.1 27 | pix_init: samples_predicted_correctly 28 | Sequential_Generation: False 29 | batch_train: 128 30 | min_start_epoch: 0 31 | max_start_epoch: 4 32 | lr_lr: 0.0000001 33 | zca: True 34 | -------------------------------------------------------------------------------- /configs/CIFAR-10/ConvIN/IPC10.yaml: -------------------------------------------------------------------------------- 1 | dataset: CIFAR10 2 | device: [3] 3 | 4 | ipc: 10 5 | syn_steps: 80 6 | expert_epochs: 2 7 | lr_img: 100 8 | lr_teacher: 0.01 9 | buffer_path: ../buffer_storage/ 10 | data_path: ../dataset 11 | ema_decay: 0.9995 12 | Iteration: 10000 13 | batch_syn: 100 14 | 15 | # wandb 16 | project: CIFAR10_ipc10 17 | 18 | num_eval: 1 19 | eval_it: 500 20 | skip_first_eva: True 21 | 22 | lr_y: 2.0 23 | Momentum_y: 0.9 24 | threshold: 1. 25 | pix_init: samples_predicted_correctly 26 | batch_train: 128 27 | 28 | min_start_epoch: 0 29 | current_max_start_epoch: 10 30 | max_start_epoch: 20 31 | expansion_end_epoch: 1000 32 | 33 | lr_lr: 0.00001 34 | zca: True 35 | 36 | -------------------------------------------------------------------------------- /configs/CIFAR-10/ConvIN/IPC1000.yaml: -------------------------------------------------------------------------------- 1 | dataset: CIFAR10 2 | device: [0,1,2,3] 3 | 4 | ipc: 1000 5 | syn_steps: 100 6 | expert_epochs: 2 7 | lr_img: 50 8 | lr_teacher: 0.01 9 | 10 | buffer_path: ../buffer_storage/ 11 | 12 | data_path: ../dataset 13 | ema_decay: 0.995 14 | Iteration: 10000 15 | batch_syn: 1000 16 | 17 | # wandb 18 | project: CIFAR10_ipc1000 19 | 20 | num_eval: 1 21 | eval_it: 1000 22 | skip_first_eva: True 23 | 24 | lr_y: 10. 25 | Momentum_y: 0.9 26 | threshold: 1. 27 | pix_init: samples_predicted_correctly 28 | Sequential_Generation: False 29 | batch_train: 128 30 | min_start_epoch: 40 31 | max_start_epoch: 60 32 | lr_lr: 0.00001 33 | zca: True 34 | 35 | -------------------------------------------------------------------------------- /configs/CIFAR-10/ConvIN/IPC50.yaml: -------------------------------------------------------------------------------- 1 | dataset: CIFAR10 2 | device: [0,1] 3 | 4 | ipc: 50 5 | syn_steps: 80 6 | expert_epochs: 2 7 | lr_img: 1000 8 | lr_teacher: 0.01 9 | buffer_path: ../buffer_storage/ 10 | data_path: ../dataset 11 | ema_decay: 0.995 12 | Iteration: 10000 13 | batch_syn: 500 14 | 15 | # wandb 16 | project: CIFAR10_ipc50 17 | 18 | num_eval: 1 19 | eval_it: 500 20 | skip_first_eva: False 21 | 22 | lr_y: 2. 23 | Momentum_y: 0.9 24 | threshold: 1. 25 | pix_init: samples_predicted_correctly 26 | expansion_end_epoch: 2000 27 | batch_train: 128 28 | min_start_epoch: 0 29 | current_max_start_epoch: 20 30 | max_start_epoch: 40 31 | lr_lr: 0.00001 32 | zca: True 33 | -------------------------------------------------------------------------------- /configs/CIFAR-10/ConvIN/IPC500.yaml: -------------------------------------------------------------------------------- 1 | dataset: CIFAR10 2 | device: [2,3] 3 | 4 | ipc: 500 5 | syn_steps: 40 6 | expert_epochs: 2 7 | lr_img: 50 8 | lr_teacher: 0.01 9 | 10 | buffer_path: ../buffer_storage/ 11 | 12 | data_path: ../dataset 13 | ema_decay: 0.995 14 | Iteration: 10000 15 | batch_syn: 1000 16 | 17 | # wandb 18 | project: CIFAR10_ipc500 19 | 20 | num_eval: 1 21 | eval_it: 1000 22 | skip_first_eva: False 23 | 24 | lr_y: 10. 25 | Momentum_y: 0.9 26 | threshold: 1. 27 | pix_init: samples_predicted_correctly 28 | batch_train: 128 29 | 30 | Sequential_Generation: False 31 | min_start_epoch: 40 32 | max_start_epoch: 60 33 | lr_lr: 0.00001 34 | zca: True 35 | 36 | -------------------------------------------------------------------------------- /configs/CIFAR-100/ConvIN/IPC1.yaml: -------------------------------------------------------------------------------- 1 | dataset: CIFAR100 2 | model: ConvNet 3 | device: [0] 4 | 5 | ipc: 1 6 | syn_steps: 40 7 | expert_epochs: 3 8 | zca: True 9 | lr_img: 1000 10 | lr_teacher: 0.01 11 | buffer_path: ../buffer_storage 12 | data_path: ../dataset 13 | ema_decay: 0.9995 14 | Iteration: 10000 15 | batch_syn: 1000 16 | 17 | # wandb 18 | project: CIFAR100_ipc1 19 | 20 | num_eval: 1 21 | eval_it: 500 22 | skip_first_eva: False 23 | 24 | lr_y: 10.0 25 | Momentum_y: 0.9 26 | threshold: 1. 27 | pix_init: samples_predicted_correctly 28 | expansion_end_epoch: 1000 29 | batch_train: 128 30 | min_start_epoch: 0 31 | current_max_start_epoch: 10 32 | max_start_epoch: 20 33 | lr_lr: 0.00001 34 | 35 | 36 | -------------------------------------------------------------------------------- /configs/CIFAR-100/ConvIN/IPC10.yaml: -------------------------------------------------------------------------------- 1 | dataset: CIFAR100 2 | device: [0,1,2,3] 3 | 4 | ipc: 10 5 | syn_steps: 80 6 | expert_epochs: 2 7 | lr_img: 1000 8 | lr_teacher: 0.01 9 | buffer_path: ../buffer_storage/ 10 | data_path: ../dataset 11 | ema_decay: 0.9995 12 | Iteration: 10000 13 | batch_syn: 1000 14 | 15 | # wandb 16 | project: CIFAR100_ipc10 17 | 18 | num_eval: 1 19 | eval_it: 500 20 | skip_first_eva: True 21 | 22 | lr_y: 10.0 23 | Momentum_y: 0.9 24 | threshold: 1. 25 | pix_init: samples_predicted_correctly 26 | expansion_end_epoch: 2000 27 | batch_train: 128 28 | min_start_epoch: 0 29 | current_max_start_epoch: 30 30 | max_start_epoch: 50 31 | lr_lr: 0.00001 32 | zca: True 33 | 34 | -------------------------------------------------------------------------------- /configs/CIFAR-100/ConvIN/IPC100.yaml: -------------------------------------------------------------------------------- 1 | dataset: CIFAR100 2 | device: [0,1,2,3] 3 | 4 | ipc: 100 5 | syn_steps: 80 6 | expert_epochs: 2 7 | lr_img: 50 8 | lr_teacher: 0.01 9 | buffer_path: ../buffer_storage/ 10 | data_path: ../dataset 11 | ema_decay: 0.9995 12 | Iteration: 10000 13 | batch_syn: 1000 14 | 15 | # wandb 16 | project: CIFAR100_ipc100 17 | 18 | num_eval: 1 19 | eval_it: 500 20 | skip_first_eva: True 21 | 22 | lr_y: 10.0 23 | Momentum_y: 0.9 24 | threshold: 1. 25 | pix_init: samples_predicted_correctly 26 | Sequential_Generation: False 27 | batch_train: 128 28 | min_start_epoch: 30 29 | max_start_epoch: 70 30 | lr_lr: 0.00001 31 | zca: True 32 | 33 | -------------------------------------------------------------------------------- /configs/CIFAR-100/ConvIN/IPC50.yaml: -------------------------------------------------------------------------------- 1 | dataset: CIFAR100 2 | device: [0,1,2,3] 3 | 4 | ipc: 50 5 | syn_steps: 80 6 | expert_epochs: 2 7 | lr_img: 1000 8 | lr_teacher: 0.01 9 | buffer_path: ../buffer_storage/ 10 | data_path: ../dataset 11 | ema_decay: 0.9995 12 | Iteration: 10000 13 | batch_syn: 1000 14 | 15 | # wandb 16 | project: CIFAR100_ipc50 17 | 18 | num_eval: 1 19 | eval_it: 500 20 | skip_first_eva: False 21 | 22 | lr_y: 10.0 23 | Momentum_y: 0.9 24 | threshold: 1. 25 | pix_init: real 26 | Sequential_Generation: False 27 | batch_train: 128 28 | min_start_epoch: 20 29 | max_start_epoch: 70 30 | lr_lr: 0.00001 31 | zca: True 32 | 33 | -------------------------------------------------------------------------------- /configs/TinyImageNet/ConvIN/IPC1.yaml: -------------------------------------------------------------------------------- 1 | dataset: Tiny 2 | res: 64 3 | model: ConvNetD4 4 | device: [0,1] 5 | 6 | ipc: 1 7 | syn_steps: 40 8 | expert_epochs: 2 9 | lr_img: 10000 10 | lr_teacher: 0.01 11 | buffer_path: ../buffer_storage/ 12 | data_path: ../dataset/tiny-imagenet-200 13 | ema_decay: 0.9995 14 | Iteration: 10000 15 | batch_syn: 1000 16 | 17 | # wandb 18 | project: Tiny 19 | 20 | num_eval: 1 21 | eval_it: 500 22 | skip_first_eva: True 23 | 24 | lr_y: 10.0 25 | Momentum_y: 0.9 26 | threshold: 1. 27 | pix_init: samples_predicted_correctly 28 | expansion_end_epoch: 2000 29 | batch_train: 128 30 | min_start_epoch: 0 31 | current_max_start_epoch: 15 32 | max_start_epoch: 20 33 | lr_lr: 0.0001 34 | 35 | zca: True 36 | 37 | -------------------------------------------------------------------------------- /configs/TinyImageNet/ConvIN/IPC10.yaml: -------------------------------------------------------------------------------- 1 | dataset: Tiny 2 | res: 64 3 | model: ConvNetD4 4 | device: [0,1,2,3] 5 | 6 | ipc: 10 7 | syn_steps: 40 8 | expert_epochs: 2 9 | lr_img: 100 10 | lr_teacher: 0.01 11 | buffer_path: ../buffer_storage/ 12 | data_path: ../dataset/tiny-imagenet-200 13 | ema_decay: 0.9995 14 | Iteration: 10000 15 | batch_syn: 500 16 | 17 | # wandb 18 | project: Tiny_ipc10 19 | 20 | num_eval: 1 21 | eval_it: 500 22 | skip_first_eva: False 23 | 24 | lr_y: 10.0 25 | Momentum_y: 0.9 26 | threshold: 1. 27 | pix_init: samples_predicted_correctly 28 | batch_train: 128 29 | 30 | Sequential_Generation: False 31 | min_start_epoch: 10 32 | max_start_epoch: 50 33 | 34 | lr_lr: 0.0001 35 | 36 | zca: True 37 | 38 | -------------------------------------------------------------------------------- /configs/TinyImageNet/ConvIN/IPC50.yaml: -------------------------------------------------------------------------------- 1 | dataset: Tiny 2 | res: 64 3 | model: ConvNetD4 4 | device: [0,1,2,3] 5 | 6 | ipc: 50 7 | syn_steps: 20 8 | expert_epochs: 2 9 | lr_img: 100 10 | lr_teacher: 0.01 11 | buffer_path: ../buffer_storage/ 12 | data_path: ../dataset/tiny-imagenet-200 13 | ema_decay: 0.995 14 | Iteration: 10000 15 | batch_syn: 1000 16 | 17 | # wandb 18 | project: Tiny_ipc50 19 | 20 | num_eval: 1 21 | eval_it: 500 22 | skip_first_eva: True 23 | 24 | lr_y: 10.0 25 | Momentum_y: 0.9 26 | threshold: 1. 27 | pix_init: samples_predicted_correctly 28 | expansion_end_epoch: 1 29 | batch_train: 128 30 | 31 | Sequential_Generation: False 32 | min_start_epoch: 40 33 | max_start_epoch: 70 34 | 35 | lr_lr: 0.0001 36 | 37 | zca: True 38 | 39 | -------------------------------------------------------------------------------- /distill/DATM.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | sys.path.append("../") 4 | import argparse 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import torchvision.utils 10 | from tqdm import tqdm 11 | from utils.utils_baseline import get_dataset, get_network, get_eval_pool, evaluate_synset, get_time, DiffAugment, ParamDiffAug 12 | import wandb 13 | import copy 14 | import random 15 | from reparam_module import ReparamModule 16 | # from kmeans_pytorch import kmeans 17 | from utils.cfg import CFG as cfg 18 | import warnings 19 | import yaml 20 | 21 | warnings.filterwarnings("ignore", category=DeprecationWarning) 22 | 23 | def manual_seed(seed=0): 24 | random.seed(seed) 25 | os.environ['PYTHONHASHSEED'] = str(seed) 26 | np.random.seed(seed) 27 | torch.manual_seed(seed) 28 | torch.cuda.manual_seed(seed) 29 | torch.cuda.manual_seed_all(seed) 30 | 31 | def main(args): 32 | 33 | manual_seed() 34 | 35 | os.environ['CUDA_VISIBLE_DEVICES'] = ','.join([str(x) for x in args.device]) 36 | 37 | if args.max_experts is not None and args.max_files is not None: 38 | args.total_experts = args.max_experts * args.max_files 39 | 40 | print("CUDNN STATUS: {}".format(torch.backends.cudnn.enabled)) 41 | 42 | args.device = 'cuda' if torch.cuda.is_available() else 'cpu' 43 | 44 | if args.skip_first_eva==False: 45 | eval_it_pool = np.arange(0, args.Iteration + 1, args.eval_it).tolist() 46 | else: 47 | eval_it_pool = np.arange(args.eval_it, args.Iteration + 1, args.eval_it).tolist() 48 | channel, im_size, num_classes, class_names, mean, std, dst_train, dst_test, testloader, loader_train_dict, class_map, class_map_inv = get_dataset(args.dataset, args.data_path, args.batch_real, args.subset, args=args) 49 | model_eval_pool = get_eval_pool(args.eval_mode, args.model, args.model) 50 | 51 | im_res = im_size[0] 52 | 53 | args.im_size = im_size 54 | 55 | accs_all_exps = dict() # record performances of all experiments 56 | for key in model_eval_pool: 57 | accs_all_exps[key] = [] 58 | 59 | data_save = [] 60 | 61 | if args.dsa: 62 | # args.epoch_eval_train = 1000 63 | args.dc_aug_param = None 64 | 65 | args.dsa_param = ParamDiffAug() 66 | 67 | dsa_params = args.dsa_param 68 | if args.zca: 69 | zca_trans = args.zca_trans 70 | else: 71 | zca_trans = None 72 | 73 | wandb.init(sync_tensorboard=False, 74 | project=args.project, 75 | job_type="CleanRepo", 76 | config=args, 77 | ) 78 | 79 | args = type('', (), {})() 80 | 81 | for key in wandb.config._items: 82 | setattr(args, key, wandb.config._items[key]) 83 | 84 | args.dsa_param = dsa_params 85 | args.zca_trans = zca_trans 86 | 87 | if args.batch_syn is None: 88 | args.batch_syn = num_classes * args.ipc 89 | 90 | args.distributed = torch.cuda.device_count() > 1 91 | 92 | 93 | print('Hyper-parameters: \n', args.__dict__) 94 | print('Evaluation model pool: ', model_eval_pool) 95 | 96 | ''' organize the real dataset ''' 97 | images_all = [] 98 | labels_all = [] 99 | indices_class = [[] for c in range(num_classes)] 100 | print("BUILDING DATASET") 101 | if args.dataset == 'ImageNet1K' and os.path.exists('images_all.pt') and os.path.exists('labels_all.pt'): 102 | images_all = torch.load('images_all.pt') 103 | labels_all = torch.load('labels_all.pt') 104 | else: 105 | for i in tqdm(range(len(dst_train))): 106 | sample = dst_train[i] 107 | images_all.append(torch.unsqueeze(sample[0], dim=0)) 108 | labels_all.append(class_map[torch.tensor(sample[1]).item()]) 109 | images_all = torch.cat(images_all, dim=0).to("cpu") 110 | labels_all = torch.tensor(labels_all, dtype=torch.long, device="cpu") 111 | if args.dataset == 'ImageNet1K': 112 | torch.save(images_all, 'images_all.pt') 113 | torch.save(labels_all, 'labels_all.pt') 114 | 115 | for i, lab in tqdm(enumerate(labels_all)): 116 | indices_class[lab].append(i) 117 | 118 | 119 | 120 | for c in range(num_classes): 121 | print('class c = %d: %d real images'%(c, len(indices_class[c]))) 122 | 123 | for ch in range(channel): 124 | print('real images channel %d, mean = %.4f, std = %.4f'%(ch, torch.mean(images_all[:, ch]), torch.std(images_all[:, ch]))) 125 | 126 | 127 | def get_images(c, n): # get random n images from class c 128 | idx_shuffle = np.random.permutation(indices_class[c])[:n] 129 | return images_all[idx_shuffle] 130 | 131 | 132 | ''' initialize the synthetic data ''' 133 | label_syn = torch.tensor([np.ones(args.ipc)*i for i in range(num_classes)], dtype=torch.long, requires_grad=False, device=args.device).view(-1) # [0,0,0, 1,1,1, ..., 9,9,9] 134 | 135 | 136 | image_syn = torch.randn(size=(num_classes * args.ipc, channel, im_size[0], im_size[1]), dtype=torch.float) 137 | 138 | syn_lr = torch.tensor(args.lr_teacher).to(args.device) 139 | expert_dir = os.path.join(args.buffer_path, args.dataset) 140 | if args.dataset == "ImageNet": 141 | expert_dir = os.path.join(expert_dir, args.subset, str(args.res)) 142 | if args.dataset in ["CIFAR10", "CIFAR100"] and not args.zca: 143 | expert_dir += "_NO_ZCA" 144 | expert_dir = os.path.join(expert_dir, args.model) 145 | print("Expert Dir: {}".format(expert_dir)) 146 | if args.load_all: 147 | buffer = [] 148 | n = 0 149 | while os.path.exists(os.path.join(expert_dir, "replay_buffer_{}.pt".format(n))): 150 | buffer = buffer + torch.load(os.path.join(expert_dir, "replay_buffer_{}.pt".format(n))) 151 | n += 1 152 | if n == 0: 153 | raise AssertionError("No buffers detected at {}".format(expert_dir)) 154 | 155 | else: 156 | expert_files = [] 157 | n = 0 158 | while os.path.exists(os.path.join(expert_dir, "replay_buffer_{}.pt".format(n))): 159 | expert_files.append(os.path.join(expert_dir, "replay_buffer_{}.pt".format(n))) 160 | n += 1 161 | if n == 0: 162 | raise AssertionError("No buffers detected at {}".format(expert_dir)) 163 | file_idx = 0 164 | expert_idx = 0 165 | # random.shuffle(expert_files) 166 | if args.max_files is not None: 167 | expert_files = expert_files[:args.max_files] 168 | 169 | expert_id = [i for i in range(len(expert_files))] 170 | random.shuffle(expert_id) 171 | 172 | print("loading file {}".format(expert_files[expert_id[file_idx]])) 173 | buffer = torch.load(expert_files[expert_id[file_idx]]) 174 | if args.max_experts is not None: 175 | buffer = buffer[:args.max_experts] 176 | buffer_id = [i for i in range(len(buffer))] 177 | random.shuffle(buffer_id) 178 | 179 | if args.pix_init == 'real': 180 | print('initialize synthetic data from random real images') 181 | for c in range(num_classes): 182 | image_syn.data[c * args.ipc:(c + 1) * args.ipc] = get_images(c, args.ipc).detach().data 183 | 184 | 185 | elif args.pix_init == 'samples_predicted_correctly': 186 | if args.parall_eva==False: 187 | device = torch.device("cuda:0") 188 | else: 189 | device = args.device 190 | if cfg.Initialize_Label_With_Another_Model: 191 | Temp_net = get_network(args.Initialize_Label_Model, channel, num_classes, im_size, dist=False).to(device) # get a random model 192 | else: 193 | Temp_net = get_network(args.model, channel, num_classes, im_size, dist=False).to(device) # get a random model 194 | Temp_net.eval() 195 | Temp_net = ReparamModule(Temp_net) 196 | if args.distributed and args.parall_eva==True: 197 | Temp_net = torch.nn.DataParallel(Temp_net) 198 | Temp_net.eval() 199 | logits=[] 200 | label_expert_files = expert_files 201 | temp_params = torch.load(label_expert_files[0])[0][args.Label_Model_Timestamp] 202 | temp_params = torch.cat([p.data.to(device).reshape(-1) for p in temp_params], 0) 203 | if args.distributed and args.parall_eva==True: 204 | temp_params = temp_params.unsqueeze(0).expand(torch.cuda.device_count(), -1) 205 | for c in range(num_classes): 206 | data_for_class_c = get_images(c, len(indices_class[c])).detach().data 207 | n, _, w, h = data_for_class_c.shape 208 | selected_num = 0 209 | select_times = 0 210 | cur=0 211 | temp_img = None 212 | Wrong_Predicted_Img = None 213 | batch_size = 256 214 | index = [] 215 | while len(index) len(data_for_class_c): 219 | select_times = 0 220 | cur+=1 221 | temp_params = torch.load(label_expert_files[int(cur/10)%10])[cur%10][args.Label_Model_Timestamp] 222 | temp_params = torch.cat([p.data.to(device).reshape(-1) for p in temp_params], 0).to(device) 223 | if args.distributed and args.parall_eva==True: 224 | temp_params = temp_params.unsqueeze(0).expand(torch.cuda.device_count(), -1) 225 | continue 226 | logits = Temp_net(current_data_batch, flat_param=temp_params).detach() 227 | prediction_class = np.argmax(logits.cpu().data.numpy(), axis=-1) 228 | for i in range(len(prediction_class)): 229 | if prediction_class[i]==c and len(index) best_acc[model_eval]: 380 | best_acc[model_eval] = acc_test_mean 381 | best_std[model_eval] = acc_test_std 382 | save_this_it = True 383 | print('Evaluate %d random %s, mean = %.4f std = %.4f\n-------------------------'%(len(accs_test), model_eval, acc_test_mean, acc_test_std)) 384 | wandb.log({'Accuracy/{}'.format(model_eval): acc_test_mean}, step=it) 385 | wandb.log({'Max_Accuracy/{}'.format(model_eval): best_acc[model_eval]}, step=it) 386 | wandb.log({'Std/{}'.format(model_eval): acc_test_std}, step=it) 387 | wandb.log({'Max_Std/{}'.format(model_eval): best_std[model_eval]}, step=it) 388 | 389 | if it in eval_it_pool and (save_this_it or it % 1000 == 0): 390 | with torch.no_grad(): 391 | image_save = image_syn.cuda() 392 | save_dir = os.path.join(".", "logged_files", args.dataset, str(args.ipc), args.model, wandb.run.name) 393 | 394 | if not os.path.exists(save_dir): 395 | os.makedirs(os.path.join(save_dir,'Normal')) 396 | 397 | torch.save(image_save.cpu(), os.path.join(save_dir, 'Normal',"images_{}.pt".format(it))) 398 | torch.save(label_syn.cpu(), os.path.join(save_dir, 'Normal', "labels_{}.pt".format(it))) 399 | torch.save(syn_lr.detach().cpu(), os.path.join(save_dir, 'Normal', "lr_{}.pt".format(it))) 400 | 401 | if save_this_it: 402 | torch.save(image_save.cpu(), os.path.join(save_dir, 'Normal', "images_best.pt".format(it))) 403 | torch.save(label_syn.cpu(), os.path.join(save_dir, 'Normal', "labels_best.pt".format(it))) 404 | torch.save(syn_lr.detach().cpu(), os.path.join(save_dir, 'Normal', "lr_best.pt".format(it))) 405 | 406 | wandb.log({"Pixels": wandb.Histogram(torch.nan_to_num(image_syn.detach().cpu()))}, step=it) 407 | 408 | if args.ipc < 50 or args.force_save: 409 | upsampled = image_save 410 | if args.dataset != "ImageNet": 411 | upsampled = torch.repeat_interleave(upsampled, repeats=4, dim=2) 412 | upsampled = torch.repeat_interleave(upsampled, repeats=4, dim=3) 413 | grid = torchvision.utils.make_grid(upsampled, nrow=10, normalize=True, scale_each=True) 414 | wandb.log({"Synthetic_Images": wandb.Image(torch.nan_to_num(grid.detach().cpu()))}, step=it) 415 | wandb.log({'Synthetic_Pixels': wandb.Histogram(torch.nan_to_num(image_save.detach().cpu()))}, step=it) 416 | 417 | for clip_val in [2.5]: 418 | std = torch.std(image_save) 419 | mean = torch.mean(image_save) 420 | upsampled = torch.clip(image_save, min=mean-clip_val*std, max=mean+clip_val*std) 421 | if args.dataset != "ImageNet": 422 | upsampled = torch.repeat_interleave(upsampled, repeats=4, dim=2) 423 | upsampled = torch.repeat_interleave(upsampled, repeats=4, dim=3) 424 | grid = torchvision.utils.make_grid(upsampled, nrow=10, normalize=True, scale_each=True) 425 | wandb.log({"Clipped_Synthetic_Images/std_{}".format(clip_val): wandb.Image(torch.nan_to_num(grid.detach().cpu()))}, step=it) 426 | 427 | if args.zca: 428 | image_save = image_save.to(args.device) 429 | image_save = args.zca_trans.inverse_transform(image_save) 430 | image_save.cpu() 431 | torch.save(image_save.cpu(), os.path.join(save_dir, 'Normal', "images_zca_{}.pt".format(it))) 432 | upsampled = image_save 433 | if args.dataset != "ImageNet": 434 | upsampled = torch.repeat_interleave(upsampled, repeats=4, dim=2) 435 | upsampled = torch.repeat_interleave(upsampled, repeats=4, dim=3) 436 | grid = torchvision.utils.make_grid(upsampled, nrow=10, normalize=True, scale_each=True) 437 | wandb.log({"Reconstructed_Images": wandb.Image(torch.nan_to_num(grid.detach().cpu()))}, step=it) 438 | wandb.log({'Reconstructed_Pixels': wandb.Histogram(torch.nan_to_num(image_save.detach().cpu()))}, step=it) 439 | for clip_val in [2.5]: 440 | std = torch.std(image_save) 441 | mean = torch.mean(image_save) 442 | upsampled = torch.clip(image_save, min=mean - clip_val * std, max=mean + clip_val * std) 443 | if args.dataset != "ImageNet": 444 | upsampled = torch.repeat_interleave(upsampled, repeats=4, dim=2) 445 | upsampled = torch.repeat_interleave(upsampled, repeats=4, dim=3) 446 | grid = torchvision.utils.make_grid(upsampled, nrow=10, normalize=True, scale_each=True) 447 | wandb.log({"Clipped_Reconstructed_Images/std_{}".format(clip_val): wandb.Image( 448 | torch.nan_to_num(grid.detach().cpu()))}, step=it) 449 | 450 | 451 | 452 | wandb.log({"Synthetic_LR": syn_lr.detach().cpu()}, step=it) 453 | 454 | student_net = get_network(args.model, channel, num_classes, im_size, dist=False).to(args.device) # get a random model 455 | 456 | student_net = ReparamModule(student_net) 457 | 458 | if args.distributed: 459 | student_net = torch.nn.DataParallel(student_net) 460 | 461 | student_net.train() 462 | 463 | num_params = sum([np.prod(p.size()) for p in (student_net.parameters())]) 464 | 465 | if args.load_all: 466 | expert_trajectory = buffer[np.random.randint(0, len(buffer))] 467 | else: 468 | expert_trajectory = buffer[buffer_id[expert_idx]] 469 | expert_idx += 1 470 | if expert_idx == len(buffer): 471 | expert_idx = 0 472 | file_idx += 1 473 | if file_idx == len(expert_files): 474 | file_idx = 0 475 | random.shuffle(expert_id) 476 | print("loading file {}".format(expert_files[expert_id[file_idx]])) 477 | if args.max_files != 1: 478 | del buffer 479 | buffer = torch.load(expert_files[expert_id[file_idx]]) 480 | if args.max_experts is not None: 481 | buffer = buffer[:args.max_experts] 482 | random.shuffle(buffer_id) 483 | 484 | # Only match easy traj. in the early stage 485 | if args.Sequential_Generation: 486 | Upper_Bound = args.current_max_start_epoch + int((args.max_start_epoch-args.current_max_start_epoch) * it/(args.expansion_end_epoch)) 487 | Upper_Bound = min(Upper_Bound, args.max_start_epoch) 488 | else: 489 | Upper_Bound = args.max_start_epoch 490 | 491 | start_epoch = np.random.randint(args.min_start_epoch, Upper_Bound) 492 | 493 | starting_params = expert_trajectory[start_epoch] 494 | target_params = expert_trajectory[start_epoch+args.expert_epochs] 495 | target_params = torch.cat([p.data.to(args.device).reshape(-1) for p in target_params], 0) 496 | student_params = [torch.cat([p.data.to(args.device).reshape(-1) for p in starting_params], 0).requires_grad_(True)] 497 | starting_params = torch.cat([p.data.to(args.device).reshape(-1) for p in starting_params], 0) 498 | 499 | syn_images = image_syn 500 | y_hat = label_syn 501 | 502 | param_loss_list = [] 503 | param_dist_list = [] 504 | indices_chunks = [] 505 | 506 | 507 | 508 | 509 | for step in range(args.syn_steps): 510 | if not indices_chunks: 511 | indices = torch.randperm(len(syn_images)) 512 | indices_chunks = list(torch.split(indices, args.batch_syn)) 513 | 514 | these_indices = indices_chunks.pop() 515 | 516 | x = syn_images[these_indices] 517 | this_y = y_hat[these_indices] 518 | 519 | 520 | if args.dsa and (not args.no_aug): 521 | x = DiffAugment(x, args.dsa_strategy, param=args.dsa_param) 522 | 523 | if args.distributed: 524 | forward_params = student_params[-1].unsqueeze(0).expand(torch.cuda.device_count(), -1) 525 | else: 526 | forward_params = student_params[-1] 527 | x = student_net(x, flat_param=forward_params) 528 | ce_loss = criterion(x, this_y) 529 | 530 | grad = torch.autograd.grad(ce_loss, student_params[-1], create_graph=True)[0] 531 | 532 | student_params.append(student_params[-1] - syn_lr * grad) 533 | 534 | param_loss = torch.tensor(0.0).to(args.device) 535 | param_dist = torch.tensor(0.0).to(args.device) 536 | 537 | param_loss += torch.nn.functional.mse_loss(student_params[-1], target_params, reduction="sum") 538 | param_dist += torch.nn.functional.mse_loss(starting_params, target_params, reduction="sum") 539 | 540 | param_loss_list.append(param_loss) 541 | param_dist_list.append(param_dist) 542 | 543 | param_loss /= num_params 544 | param_dist /= num_params 545 | 546 | param_loss /= param_dist 547 | 548 | grand_loss = param_loss 549 | 550 | optimizer_img.zero_grad() 551 | optimizer_lr.zero_grad() 552 | optimizer_y.zero_grad() 553 | 554 | grand_loss.backward() 555 | 556 | if grand_loss<=args.threshold: 557 | optimizer_y.step() 558 | optimizer_img.step() 559 | optimizer_lr.step() 560 | else: 561 | wandb.log({"falts": start_epoch}, step=it) 562 | 563 | 564 | 565 | wandb.log({"Grand_Loss": param_loss.detach().cpu(), 566 | "Start_Epoch": start_epoch}) 567 | 568 | for _ in student_params: 569 | del _ 570 | 571 | if it%10 == 0: 572 | print('%s iter = %04d, loss = %.4f' % (get_time(), it, grand_loss.item())) 573 | 574 | 575 | wandb.finish() 576 | 577 | 578 | if __name__ == '__main__': 579 | parser = argparse.ArgumentParser(description='Parameter Processing') 580 | 581 | parser.add_argument("--cfg", type=str, default="") 582 | args = parser.parse_args() 583 | 584 | cfg.merge_from_file(args.cfg) 585 | for key, value in cfg.items(): 586 | arg_name = '--' + key 587 | parser.add_argument(arg_name, type=type(value), default=value) 588 | args = parser.parse_args() 589 | main(args) 590 | 591 | 592 | 593 | -------------------------------------------------------------------------------- /distill/DATM_tesla.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | sys.path.append("../") 4 | import argparse 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import torchvision.utils 10 | from tqdm import tqdm 11 | from utils.utils_baseline import get_dataset, get_network, get_eval_pool, evaluate_synset, get_time, DiffAugment, ParamDiffAug 12 | import wandb 13 | import copy 14 | import random 15 | from reparam_module import ReparamModule 16 | # from kmeans_pytorch import kmeans 17 | from utils.cfg import CFG as cfg 18 | import warnings 19 | import yaml 20 | 21 | warnings.filterwarnings("ignore", category=DeprecationWarning) 22 | 23 | def manual_seed(seed=0): 24 | random.seed(seed) 25 | os.environ['PYTHONHASHSEED'] = str(seed) 26 | np.random.seed(seed) 27 | torch.manual_seed(seed) 28 | torch.cuda.manual_seed(seed) 29 | torch.cuda.manual_seed_all(seed) 30 | 31 | def main(args): 32 | 33 | manual_seed() 34 | 35 | os.environ['CUDA_VISIBLE_DEVICES'] = ','.join([str(x) for x in args.device]) 36 | 37 | if args.max_experts is not None and args.max_files is not None: 38 | args.total_experts = args.max_experts * args.max_files 39 | 40 | print("CUDNN STATUS: {}".format(torch.backends.cudnn.enabled)) 41 | 42 | args.device = 'cuda' if torch.cuda.is_available() else 'cpu' 43 | 44 | if args.skip_first_eva==False: 45 | eval_it_pool = np.arange(0, args.Iteration + 1, args.eval_it).tolist() 46 | else: 47 | eval_it_pool = np.arange(args.eval_it, args.Iteration + 1, args.eval_it).tolist() 48 | channel, im_size, num_classes, class_names, mean, std, dst_train, dst_test, testloader, loader_train_dict, class_map, class_map_inv = get_dataset(args.dataset, args.data_path, args.batch_real, args.subset, args=args) 49 | model_eval_pool = get_eval_pool(args.eval_mode, args.model, args.model) 50 | 51 | im_res = im_size[0] 52 | 53 | args.im_size = im_size 54 | 55 | accs_all_exps = dict() # record performances of all experiments 56 | for key in model_eval_pool: 57 | accs_all_exps[key] = [] 58 | 59 | data_save = [] 60 | 61 | if args.dsa: 62 | # args.epoch_eval_train = 1000 63 | args.dc_aug_param = None 64 | 65 | args.dsa_param = ParamDiffAug() 66 | 67 | dsa_params = args.dsa_param 68 | if args.zca: 69 | zca_trans = args.zca_trans 70 | else: 71 | zca_trans = None 72 | 73 | wandb.init(sync_tensorboard=False, 74 | project=args.project, 75 | job_type="CleanRepo", 76 | config=args, 77 | ) 78 | 79 | args = type('', (), {})() 80 | 81 | for key in wandb.config._items: 82 | setattr(args, key, wandb.config._items[key]) 83 | 84 | args.dsa_param = dsa_params 85 | args.zca_trans = zca_trans 86 | 87 | if args.batch_syn is None: 88 | args.batch_syn = num_classes * args.ipc 89 | 90 | args.distributed = torch.cuda.device_count() > 1 91 | 92 | 93 | print('Hyper-parameters: \n', args.__dict__) 94 | print('Evaluation model pool: ', model_eval_pool) 95 | 96 | ''' organize the real dataset ''' 97 | images_all = [] 98 | labels_all = [] 99 | indices_class = [[] for c in range(num_classes)] 100 | print("BUILDING DATASET") 101 | if args.dataset == 'ImageNet1K' and os.path.exists('images_all.pt') and os.path.exists('labels_all.pt'): 102 | images_all = torch.load('images_all.pt') 103 | labels_all = torch.load('labels_all.pt') 104 | else: 105 | for i in tqdm(range(len(dst_train))): 106 | sample = dst_train[i] 107 | images_all.append(torch.unsqueeze(sample[0], dim=0)) 108 | labels_all.append(class_map[torch.tensor(sample[1]).item()]) 109 | images_all = torch.cat(images_all, dim=0).to("cpu") 110 | labels_all = torch.tensor(labels_all, dtype=torch.long, device="cpu") 111 | if args.dataset == 'ImageNet1K': 112 | torch.save(images_all, 'images_all.pt') 113 | torch.save(labels_all, 'labels_all.pt') 114 | 115 | for i, lab in tqdm(enumerate(labels_all)): 116 | indices_class[lab].append(i) 117 | 118 | 119 | 120 | for c in range(num_classes): 121 | print('class c = %d: %d real images'%(c, len(indices_class[c]))) 122 | 123 | for ch in range(channel): 124 | print('real images channel %d, mean = %.4f, std = %.4f'%(ch, torch.mean(images_all[:, ch]), torch.std(images_all[:, ch]))) 125 | 126 | 127 | def get_images(c, n): # get random n images from class c 128 | idx_shuffle = np.random.permutation(indices_class[c])[:n] 129 | return images_all[idx_shuffle] 130 | 131 | 132 | ''' initialize the synthetic data ''' 133 | label_syn = torch.tensor([ [i] * args.ipc for i in range(num_classes)], dtype=torch.long, requires_grad=False, device=args.device).view(-1) # [0,0,0, 1,1,1, ..., 9,9,9] 134 | 135 | 136 | image_syn = torch.randn(size=(num_classes * args.ipc, channel, im_size[0], im_size[1]), dtype=torch.float) 137 | 138 | syn_lr = torch.tensor(args.lr_teacher).to(args.device) 139 | expert_dir = os.path.join(args.buffer_path, args.dataset) 140 | if args.dataset == "ImageNet": 141 | expert_dir = os.path.join(expert_dir, args.subset, str(args.res)) 142 | if args.dataset in ["CIFAR10", "CIFAR100"] and not args.zca: 143 | expert_dir += "_NO_ZCA" 144 | expert_dir = os.path.join(expert_dir, args.model) 145 | print("Expert Dir: {}".format(expert_dir)) 146 | if args.load_all: 147 | expert_files = [] 148 | buffer = [] 149 | n = 0 150 | while os.path.exists(os.path.join(expert_dir, "replay_buffer_{}.pt".format(n))): 151 | expert_files.append(os.path.join(expert_dir, "replay_buffer_{}.pt".format(n))) 152 | buffer = buffer + torch.load(os.path.join(expert_dir, "replay_buffer_{}.pt".format(n))) 153 | n += 1 154 | if n == 0: 155 | raise AssertionError("No buffers detected at {}".format(expert_dir)) 156 | 157 | else: 158 | expert_files = [] 159 | n = 0 160 | while os.path.exists(os.path.join(expert_dir, "replay_buffer_{}.pt".format(n))): 161 | expert_files.append(os.path.join(expert_dir, "replay_buffer_{}.pt".format(n))) 162 | n += 1 163 | if n == 0: 164 | raise AssertionError("No buffers detected at {}".format(expert_dir)) 165 | file_idx = 0 166 | expert_idx = 0 167 | # random.shuffle(expert_files) 168 | if args.max_files is not None: 169 | expert_files = expert_files[:args.max_files] 170 | 171 | expert_id = [i for i in range(len(expert_files))] 172 | random.shuffle(expert_id) 173 | 174 | print("loading file {}".format(expert_files[expert_id[file_idx]])) 175 | buffer = torch.load(expert_files[expert_id[file_idx]]) 176 | if args.max_experts is not None: 177 | buffer = buffer[:args.max_experts] 178 | buffer_id = [i for i in range(len(buffer))] 179 | random.shuffle(buffer_id) 180 | 181 | if args.pix_init == 'real': 182 | print('initialize synthetic data from random real images') 183 | for c in range(num_classes): 184 | image_syn.data[c * args.ipc:(c + 1) * args.ipc] = get_images(c, args.ipc).detach().data 185 | 186 | 187 | elif args.pix_init == 'samples_predicted_correctly': 188 | if args.parall_eva==False: 189 | device = torch.device("cuda:0") 190 | else: 191 | device = args.device 192 | if cfg.Initialize_Label_With_Another_Model: 193 | Temp_net = get_network(args.Initialize_Label_Model, channel, num_classes, im_size, dist=False).to(device) # get a random model 194 | else: 195 | Temp_net = get_network(args.model, channel, num_classes, im_size, dist=False).to(device) # get a random model 196 | Temp_net.eval() 197 | Temp_net = ReparamModule(Temp_net) 198 | if args.distributed and args.parall_eva==True: 199 | Temp_net = torch.nn.DataParallel(Temp_net) 200 | Temp_net.eval() 201 | logits=[] 202 | label_expert_files = expert_files 203 | temp_params = torch.load(label_expert_files[0])[0][args.Label_Model_Timestamp] 204 | temp_params = torch.cat([p.data.to(device).reshape(-1) for p in temp_params], 0) 205 | if args.distributed and args.parall_eva==True: 206 | temp_params = temp_params.unsqueeze(0).expand(torch.cuda.device_count(), -1) 207 | for c in range(num_classes): 208 | data_for_class_c = get_images(c, len(indices_class[c])).detach().data 209 | n, _, w, h = data_for_class_c.shape 210 | selected_num = 0 211 | select_times = 0 212 | cur=0 213 | temp_img = None 214 | Wrong_Predicted_Img = None 215 | batch_size = 256 216 | index = [] 217 | while len(index) len(data_for_class_c): 221 | select_times = 0 222 | cur+=1 223 | temp_params = torch.load(label_expert_files[int(cur/10)%10])[cur%10][args.Label_Model_Timestamp] 224 | temp_params = torch.cat([p.data.to(device).reshape(-1) for p in temp_params], 0).to(device) 225 | if args.distributed and args.parall_eva==True: 226 | temp_params = temp_params.unsqueeze(0).expand(torch.cuda.device_count(), -1) 227 | continue 228 | logits = Temp_net(current_data_batch, flat_param=temp_params).detach() 229 | prediction_class = np.argmax(logits.cpu().data.numpy(), axis=-1) 230 | for i in range(len(prediction_class)): 231 | if prediction_class[i]==c and len(index) best_acc[model_eval]: 382 | best_acc[model_eval] = acc_test_mean 383 | best_std[model_eval] = acc_test_std 384 | save_this_it = True 385 | print('Evaluate %d random %s, mean = %.4f std = %.4f\n-------------------------'%(len(accs_test), model_eval, acc_test_mean, acc_test_std)) 386 | wandb.log({'Accuracy/{}'.format(model_eval): acc_test_mean}, step=it) 387 | wandb.log({'Max_Accuracy/{}'.format(model_eval): best_acc[model_eval]}, step=it) 388 | wandb.log({'Std/{}'.format(model_eval): acc_test_std}, step=it) 389 | wandb.log({'Max_Std/{}'.format(model_eval): best_std[model_eval]}, step=it) 390 | 391 | if it in eval_it_pool and (save_this_it or it % 1000 == 0): 392 | with torch.no_grad(): 393 | image_save = image_syn.cuda() 394 | save_dir = os.path.join(".", "logged_files", args.dataset, str(args.ipc), args.model, wandb.run.name) 395 | 396 | if not os.path.exists(save_dir): 397 | os.makedirs(os.path.join(save_dir,'Normal')) 398 | 399 | torch.save(image_save.cpu(), os.path.join(save_dir, 'Normal',"images_{}.pt".format(it))) 400 | torch.save(label_syn.cpu(), os.path.join(save_dir, 'Normal', "labels_{}.pt".format(it))) 401 | torch.save(syn_lr.detach().cpu(), os.path.join(save_dir, 'Normal', "lr_{}.pt".format(it))) 402 | 403 | if save_this_it: 404 | torch.save(image_save.cpu(), os.path.join(save_dir, 'Normal', "images_best.pt".format(it))) 405 | torch.save(label_syn.cpu(), os.path.join(save_dir, 'Normal', "labels_best.pt".format(it))) 406 | torch.save(syn_lr.detach().cpu(), os.path.join(save_dir, 'Normal', "lr_best.pt".format(it))) 407 | 408 | wandb.log({"Pixels": wandb.Histogram(torch.nan_to_num(image_syn.detach().cpu()))}, step=it) 409 | 410 | if args.ipc < 50 or args.force_save: 411 | upsampled = image_save 412 | if args.dataset != "ImageNet": 413 | upsampled = torch.repeat_interleave(upsampled, repeats=4, dim=2) 414 | upsampled = torch.repeat_interleave(upsampled, repeats=4, dim=3) 415 | grid = torchvision.utils.make_grid(upsampled, nrow=10, normalize=True, scale_each=True) 416 | wandb.log({"Synthetic_Images": wandb.Image(torch.nan_to_num(grid.detach().cpu()))}, step=it) 417 | wandb.log({'Synthetic_Pixels': wandb.Histogram(torch.nan_to_num(image_save.detach().cpu()))}, step=it) 418 | 419 | for clip_val in [2.5]: 420 | std = torch.std(image_save) 421 | mean = torch.mean(image_save) 422 | upsampled = torch.clip(image_save, min=mean-clip_val*std, max=mean+clip_val*std) 423 | if args.dataset != "ImageNet": 424 | upsampled = torch.repeat_interleave(upsampled, repeats=4, dim=2) 425 | upsampled = torch.repeat_interleave(upsampled, repeats=4, dim=3) 426 | grid = torchvision.utils.make_grid(upsampled, nrow=10, normalize=True, scale_each=True) 427 | wandb.log({"Clipped_Synthetic_Images/std_{}".format(clip_val): wandb.Image(torch.nan_to_num(grid.detach().cpu()))}, step=it) 428 | 429 | if args.zca: 430 | image_save = image_save.to(args.device) 431 | image_save = args.zca_trans.inverse_transform(image_save) 432 | image_save.cpu() 433 | torch.save(image_save.cpu(), os.path.join(save_dir, 'Normal', "images_zca_{}.pt".format(it))) 434 | upsampled = image_save 435 | if args.dataset != "ImageNet": 436 | upsampled = torch.repeat_interleave(upsampled, repeats=4, dim=2) 437 | upsampled = torch.repeat_interleave(upsampled, repeats=4, dim=3) 438 | grid = torchvision.utils.make_grid(upsampled, nrow=10, normalize=True, scale_each=True) 439 | wandb.log({"Reconstructed_Images": wandb.Image(torch.nan_to_num(grid.detach().cpu()))}, step=it) 440 | wandb.log({'Reconstructed_Pixels': wandb.Histogram(torch.nan_to_num(image_save.detach().cpu()))}, step=it) 441 | for clip_val in [2.5]: 442 | std = torch.std(image_save) 443 | mean = torch.mean(image_save) 444 | upsampled = torch.clip(image_save, min=mean - clip_val * std, max=mean + clip_val * std) 445 | if args.dataset != "ImageNet": 446 | upsampled = torch.repeat_interleave(upsampled, repeats=4, dim=2) 447 | upsampled = torch.repeat_interleave(upsampled, repeats=4, dim=3) 448 | grid = torchvision.utils.make_grid(upsampled, nrow=10, normalize=True, scale_each=True) 449 | wandb.log({"Clipped_Reconstructed_Images/std_{}".format(clip_val): wandb.Image( 450 | torch.nan_to_num(grid.detach().cpu()))}, step=it) 451 | 452 | 453 | 454 | wandb.log({"Synthetic_LR": syn_lr.detach().cpu()}, step=it) 455 | 456 | student_net = get_network(args.model, channel, num_classes, im_size, dist=False).to(args.device) # get a random model 457 | 458 | student_net = ReparamModule(student_net) 459 | 460 | if args.distributed: 461 | student_net = torch.nn.DataParallel(student_net) 462 | 463 | student_net.train() 464 | 465 | num_params = sum([np.prod(p.size()) for p in (student_net.parameters())]) 466 | 467 | if args.load_all: 468 | expert_trajectory = buffer[np.random.randint(0, len(buffer))] 469 | else: 470 | expert_trajectory = buffer[buffer_id[expert_idx]] 471 | expert_idx += 1 472 | if expert_idx == len(buffer): 473 | expert_idx = 0 474 | file_idx += 1 475 | if file_idx == len(expert_files): 476 | file_idx = 0 477 | random.shuffle(expert_id) 478 | print("loading file {}".format(expert_files[expert_id[file_idx]])) 479 | if args.max_files != 1: 480 | del buffer 481 | buffer = torch.load(expert_files[expert_id[file_idx]]) 482 | if args.max_experts is not None: 483 | buffer = buffer[:args.max_experts] 484 | random.shuffle(buffer_id) 485 | 486 | # Only match easy traj. in the early stage 487 | if args.Sequential_Generation: 488 | Upper_Bound = args.current_max_start_epoch + int((args.max_start_epoch-args.current_max_start_epoch) * it/(args.expansion_end_epoch)) 489 | Upper_Bound = min(Upper_Bound, args.max_start_epoch) 490 | else: 491 | Upper_Bound = args.max_start_epoch 492 | 493 | start_epoch = np.random.randint(args.min_start_epoch, Upper_Bound) 494 | 495 | starting_params = expert_trajectory[start_epoch] 496 | target_params = expert_trajectory[start_epoch+args.expert_epochs] 497 | target_params = torch.cat([p.data.to(args.device).reshape(-1) for p in target_params], 0) 498 | student_params = [torch.cat([p.data.to(args.device).reshape(-1) for p in starting_params], 0).requires_grad_(True)] 499 | starting_params = torch.cat([p.data.to(args.device).reshape(-1) for p in starting_params], 0) 500 | param_dist = torch.nn.functional.mse_loss(starting_params, target_params, reduction="sum") 501 | syn_images = image_syn 502 | y_hat = label_syn 503 | 504 | syn_image_gradients = torch.zeros(image_syn.shape).to(args.device) 505 | syn_label_gradients = torch.zeros(label_syn.shape).to(args.device) 506 | x_list = [] 507 | original_x_list = [] 508 | y_list = [] 509 | original_y_list = [] 510 | indices_chunks = [] 511 | gradient_sum = torch.zeros(student_params[-1].shape).to(args.device) 512 | indices_chunks_copy = [] 513 | 514 | 515 | 516 | 517 | for step in range(args.syn_steps): 518 | if not indices_chunks: 519 | indices = torch.randperm(len(syn_images)) 520 | indices_chunks = list(torch.split(indices, args.batch_syn)) 521 | 522 | these_indices = indices_chunks.pop() 523 | indices_chunks_copy.append(these_indices) 524 | 525 | x = syn_images[these_indices] 526 | this_y = y_hat[these_indices] 527 | original_x_list.append(x) 528 | original_y_list.append(this_y) 529 | if args.dsa and (not args.no_aug): 530 | x = DiffAugment(x, args.dsa_strategy, param=args.dsa_param) 531 | x_list.append(x.clone()) 532 | y_list.append(this_y.clone()) 533 | 534 | if args.distributed: 535 | forward_params = student_params[-1].unsqueeze(0).expand(torch.cuda.device_count(), -1) 536 | else: 537 | forward_params = student_params[-1] 538 | x = student_net(x, flat_param=forward_params) 539 | ce_loss = criterion(x, this_y) 540 | 541 | grad = torch.autograd.grad(ce_loss, forward_params, create_graph=True, retain_graph=True)[0] 542 | 543 | detached_grad = grad.detach().clone() 544 | student_params.append(student_params[-1] - syn_lr.item() * detached_grad) 545 | gradient_sum += detached_grad 546 | 547 | del grad 548 | 549 | # --------Compute the gradients regarding input image and learning rate--------- 550 | # compute gradients invoving 2 gradients 551 | for i in range(args.syn_steps): 552 | # compute gradients for w_i 553 | w_i = student_params[i] 554 | output_i = student_net(x_list[i], flat_param = w_i) 555 | if args.batch_syn: 556 | ce_loss_i = criterion(output_i, y_list[i]) 557 | else: 558 | ce_loss_i = criterion(output_i, y_hat) 559 | 560 | grad_i = torch.autograd.grad(ce_loss_i, w_i, create_graph=True, retain_graph=True)[0] 561 | single_term = syn_lr.item() * (target_params - starting_params) 562 | square_term = (syn_lr.item() ** 2) * gradient_sum 563 | 564 | total_term = 2 * (single_term + square_term) @ grad_i / param_dist 565 | 566 | gradients_x, gradients_y = torch.autograd.grad(total_term, [original_x_list[i], original_y_list[i]] ) 567 | with torch.no_grad(): 568 | syn_image_gradients[indices_chunks_copy[i]] += gradients_x 569 | syn_label_gradients[indices_chunks_copy[i]] += gradients_y 570 | # ---------end of computing input image gradients and learning rates-------------- 571 | 572 | image_syn.grad = syn_image_gradients 573 | label_syn.grad = syn_label_gradients 574 | 575 | grand_loss = starting_params - syn_lr * gradient_sum - target_params 576 | grand_loss = grand_loss.dot(grand_loss) / param_dist 577 | 578 | lr_grad, = torch.autograd.grad(grand_loss, syn_lr) 579 | syn_lr.grad = lr_grad 580 | 581 | if grand_loss<=args.threshold: 582 | optimizer_y.step() 583 | optimizer_img.step() 584 | optimizer_lr.step() 585 | else: 586 | wandb.log({"falts": start_epoch}, step=it) 587 | 588 | 589 | 590 | wandb.log({"Grand_Loss": grand_loss.detach().cpu(), 591 | "Start_Epoch": start_epoch}) 592 | 593 | for _ in student_params: 594 | del _ 595 | 596 | if it%10 == 0: 597 | print('%s iter = %04d, loss = %.4f' % (get_time(), it, grand_loss.item())) 598 | 599 | 600 | wandb.finish() 601 | 602 | 603 | if __name__ == '__main__': 604 | parser = argparse.ArgumentParser(description='Parameter Processing') 605 | 606 | parser.add_argument("--cfg", type=str, default="") 607 | args = parser.parse_args() 608 | 609 | cfg.merge_from_file(args.cfg) 610 | for key, value in cfg.items(): 611 | arg_name = '--' + key 612 | parser.add_argument(arg_name, type=type(value), default=value) 613 | args = parser.parse_args() 614 | main(args) 615 | 616 | 617 | 618 | -------------------------------------------------------------------------------- /distill/baseline.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | sys.path.append("../") 4 | import argparse 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import torchvision.utils 10 | from tqdm import tqdm 11 | from utils.utils_baseline import get_dataset, get_network, get_eval_pool, evaluate_synset, get_time, DiffAugment, ParamDiffAug, epoch, evaluate_baseline, reduce_dataset 12 | import wandb 13 | import copy 14 | import random 15 | from reparam_module import ReparamModule 16 | from torchvision import datasets, transforms 17 | import warnings 18 | from torch.utils.data import DataLoader 19 | warnings.filterwarnings("ignore", category=DeprecationWarning) 20 | 21 | 22 | 23 | def main(args): 24 | 25 | 26 | print("CUDNN STATUS: {}".format(torch.backends.cudnn.enabled)) 27 | 28 | args.dsa = True if args.dsa == 'True' else False 29 | args.device = 'cuda' if torch.cuda.is_available() else 'cpu' 30 | 31 | 32 | channel, im_size, num_classes, class_names, mean, std, dst_train, dst_test, testloader, loader_train_dict, class_map, class_map_inv = get_dataset(args.dataset, args.data_path, args.batch_real, args.subset, args=args,baseline=True) 33 | 34 | dst_train = reduce_dataset(dst_train, rate=args.ipc/500, class_num=num_classes, num_per_class = 500) 35 | train_loader = DataLoader( 36 | dst_train, batch_size=args.batch_train, shuffle=True, num_workers=2 37 | ) 38 | 39 | model_eval_pool = get_eval_pool(args.eval_mode, args.model, args.model) 40 | 41 | im_res = im_size[0] 42 | 43 | args.im_size = im_size 44 | 45 | accs_all_exps = dict() # record performances of all experiments 46 | for key in model_eval_pool: 47 | accs_all_exps[key] = [] 48 | 49 | data_save = [] 50 | 51 | if args.dsa: 52 | # args.epoch_eval_train = 1000 53 | args.dc_aug_param = None 54 | 55 | args.dsa_param = ParamDiffAug() 56 | 57 | dsa_params = args.dsa_param 58 | if args.zca: 59 | zca_trans = args.zca_trans 60 | else: 61 | zca_trans = None 62 | 63 | 64 | args.dsa_param = dsa_params 65 | args.zca_trans = zca_trans 66 | 67 | args.distributed = torch.cuda.device_count() > 1 68 | 69 | print('Hyper-parameters: \n', args.__dict__) 70 | print('Evaluation model pool: ', model_eval_pool) 71 | 72 | 73 | criterion = nn.CrossEntropyLoss().to(args.device) 74 | 75 | 76 | 77 | args.lr_net = torch.tensor(args.lr_teacher).to(args.device) 78 | 79 | 80 | 81 | for model_eval in model_eval_pool: 82 | print('Evaluating: '+model_eval) 83 | network = get_network(model_eval, channel, num_classes, im_size, dist=False).to(args.device) # get a random model 84 | _, acc_train, acc_test = evaluate_baseline(0, copy.deepcopy(network), train_loader, testloader, args, texture=False) 85 | 86 | 87 | 88 | 89 | if __name__ == '__main__': 90 | parser = argparse.ArgumentParser(description='Parameter Processing') 91 | 92 | parser.add_argument('--dataset', type=str, default='CIFAR10', help='dataset') 93 | 94 | parser.add_argument('--subset', type=str, default='imagenette', help='ImageNet subset. This only does anything when --dataset=ImageNet') 95 | 96 | parser.add_argument('--model', type=str, default='ConvNet', help='model') 97 | 98 | parser.add_argument('--ipc', type=int, default=1, help='image(s) per class') 99 | 100 | parser.add_argument('--eval_mode', type=str, default='S', 101 | help='eval_mode, check utils.py for more info') 102 | 103 | parser.add_argument('--num_eval', type=int, default=5, help='how many networks to evaluate on') 104 | 105 | parser.add_argument('--eval_it', type=int, default=100, help='how often to evaluate') 106 | parser.add_argument('--epoch_eval_train', type=int, default=1000, help='epochs to train a model with synthetic data') 107 | parser.add_argument('--Iteration', type=int, default=5000, help='how many distillation steps to perform') 108 | parser.add_argument('--lr_init', type=float, default=0.01, help='how to init lr (alpha)') 109 | parser.add_argument('--batch_real', type=int, default=256, help='batch size for real data') 110 | parser.add_argument('--batch_syn', type=int, default=None, help='should only use this if you run out of VRAM') 111 | parser.add_argument('--batch_train', type=int, default=128, help='batch size for training networks') 112 | parser.add_argument('--dsa', type=str, default='True', choices=['True', 'False'], 113 | help='whether to use differentiable Siamese augmentation.') 114 | parser.add_argument('--dsa_strategy', type=str, default='color_crop_cutout_flip_scale_rotate', 115 | help='differentiable Siamese augmentation strategy') 116 | parser.add_argument('--data_path', type=str, default='data', help='dataset path') 117 | parser.add_argument('--buffer_path', type=str, default='./buffers', help='buffer path') 118 | parser.add_argument('--zca', action='store_true', help="do ZCA whitening") 119 | parser.add_argument('--lr_teacher', type=float, default=0.01, help='initialization for synthetic learning rate') 120 | parser.add_argument('--no_aug', type=bool, default=False, help='this turns off diff aug during distillation') 121 | parser.add_argument('--data_dir', type=str, default='path', help='dataset') 122 | parser.add_argument('--label_dir', type=str, default='path', help='dataset') 123 | parser.add_argument('--lr_dir', type=str, default='path', help='dataset') 124 | parser.add_argument('--parall_eva', type=bool, default=False, help='dataset') 125 | parser.add_argument('--ASL_model', type=str, default=None) 126 | parser.add_argument('--ASL_model_dir', type=str, default=None) 127 | parser.add_argument('--method', type=str, default='') 128 | 129 | args = parser.parse_args() 130 | main(args) 131 | 132 | 133 | -------------------------------------------------------------------------------- /distill/distill_arch.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | sys.path.append("../") 4 | import argparse 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import torchvision.utils 10 | from tqdm import tqdm 11 | from utils.utils_arch import get_dataset, get_network, get_eval_pool, evaluate_synset, get_time, DiffAugment, ParamDiffAug 12 | import wandb 13 | import copy 14 | import random 15 | from reparam_module import ReparamModule 16 | 17 | import warnings 18 | warnings.filterwarnings("ignore", category=DeprecationWarning) 19 | 20 | def main(args): 21 | 22 | if args.zca and args.texture: 23 | raise AssertionError("Cannot use zca and texture together") 24 | 25 | if args.texture and args.pix_init == "real": 26 | print("WARNING: Using texture with real initialization will take a very long time to smooth out the boundaries between images.") 27 | 28 | if args.max_experts is not None and args.max_files is not None: 29 | args.total_experts = args.max_experts * args.max_files 30 | 31 | print("CUDNN STATUS: {}".format(torch.backends.cudnn.enabled)) 32 | 33 | args.dsa = True if args.dsa == 'True' else False 34 | args.device = 'cuda' if torch.cuda.is_available() else 'cpu' 35 | 36 | eval_it_pool = np.arange(0, args.Iteration + 1, args.eval_it).tolist() 37 | channel, im_size, num_classes, class_names, mean, std, dst_train, dst_test, testloader, loader_train_dict, class_map, class_map_inv = get_dataset(args.dataset, args.data_path, args.batch_real, args.subset, args=args) 38 | model_eval_pool = get_eval_pool(args.eval_mode, args.model, args.model) 39 | 40 | im_res = im_size[0] 41 | 42 | args.im_size = im_size 43 | 44 | accs_all_exps = dict() # record performances of all experiments 45 | for key in model_eval_pool: 46 | accs_all_exps[key] = [] 47 | 48 | data_save = [] 49 | 50 | if args.dsa: 51 | # args.epoch_eval_train = 1000 52 | args.dc_aug_param = None 53 | 54 | args.dsa_param = ParamDiffAug() 55 | 56 | dsa_params = args.dsa_param 57 | if args.zca: 58 | zca_trans = args.zca_trans 59 | else: 60 | zca_trans = None 61 | 62 | wandb.init(sync_tensorboard=False, 63 | project="DatasetDistillation", 64 | job_type="CleanRepo", 65 | config=args, 66 | ) 67 | 68 | args = type('', (), {})() 69 | 70 | for key in wandb.config._items: 71 | setattr(args, key, wandb.config._items[key]) 72 | 73 | args.dsa_param = dsa_params 74 | args.zca_trans = zca_trans 75 | 76 | if args.batch_syn is None: 77 | args.batch_syn = num_classes * args.ipc 78 | 79 | args.distributed = torch.cuda.device_count() > 1 80 | 81 | 82 | print('Hyper-parameters: \n', args.__dict__) 83 | print('Evaluation model pool: ', model_eval_pool) 84 | 85 | ''' organize the real dataset ''' 86 | images_all = [] 87 | labels_all = [] 88 | indices_class = [[] for c in range(num_classes)] 89 | print("BUILDING DATASET") 90 | for i in tqdm(range(len(dst_train))): 91 | sample = dst_train[i] 92 | images_all.append(torch.unsqueeze(sample[0], dim=0)) 93 | labels_all.append(class_map[torch.tensor(sample[1]).item()]) 94 | 95 | for i, lab in tqdm(enumerate(labels_all)): 96 | indices_class[lab].append(i) 97 | images_all = torch.cat(images_all, dim=0).to("cpu") 98 | labels_all = torch.tensor(labels_all, dtype=torch.long, device="cpu") 99 | 100 | for c in range(num_classes): 101 | print('class c = %d: %d real images'%(c, len(indices_class[c]))) 102 | 103 | for ch in range(channel): 104 | print('real images channel %d, mean = %.4f, std = %.4f'%(ch, torch.mean(images_all[:, ch]), torch.std(images_all[:, ch]))) 105 | 106 | 107 | def get_images(c, n): # get random n images from class c 108 | idx_shuffle = np.random.permutation(indices_class[c])[:n] 109 | return images_all[idx_shuffle] 110 | 111 | 112 | ''' initialize the synthetic data ''' 113 | label_syn = torch.tensor([np.ones(args.ipc)*i for i in range(num_classes)], dtype=torch.long, requires_grad=False, device=args.device).view(-1) # [0,0,0, 1,1,1, ..., 9,9,9] 114 | 115 | if args.texture: 116 | image_syn = torch.randn(size=(num_classes * args.ipc, channel, im_size[0]*args.canvas_size, im_size[1]*args.canvas_size), dtype=torch.float) 117 | else: 118 | image_syn = torch.randn(size=(num_classes * args.ipc, channel, im_size[0], im_size[1]), dtype=torch.float) 119 | 120 | syn_lr = torch.tensor(args.lr_teacher).to(args.device) 121 | 122 | if args.pix_init == 'real': 123 | print('initialize synthetic data from random real images') 124 | if args.texture: 125 | for c in range(num_classes): 126 | for i in range(args.canvas_size): 127 | for j in range(args.canvas_size): 128 | image_syn.data[c * args.ipc:(c + 1) * args.ipc, :, i * im_size[0]:(i + 1) * im_size[0], 129 | j * im_size[1]:(j + 1) * im_size[1]] = torch.cat( 130 | [get_images(c, 1).detach().data for s in range(args.ipc)]) 131 | for c in range(num_classes): 132 | image_syn.data[c * args.ipc:(c + 1) * args.ipc] = get_images(c, args.ipc).detach().data 133 | else: 134 | print('initialize synthetic data from random noise') 135 | 136 | 137 | ''' training ''' 138 | image_syn = image_syn.detach().to(args.device).requires_grad_(True) 139 | syn_lr = syn_lr.detach().to(args.device).requires_grad_(True) 140 | optimizer_img = torch.optim.SGD([image_syn], lr=args.lr_img, momentum=0.5) 141 | optimizer_lr = torch.optim.SGD([syn_lr], lr=args.lr_lr, momentum=0.5) 142 | optimizer_img.zero_grad() 143 | 144 | criterion = nn.CrossEntropyLoss().to(args.device) 145 | print('%s training begins'%get_time()) 146 | 147 | best_acc = {m: 0 for m in model_eval_pool} 148 | 149 | best_std = {m: 0 for m in model_eval_pool} 150 | 151 | for it in range(0, args.Iteration+1): 152 | save_this_it = False 153 | 154 | # writer.add_scalar('Progress', it, it) 155 | #wandb.log({"Progress": it}, step=it) 156 | ''' Evaluate synthetic data ''' 157 | if it in eval_it_pool: 158 | for model_eval in model_eval_pool: 159 | print('-------------------------\nEvaluation\nmodel_train = %s, model_eval = %s, iteration = %d'%(args.model, model_eval, it)) 160 | if args.dsa: 161 | print('DSA augmentation strategy: \n', args.dsa_strategy) 162 | print('DSA augmentation parameters: \n', args.dsa_param.__dict__) 163 | else: 164 | print('DC augmentation parameters: \n', args.dc_aug_param) 165 | 166 | accs_test = [] 167 | accs_train = [] 168 | for it_eval in range(args.num_eval): 169 | net_eval = get_network(model_eval, channel, num_classes, im_size).to(args.device) # get a random model 170 | 171 | eval_labs = label_syn 172 | with torch.no_grad(): 173 | image_save = image_syn 174 | #image_syn_eval, label_syn_eval = copy.deepcopy(image_save.detach()), copy.deepcopy(eval_labs.detach()) # avoid any unaware modification 175 | 176 | image_syn_eval = torch.load(os.path.join(args.syn_image_path,'images_best_base.pt')) 177 | label_syn_eval = torch.load(os.path.join(args.syn_image_path,'labels_best_base.pt')) 178 | 179 | args.lr_net = 0.04331 ##sam base 180 | #args.lr_net = syn_lr.item() 181 | _, acc_train, acc_test,acc_test_list = evaluate_synset(it_eval, net_eval, image_syn_eval, label_syn_eval, testloader, args, texture=args.texture) 182 | 183 | ################################################################ 184 | print('best',max(acc_test_list)) 185 | print('last',acc_test_list[-1]) 186 | from matplotlib import pyplot as plt 187 | plt.plot(range(0,1001),acc_test_list) 188 | plt.show() 189 | 190 | accs_test.append(acc_test) 191 | accs_train.append(acc_train) 192 | accs_test = np.array(accs_test) 193 | accs_train = np.array(accs_train) 194 | acc_test_mean = np.mean(accs_test) 195 | acc_test_std = np.std(accs_test) 196 | if acc_test_mean > best_acc[model_eval]: 197 | best_acc[model_eval] = acc_test_mean 198 | best_std[model_eval] = acc_test_std 199 | save_this_it = True 200 | print('Evaluate %d random %s, mean = %.4f std = %.4f\n-------------------------'%(len(accs_test), model_eval, acc_test_mean, acc_test_std)) 201 | #wandb.log({'Accuracy/{}'.format(model_eval): acc_test_mean}, step=it) 202 | #wandb.log({'Max_Accuracy/{}'.format(model_eval): best_acc[model_eval]}, step=it) 203 | #wandb.log({'Std/{}'.format(model_eval): acc_test_std}, step=it) 204 | #wandb.log({'Max_Std/{}'.format(model_eval): best_std[model_eval]}, step=it) 205 | 206 | 207 | 208 | 209 | if __name__ == '__main__': 210 | parser = argparse.ArgumentParser(description='Parameter Processing') 211 | 212 | parser.add_argument('--dataset', type=str, default='CIFAR10', help='dataset') 213 | 214 | parser.add_argument('--subset', type=str, default='imagenette', help='ImageNet subset. This only does anything when --dataset=ImageNet') 215 | 216 | parser.add_argument('--model', type=str, default='VGG11', help='model') 217 | 218 | parser.add_argument('--ipc', type=int, default=1, help='image(s) per class') 219 | 220 | parser.add_argument('--eval_mode', type=str, default='S', 221 | help='eval_mode, check utils.py for more info') 222 | 223 | parser.add_argument('--num_eval', type=int, default=1, help='how many networks to evaluate on') 224 | 225 | parser.add_argument('--eval_it', type=int, default=100, help='how often to evaluate') 226 | 227 | parser.add_argument('--epoch_eval_train', type=int, default=1000, help='epochs to train a model with synthetic data') 228 | parser.add_argument('--Iteration', type=int, default=5000, help='how many distillation steps to perform') 229 | 230 | parser.add_argument('--lr_img', type=float, default=100, help='learning rate for updating synthetic images') 231 | parser.add_argument('--lr_lr', type=float, default=1e-07, help='learning rate for updating... learning rate') 232 | parser.add_argument('--lr_teacher', type=float, default=0.01, help='initialization for synthetic learning rate') 233 | 234 | parser.add_argument('--lr_init', type=float, default=0.01, help='how to init lr (alpha)') 235 | 236 | parser.add_argument('--batch_real', type=int, default=256, help='batch size for real data') 237 | parser.add_argument('--batch_syn', type=int, default=None, help='should only use this if you run out of VRAM') 238 | parser.add_argument('--batch_train', type=int, default=256, help='batch size for training networks') 239 | 240 | parser.add_argument('--pix_init', type=str, default='real', choices=["noise", "real"], 241 | help='noise/real: initialize synthetic images from random noise or randomly sampled real images.') 242 | 243 | parser.add_argument('--dsa', type=str, default='True', choices=['True', 'False'], 244 | help='whether to use differentiable Siamese augmentation.') 245 | 246 | parser.add_argument('--dsa_strategy', type=str, default='color_crop_cutout_flip_scale_rotate', 247 | help='differentiable Siamese augmentation strategy') 248 | 249 | parser.add_argument('--data_path', type=str, default='../dataset/', help='dataset path') 250 | parser.add_argument('--buffer_path', type=str, default='../buffer_storage/sam_rho0.02/', help='buffer path') 251 | 252 | parser.add_argument('--expert_epochs', type=int, default=2, help='how many expert epochs the target params are') 253 | parser.add_argument('--syn_steps', type=int, default=50, help='how many steps to take on synthetic data') 254 | parser.add_argument('--max_start_epoch', type=int, default=2, help='max epoch we can start at') 255 | 256 | #parser.add_argument('--zca', action='store_true', help="do ZCA whitening") 257 | parser.add_argument('--zca',default=True, help="do ZCA whitening") 258 | 259 | parser.add_argument('--load_all', action='store_true', help="only use if you can fit all expert trajectories into RAM") 260 | 261 | parser.add_argument('--no_aug', type=bool, default=False, help='this turns off diff aug during distillation') 262 | 263 | parser.add_argument('--texture', action='store_true', help="will distill textures instead") 264 | parser.add_argument('--canvas_size', type=int, default=2, help='size of synthetic canvas') 265 | parser.add_argument('--canvas_samples', type=int, default=1, help='number of canvas samples per iteration') 266 | 267 | 268 | parser.add_argument('--max_files', type=int, default=None, help='number of expert files to read (leave as None unless doing ablations)') 269 | parser.add_argument('--max_experts', type=int, default=None, help='number of experts to read per file (leave as None unless doing ablations)') 270 | 271 | parser.add_argument('--force_save', action='store_true', help='this will save images for 50ipc') 272 | parser.add_argument('--syn_image_path',type=str,default='./logged_files/CIFAR10/cifar10_50ipc/') 273 | 274 | #args = parser.parse_args() 275 | args, unknown = parser.parse_known_args() 276 | 277 | main(args) 278 | 279 | 280 | -------------------------------------------------------------------------------- /distill/evaluation.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | sys.path.append("../") 4 | import argparse 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import torchvision.utils 10 | from tqdm import tqdm 11 | from utils.utils_baseline import get_dataset, get_network, get_eval_pool, evaluate_synset, ParamDiffAug 12 | import copy 13 | from reparam_module import ReparamModule 14 | 15 | import warnings 16 | warnings.filterwarnings("ignore", category=DeprecationWarning) 17 | 18 | 19 | 20 | def main(args): 21 | 22 | 23 | print("CUDNN STATUS: {}".format(torch.backends.cudnn.enabled)) 24 | 25 | args.dsa = True if args.dsa == 'True' else False 26 | args.device = 'cuda' if torch.cuda.is_available() else 'cpu' 27 | 28 | 29 | channel, im_size, num_classes, class_names, mean, std, dst_train, dst_test, testloader, loader_train_dict, class_map, class_map_inv = get_dataset(args.dataset, args.data_path, args.batch_real, args.subset, args=args) 30 | model_eval_pool = get_eval_pool(args.eval_mode, args.model, args.model) 31 | 32 | args.im_size = im_size 33 | data_save = [] 34 | 35 | if args.dsa: 36 | # args.epoch_eval_train = 1000 37 | args.dc_aug_param = None 38 | 39 | args.dsa_param = ParamDiffAug() 40 | 41 | dsa_params = args.dsa_param 42 | if args.zca: 43 | zca_trans = args.zca_trans 44 | else: 45 | zca_trans = None 46 | 47 | 48 | args.dsa_param = dsa_params 49 | args.zca_trans = zca_trans 50 | 51 | args.distributed = torch.cuda.device_count() > 1 52 | 53 | print('Hyper-parameters: \n', args.__dict__) 54 | print('Evaluation model pool: ', model_eval_pool) 55 | 56 | def SoftCrossEntropy(inputs, target, reduction='average'): 57 | input_log_likelihood = -F.log_softmax(inputs, dim=1) 58 | target_log_likelihood = F.softmax(target, dim=1) 59 | batch = inputs.shape[0] 60 | loss = torch.sum(torch.mul(input_log_likelihood, target_log_likelihood)) / batch 61 | return loss 62 | 63 | soft_cri = SoftCrossEntropy 64 | 65 | image_syn_eval = torch.load(args.data_dir) 66 | label_syn_eval = torch.load(args.label_dir) 67 | args.lr_net = torch.load(args.lr_dir) 68 | 69 | for model_eval in model_eval_pool: 70 | print('Evaluating: '+model_eval) 71 | network = get_network(model_eval, channel, num_classes, im_size, dist=False).to(args.device) # get a random model 72 | _, acc_train, acc_test = evaluate_synset(0, copy.deepcopy(network), image_syn_eval, label_syn_eval, testloader, args, texture=False, train_criterion=soft_cri) 73 | 74 | 75 | 76 | 77 | if __name__ == '__main__': 78 | parser = argparse.ArgumentParser(description='Parameter Processing') 79 | 80 | parser.add_argument('--dataset', type=str, default='CIFAR10', help='dataset') 81 | 82 | parser.add_argument('--subset', type=str, default='imagenette', help='ImageNet subset. This only does anything when --dataset=ImageNet') 83 | 84 | parser.add_argument('--model', type=str, default='ConvNet', help='model') 85 | 86 | parser.add_argument('--eval_mode', type=str, default='S', 87 | help='eval_mode, check utils.py for more info') 88 | 89 | parser.add_argument('--epoch_eval_train', type=int, default=1000, help='epochs to train a model with synthetic data') 90 | parser.add_argument('--batch_real', type=int, default=256, help='batch size for real data') 91 | parser.add_argument('--dsa', type=str, default='True', choices=['True', 'False'], 92 | help='whether to use differentiable Siamese augmentation.') 93 | parser.add_argument('--dsa_strategy', type=str, default='color_crop_cutout_flip_scale_rotate', 94 | help='differentiable Siamese augmentation strategy') 95 | parser.add_argument('--data_path', type=str, default='data', help='dataset path') 96 | parser.add_argument('--zca', action='store_true', help="do ZCA whitening") 97 | parser.add_argument('--lr_teacher', type=float, default=0.01, help='initialization for synthetic learning rate') 98 | parser.add_argument('--no_aug', type=bool, default=False, help='this turns off diff aug during distillation') 99 | parser.add_argument('--batch_train', type=int, default=128, help='batch size for training networks') 100 | 101 | parser.add_argument('--parall_eva', type=bool, default=False, help='dataset') 102 | 103 | parser.add_argument('--data_dir', type=str, default='path', help='dataset') 104 | parser.add_argument('--label_dir', type=str, default='path', help='dataset') 105 | parser.add_argument('--lr_dir', type=str, default='path', help='dataset') 106 | 107 | args = parser.parse_args() 108 | main(args) 109 | -------------------------------------------------------------------------------- /distill/model_ema.py: -------------------------------------------------------------------------------- 1 | """ Exponential Moving Average (EMA) of model updates 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | import logging 6 | from collections import OrderedDict 7 | from copy import deepcopy 8 | 9 | import torch 10 | import torch.nn as nn 11 | 12 | _logger = logging.getLogger(__name__) 13 | 14 | 15 | class ModelEma: 16 | """ Model Exponential Moving Average (DEPRECATED) 17 | 18 | Keep a moving average of everything in the model state_dict (parameters and buffers). 19 | This version is deprecated, it does not work with scripted models. Will be removed eventually. 20 | 21 | This is intended to allow functionality like 22 | https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage 23 | 24 | A smoothed version of the weights is necessary for some training schemes to perform well. 25 | E.g. Google's hyper-params for training MNASNet, MobileNet-V3, EfficientNet, etc that use 26 | RMSprop with a short 2.4-3 epoch decay period and slow LR decay rate of .96-.99 requires EMA 27 | smoothing of weights to match results. Pay attention to the decay constant you are using 28 | relative to your update count per epoch. 29 | 30 | To keep EMA from using GPU resources, set device='cpu'. This will save a bit of memory but 31 | disable validation of the EMA weights. Validation will have to be done manually in a separate 32 | process, or after the training stops converging. 33 | 34 | This class is sensitive where it is initialized in the sequence of model init, 35 | GPU assignment and distributed training wrappers. 36 | """ 37 | def __init__(self, model, decay=0.9999, device='', resume=''): 38 | # make a copy of the model for accumulating moving average of weights 39 | self.ema = deepcopy(model) 40 | self.ema.eval() 41 | self.decay = decay 42 | self.device = device # perform ema on different device from model if set 43 | if device: 44 | self.ema.to(device=device) 45 | self.ema_has_module = hasattr(self.ema, 'module') 46 | if resume: 47 | self._load_checkpoint(resume) 48 | for p in self.ema.parameters(): 49 | p.requires_grad_(False) 50 | 51 | def _load_checkpoint(self, checkpoint_path): 52 | checkpoint = torch.load(checkpoint_path, map_location='cpu') 53 | assert isinstance(checkpoint, dict) 54 | if 'state_dict_ema' in checkpoint: 55 | new_state_dict = OrderedDict() 56 | for k, v in checkpoint['state_dict_ema'].items(): 57 | # ema model may have been wrapped by DataParallel, and need module prefix 58 | if self.ema_has_module: 59 | name = 'module.' + k if not k.startswith('module') else k 60 | else: 61 | name = k 62 | new_state_dict[name] = v 63 | self.ema.load_state_dict(new_state_dict) 64 | _logger.info("Loaded state_dict_ema") 65 | else: 66 | _logger.warning("Failed to find state_dict_ema, starting from loaded model weights") 67 | 68 | def update(self, model): 69 | # correct a mismatch in state dict keys 70 | needs_module = hasattr(model, 'module') and not self.ema_has_module 71 | with torch.no_grad(): 72 | msd = model.state_dict() 73 | for k, ema_v in self.ema.state_dict().items(): 74 | if needs_module: 75 | k = 'module.' + k 76 | model_v = msd[k].detach() 77 | if self.device: 78 | model_v = model_v.to(device=self.device) 79 | ema_v.copy_(ema_v * self.decay + (1. - self.decay) * model_v) 80 | 81 | 82 | class ModelEmaV2(nn.Module): 83 | """ Model Exponential Moving Average V2 84 | 85 | Keep a moving average of everything in the model state_dict (parameters and buffers). 86 | V2 of this module is simpler, it does not match params/buffers based on name but simply 87 | iterates in order. It works with torchscript (JIT of full model). 88 | 89 | This is intended to allow functionality like 90 | https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage 91 | 92 | A smoothed version of the weights is necessary for some training schemes to perform well. 93 | E.g. Google's hyper-params for training MNASNet, MobileNet-V3, EfficientNet, etc that use 94 | RMSprop with a short 2.4-3 epoch decay period and slow LR decay rate of .96-.99 requires EMA 95 | smoothing of weights to match results. Pay attention to the decay constant you are using 96 | relative to your update count per epoch. 97 | 98 | To keep EMA from using GPU resources, set device='cpu'. This will save a bit of memory but 99 | disable validation of the EMA weights. Validation will have to be done manually in a separate 100 | process, or after the training stops converging. 101 | 102 | This class is sensitive where it is initialized in the sequence of model init, 103 | GPU assignment and distributed training wrappers. 104 | """ 105 | def __init__(self, params, decay=0.9999, device=None,updating_list=None): 106 | super(ModelEmaV2, self).__init__() 107 | # make a copy of the params for accumulating moving average of weights 108 | self.module = deepcopy(params) 109 | #self.module.eval() 110 | self.decay = decay 111 | self.updating_list = updating_list 112 | self.device = device # perform ema on different device from params if set 113 | if self.device is not None: 114 | self.module.to(device=device) 115 | 116 | def _update(self, params, update_fn): 117 | with torch.no_grad(): 118 | count = 0 119 | for ema_v, params_v in zip(self.module, params): 120 | if self.updating_list is not None and ema_v.dim() > 1: 121 | count += 1 122 | if not count in self.updating_list: 123 | decay = 0 124 | else: 125 | decay = self.decay 126 | # this is for the updating part EMA use, comment it for using set method 127 | # update_fn = lambda e, m: decay * e + (1. - decay) * m 128 | if self.device is not None: 129 | params_v = params_v.to(device=self.device) 130 | ema_v.copy_(update_fn(ema_v, params_v)) 131 | 132 | def update(self, params): 133 | self._update(params, update_fn=lambda e, m: self.decay * e + (1. - self.decay) * m) 134 | 135 | def set(self, params): 136 | self._update(params, update_fn=lambda e, m: m) 137 | -------------------------------------------------------------------------------- /environment.yaml: -------------------------------------------------------------------------------- 1 | name: distillation 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | - defaults 6 | dependencies: 7 | - _libgcc_mutex=0.1=main 8 | - _openmp_mutex=4.5=1_gnu 9 | - anyio=3.6.1=py39hf3d152e_0 10 | - argon2-cffi=21.3.0=pyhd8ed1ab_0 11 | - argon2-cffi-bindings=21.2.0=py39h7f8727e_0 12 | - asttokens=2.0.5=pyhd8ed1ab_0 13 | - attrs=22.1.0=pyh71513ae_1 14 | - babel=2.10.3=pyhd8ed1ab_0 15 | - backcall=0.2.0=pyh9f0ad1d_0 16 | - backports=1.0=py_2 17 | - backports.functools_lru_cache=1.6.4=pyhd8ed1ab_0 18 | - beautifulsoup4=4.11.1=pyha770c72_0 19 | - blas=1.0=mkl 20 | - bleach=5.0.1=pyhd8ed1ab_0 21 | - brotlipy=0.7.0=py39h27cfd23_1003 22 | - bzip2=1.0.8=h7b6447c_0 23 | - ca-certificates=2022.6.15=ha878542_0 24 | - certifi=2022.6.15=py39hf3d152e_0 25 | - cffi=1.15.0=py39hd667e15_1 26 | - charset-normalizer=2.0.4=pyhd3eb1b0_0 27 | - click=7.1.2=pyh9f0ad1d_0 28 | - colorama=0.4.4=pyh9f0ad1d_0 29 | - configparser=5.2.0=pyhd8ed1ab_0 30 | - cryptography=36.0.0=py39h9ce1e76_0 31 | - cudatoolkit=11.3.1=h2bc3f7f_2 32 | - dataclasses=0.8=pyhc8e2a94_3 33 | - decorator=5.1.1=pyhd8ed1ab_0 34 | - defusedxml=0.7.1=pyhd8ed1ab_0 35 | - docker-pycreds=0.4.0=py_0 36 | - entrypoints=0.4=pyhd8ed1ab_0 37 | - executing=0.9.1=pyhd8ed1ab_0 38 | - ffmpeg=4.3=hf484d3e_0 39 | - flit-core=3.7.1=pyhd8ed1ab_0 40 | - freetype=2.11.0=h70c0345_0 41 | - giflib=5.2.1=h7b6447c_0 42 | - gitdb=4.0.9=pyhd8ed1ab_0 43 | - gitpython=3.1.27=pyhd8ed1ab_0 44 | - gmp=6.2.1=h2531618_2 45 | - gnutls=3.6.15=he1e5248_0 46 | - idna=3.3=pyhd3eb1b0_0 47 | - importlib-metadata=4.11.4=py39hf3d152e_0 48 | - importlib_metadata=4.11.4=hd8ed1ab_0 49 | - importlib_resources=5.9.0=pyhd8ed1ab_0 50 | - intel-openmp=2021.4.0=h06a4308_3561 51 | - ipykernel=5.5.5=py39hef51801_0 52 | - ipython=8.4.0=py39hf3d152e_0 53 | - ipython_genutils=0.2.0=py_1 54 | - jedi=0.18.1=py39hf3d152e_1 55 | - jinja2=3.1.2=pyhd8ed1ab_1 56 | - jpeg=9d=h7f8727e_0 57 | - json5=0.9.5=pyh9f0ad1d_0 58 | - jsonschema=4.9.0=pyhd8ed1ab_0 59 | - jupyter_client=7.0.6=pyhd8ed1ab_0 60 | - jupyter_core=4.11.1=py39hf3d152e_0 61 | - jupyter_server=1.18.1=pyhd8ed1ab_0 62 | - jupyterlab=3.4.4=pyhd8ed1ab_0 63 | - jupyterlab_pygments=0.2.2=pyhd8ed1ab_0 64 | - jupyterlab_server=2.15.0=pyhd8ed1ab_0 65 | - kornia=0.6.3=pyhd8ed1ab_0 66 | - lame=3.100=h7b6447c_0 67 | - lcms2=2.12=h3be6417_0 68 | - ld_impl_linux-64=2.35.1=h7274673_9 69 | - libffi=3.3=he6710b0_2 70 | - libgcc-ng=9.3.0=h5101ec6_17 71 | - libgfortran-ng=7.5.0=ha8ba4b0_17 72 | - libgfortran4=7.5.0=ha8ba4b0_17 73 | - libgomp=9.3.0=h5101ec6_17 74 | - libiconv=1.15=h63c8f33_5 75 | - libidn2=2.3.2=h7f8727e_0 76 | - libpng=1.6.37=hbc83047_0 77 | - libprotobuf=3.15.8=h780b84a_0 78 | - libsodium=1.0.18=h36c2ea0_1 79 | - libstdcxx-ng=9.3.0=hd4cf53a_17 80 | - libtasn1=4.16.0=h27cfd23_0 81 | - libtiff=4.2.0=h85742a9_0 82 | - libunistring=0.9.10=h27cfd23_0 83 | - libuv=1.40.0=h7b6447c_0 84 | - libwebp=1.2.2=h55f646e_0 85 | - libwebp-base=1.2.2=h7f8727e_0 86 | - lz4-c=1.9.3=h295c915_1 87 | - markupsafe=2.1.1=py39h7f8727e_0 88 | - matplotlib-inline=0.1.3=pyhd8ed1ab_0 89 | - mistune=0.8.4=py39h3811e60_1004 90 | - mkl=2021.4.0=h06a4308_640 91 | - mkl-service=2.4.0=py39h7f8727e_0 92 | - mkl_fft=1.3.1=py39hd3c417c_0 93 | - mkl_random=1.2.2=py39h51133e4_0 94 | - nbclassic=0.4.3=pyhd8ed1ab_0 95 | - nbclient=0.6.6=pyhd8ed1ab_0 96 | - nbconvert=6.5.0=pyhd8ed1ab_0 97 | - nbconvert-core=6.5.0=pyhd8ed1ab_0 98 | - nbconvert-pandoc=6.5.0=pyhd8ed1ab_0 99 | - nbformat=5.4.0=pyhd8ed1ab_0 100 | - ncurses=6.3=h7f8727e_2 101 | - nest-asyncio=1.5.5=pyhd8ed1ab_0 102 | - nettle=3.7.3=hbbd107a_1 103 | - notebook=6.4.12=pyha770c72_0 104 | - notebook-shim=0.1.0=pyhd8ed1ab_0 105 | - numpy=1.21.2=py39h20f2e39_0 106 | - numpy-base=1.21.2=py39h79a1101_0 107 | - openh264=2.1.1=h4ff587b_0 108 | - openssl=1.1.1q=h7f8727e_0 109 | - packaging=21.3=pyhd8ed1ab_0 110 | - pandoc=2.18=ha770c72_0 111 | - pandocfilters=1.5.0=pyhd8ed1ab_0 112 | - parso=0.8.3=pyhd8ed1ab_0 113 | - pathtools=0.1.2=py_1 114 | - pexpect=4.8.0=pyh9f0ad1d_2 115 | - pickleshare=0.7.5=py_1003 116 | - pillow=9.0.1=py39h22f2fdc_0 117 | - pip=21.2.4=py39h06a4308_0 118 | - pkgutil-resolve-name=1.3.10=pyhd8ed1ab_0 119 | - prometheus_client=0.14.1=pyhd8ed1ab_0 120 | - promise=2.3=py39hf3d152e_5 121 | - prompt-toolkit=3.0.30=pyha770c72_0 122 | - protobuf=3.15.8=py39he80948d_0 123 | - psutil=5.8.0=py39h27cfd23_1 124 | - ptyprocess=0.7.0=pyhd3deb0d_0 125 | - pure_eval=0.2.2=pyhd8ed1ab_0 126 | - pycparser=2.21=pyhd3eb1b0_0 127 | - pygments=2.12.0=pyhd8ed1ab_0 128 | - pyopenssl=22.0.0=pyhd3eb1b0_0 129 | - pyparsing=3.0.7=pyhd8ed1ab_0 130 | - pyrsistent=0.18.0=py39heee7806_0 131 | - pysocks=1.7.1=py39h06a4308_0 132 | - python=3.9.7=h12debd9_1 133 | - python-dateutil=2.8.2=pyhd8ed1ab_0 134 | - python-fastjsonschema=2.16.1=pyhd8ed1ab_0 135 | - python_abi=3.9=2_cp39 136 | - pytorch=1.11.0=py3.9_cuda11.3_cudnn8.2.0_0 137 | - pytorch-mutex=1.0=cuda 138 | - pytz=2022.1=pyhd8ed1ab_0 139 | - pyyaml=5.4.1=py39h3811e60_0 140 | - pyzmq=19.0.2=py39hb69f2a1_2 141 | - readline=8.1.2=h7f8727e_1 142 | - requests=2.27.1=pyhd3eb1b0_0 143 | - scipy=1.7.3=py39hc147768_0 144 | - send2trash=1.8.0=pyhd8ed1ab_0 145 | - sentry-sdk=1.5.7=pyhd8ed1ab_0 146 | - setproctitle=1.2.2=py39h3811e60_0 147 | - setuptools=58.0.4=py39h06a4308_0 148 | - shortuuid=1.0.8=py39hf3d152e_0 149 | - six=1.16.0=pyhd3eb1b0_1 150 | - smmap=3.0.5=pyh44b312d_0 151 | - sniffio=1.2.0=py39hf3d152e_3 152 | - soupsieve=2.3.2.post1=pyhd8ed1ab_0 153 | - sqlite=3.38.0=hc218d9a_0 154 | - stack_data=0.3.0=pyhd8ed1ab_0 155 | - termcolor=1.1.0=py_2 156 | - terminado=0.15.0=py39hf3d152e_0 157 | - tinycss2=1.1.1=pyhd8ed1ab_0 158 | - tk=8.6.11=h1ccaba5_0 159 | - torchaudio=0.11.0=py39_cu113 160 | - torchvision=0.12.0=py39_cu113 161 | - tornado=6.1=py39h3811e60_1 162 | - tqdm=4.63.0=pyhd8ed1ab_0 163 | - traitlets=5.3.0=pyhd8ed1ab_0 164 | - typing_extensions=3.10.0.2=pyh06a4308_0 165 | - tzdata=2021e=hda174b7_0 166 | - urllib3=1.26.8=pyhd3eb1b0_0 167 | - wandb=0.12.11=pyhd8ed1ab_0 168 | - wcwidth=0.2.5=pyh9f0ad1d_2 169 | - webencodings=0.5.1=py_1 170 | - websocket-client=1.3.3=pyhd8ed1ab_0 171 | - wheel=0.37.1=pyhd3eb1b0_0 172 | - xz=5.2.5=h7b6447c_0 173 | - yaml=0.2.5=h516909a_0 174 | - yaspin=2.1.0=pyhd8ed1ab_0 175 | - zeromq=4.3.4=h9c3ff4c_0 176 | - zipp=3.8.0=pyhd8ed1ab_0 177 | - zlib=1.2.11=h7f8727e_4 178 | - zstd=1.4.9=haebb681_0 179 | - pip: 180 | - cycler==0.11.0 181 | - efficientnet-pytorch==0.7.1 182 | - filelock==3.9.0 183 | - fonttools==4.34.4 184 | - huggingface-hub==0.11.1 185 | - kiwisolver==1.4.4 186 | - matplotlib==3.5.2 187 | - prettytable==3.6.0 188 | - timm==0.6.12 189 | -------------------------------------------------------------------------------- /figures/comparison.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NUS-HPC-AI-Lab/DATM/97ae5109d98e749e897bd1af29bccd0515f8137e/figures/comparison.png -------------------------------------------------------------------------------- /figures/visualization.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NUS-HPC-AI-Lab/DATM/97ae5109d98e749e897bd1af29bccd0515f8137e/figures/visualization.png -------------------------------------------------------------------------------- /figures/visualization_ipc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NUS-HPC-AI-Lab/DATM/97ae5109d98e749e897bd1af29bccd0515f8137e/figures/visualization_ipc.png -------------------------------------------------------------------------------- /networks.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | import torch 4 | # Acknowledgement to 5 | # https://github.com/kuangliu/pytorch-cifar, 6 | # https://github.com/BIGBALLON/CIFAR-ZOO, 7 | 8 | # adapted from 9 | # https://github.com/VICO-UoE/DatasetCondensation 10 | 11 | ''' MLP ''' 12 | class MLP(nn.Module): 13 | def __init__(self, channel, num_classes, res=32): 14 | super(MLP, self).__init__() 15 | self.fc_1 = nn.Linear(28*28*1 if channel==1 else res*res*3, 128) 16 | self.fc_2 = nn.Linear(128, 128) 17 | self.fc_3 = nn.Linear(128, num_classes) 18 | 19 | def forward(self, x): 20 | out = x.view(x.size(0), -1) 21 | out = F.relu(self.fc_1(out)) 22 | out = F.relu(self.fc_2(out)) 23 | out = self.fc_3(out) 24 | return out 25 | 26 | 27 | 28 | ''' ConvNet ''' 29 | class ConvNet(nn.Module): 30 | def __init__(self, channel, num_classes, net_width, net_depth, net_act, net_norm, net_pooling, im_size = (32,32)): 31 | super(ConvNet, self).__init__() 32 | 33 | self.features, shape_feat = self._make_layers(channel, net_width, net_depth, net_norm, net_act, net_pooling, im_size) 34 | num_feat = shape_feat[0]*shape_feat[1]*shape_feat[2] 35 | self.classifier = nn.Linear(num_feat, num_classes) 36 | 37 | def forward(self, x): 38 | # print("MODEL DATA ON: ", x.get_device(), "MODEL PARAMS ON: ", self.classifier.weight.data.get_device()) 39 | out = self.features(x) 40 | out = out.view(out.size(0), -1) 41 | out = self.classifier(out) 42 | return out 43 | 44 | def _get_activation(self, net_act): 45 | if net_act == 'sigmoid': 46 | return nn.Sigmoid() 47 | elif net_act == 'relu': 48 | return nn.ReLU(inplace=True) 49 | elif net_act == 'leakyrelu': 50 | return nn.LeakyReLU(negative_slope=0.01) 51 | else: 52 | exit('unknown activation function: %s'%net_act) 53 | 54 | def _get_pooling(self, net_pooling): 55 | if net_pooling == 'maxpooling': 56 | return nn.MaxPool2d(kernel_size=2, stride=2) 57 | elif net_pooling == 'avgpooling': 58 | return nn.AvgPool2d(kernel_size=2, stride=2) 59 | elif net_pooling == 'none': 60 | return None 61 | else: 62 | exit('unknown net_pooling: %s'%net_pooling) 63 | 64 | def _get_normlayer(self, net_norm, shape_feat): 65 | # shape_feat = (c*h*w) 66 | if net_norm == 'batchnorm': 67 | return nn.BatchNorm2d(shape_feat[0], affine=True) 68 | elif net_norm == 'layernorm': 69 | return nn.LayerNorm(shape_feat, elementwise_affine=True) 70 | elif net_norm == 'instancenorm': 71 | return nn.GroupNorm(shape_feat[0], shape_feat[0], affine=True) 72 | elif net_norm == 'groupnorm': 73 | return nn.GroupNorm(4, shape_feat[0], affine=True) 74 | elif net_norm == 'none': 75 | return None 76 | else: 77 | exit('unknown net_norm: %s'%net_norm) 78 | 79 | def _make_layers(self, channel, net_width, net_depth, net_norm, net_act, net_pooling, im_size): 80 | layers = [] 81 | in_channels = channel 82 | if im_size[0] == 28: 83 | im_size = (32, 32) 84 | shape_feat = [in_channels, im_size[0], im_size[1]] 85 | for d in range(net_depth): 86 | layers += [nn.Conv2d(in_channels, net_width, kernel_size=3, padding=3 if channel == 1 and d == 0 else 1)] 87 | shape_feat[0] = net_width 88 | if net_norm != 'none': 89 | layers += [self._get_normlayer(net_norm, shape_feat)] 90 | layers += [self._get_activation(net_act)] 91 | in_channels = net_width 92 | if net_pooling != 'none': 93 | layers += [self._get_pooling(net_pooling)] 94 | shape_feat[1] //= 2 95 | shape_feat[2] //= 2 96 | 97 | 98 | return nn.Sequential(*layers), shape_feat 99 | 100 | 101 | ''' ConvNet ''' 102 | class ConvNetGAP(nn.Module): 103 | def __init__(self, channel, num_classes, net_width, net_depth, net_act, net_norm, net_pooling, im_size = (32,32)): 104 | super(ConvNetGAP, self).__init__() 105 | 106 | self.features, shape_feat = self._make_layers(channel, net_width, net_depth, net_norm, net_act, net_pooling, im_size) 107 | num_feat = shape_feat[0]*shape_feat[1]*shape_feat[2] 108 | # self.classifier = nn.Linear(num_feat, num_classes) 109 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 110 | self.classifier = nn.Linear(shape_feat[0], num_classes) 111 | 112 | def forward(self, x): 113 | out = self.features(x) 114 | out = self.avgpool(out) 115 | out = out.view(out.size(0), -1) 116 | out = self.classifier(out) 117 | return out 118 | 119 | def _get_activation(self, net_act): 120 | if net_act == 'sigmoid': 121 | return nn.Sigmoid() 122 | elif net_act == 'relu': 123 | return nn.ReLU(inplace=True) 124 | elif net_act == 'leakyrelu': 125 | return nn.LeakyReLU(negative_slope=0.01) 126 | else: 127 | exit('unknown activation function: %s'%net_act) 128 | 129 | def _get_pooling(self, net_pooling): 130 | if net_pooling == 'maxpooling': 131 | return nn.MaxPool2d(kernel_size=2, stride=2) 132 | elif net_pooling == 'avgpooling': 133 | return nn.AvgPool2d(kernel_size=2, stride=2) 134 | elif net_pooling == 'none': 135 | return None 136 | else: 137 | exit('unknown net_pooling: %s'%net_pooling) 138 | 139 | def _get_normlayer(self, net_norm, shape_feat): 140 | # shape_feat = (c*h*w) 141 | if net_norm == 'batchnorm': 142 | return nn.BatchNorm2d(shape_feat[0], affine=True) 143 | elif net_norm == 'layernorm': 144 | return nn.LayerNorm(shape_feat, elementwise_affine=True) 145 | elif net_norm == 'instancenorm': 146 | return nn.GroupNorm(shape_feat[0], shape_feat[0], affine=True) 147 | elif net_norm == 'groupnorm': 148 | return nn.GroupNorm(4, shape_feat[0], affine=True) 149 | elif net_norm == 'none': 150 | return None 151 | else: 152 | exit('unknown net_norm: %s'%net_norm) 153 | 154 | def _make_layers(self, channel, net_width, net_depth, net_norm, net_act, net_pooling, im_size): 155 | layers = [] 156 | in_channels = channel 157 | if im_size[0] == 28: 158 | im_size = (32, 32) 159 | shape_feat = [in_channels, im_size[0], im_size[1]] 160 | for d in range(net_depth): 161 | layers += [nn.Conv2d(in_channels, net_width, kernel_size=3, padding=3 if channel == 1 and d == 0 else 1)] 162 | shape_feat[0] = net_width 163 | if net_norm != 'none': 164 | layers += [self._get_normlayer(net_norm, shape_feat)] 165 | layers += [self._get_activation(net_act)] 166 | in_channels = net_width 167 | if net_pooling != 'none': 168 | layers += [self._get_pooling(net_pooling)] 169 | shape_feat[1] //= 2 170 | shape_feat[2] //= 2 171 | 172 | return nn.Sequential(*layers), shape_feat 173 | 174 | 175 | ''' LeNet ''' 176 | class LeNet(nn.Module): 177 | def __init__(self, channel, num_classes, res=32): 178 | super(LeNet, self).__init__() 179 | self.features = nn.Sequential( 180 | nn.Conv2d(channel, 6, kernel_size=5, padding=2 if channel==1 else 0, stride=1 if res==32 else 2), 181 | nn.ReLU(inplace=True), 182 | nn.MaxPool2d(kernel_size=2, stride=2), 183 | nn.Conv2d(6, 16, kernel_size=5), 184 | nn.ReLU(inplace=True), 185 | nn.MaxPool2d(kernel_size=2, stride=2), 186 | ) 187 | self.fc_1 = nn.Linear(16 * 5 * 5, 120) 188 | self.fc_2 = nn.Linear(120, 84) 189 | self.fc_3 = nn.Linear(84, num_classes) 190 | 191 | def forward(self, x): 192 | x = self.features(x) 193 | x = x.view(x.size(0), -1) 194 | x = F.relu(self.fc_1(x)) 195 | x = F.relu(self.fc_2(x)) 196 | x = self.fc_3(x) 197 | return x 198 | 199 | 200 | 201 | ''' AlexNet ''' 202 | class AlexNet(nn.Module): 203 | def __init__(self, channel, num_classes, res=32): 204 | super(AlexNet, self).__init__() 205 | self.features = nn.Sequential( 206 | nn.Conv2d(channel, 128, kernel_size=5, stride=1 if res==32 else 2, padding=4 if channel==1 else 2), 207 | nn.ReLU(inplace=True), 208 | nn.MaxPool2d(kernel_size=2, stride=2), 209 | nn.Conv2d(128, 192, kernel_size=5, padding=2), 210 | nn.ReLU(inplace=True), 211 | nn.MaxPool2d(kernel_size=2, stride=2), 212 | nn.Conv2d(192, 256, kernel_size=3, padding=1), 213 | nn.ReLU(inplace=True), 214 | nn.Conv2d(256, 192, kernel_size=3, padding=1), 215 | nn.ReLU(inplace=True), 216 | nn.Conv2d(192, 192, kernel_size=3, padding=1), 217 | nn.ReLU(inplace=True), 218 | nn.MaxPool2d(kernel_size=2, stride=2), 219 | ) 220 | self.fc = nn.Linear(192 * 4 * 4, num_classes) 221 | 222 | def forward(self, x): 223 | x = self.features(x) 224 | x = x.view(x.size(0), -1) 225 | x = self.fc(x) 226 | return x 227 | 228 | 229 | 230 | ''' VGG ''' 231 | cfg_vgg = { 232 | 'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 233 | 'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 234 | 'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 235 | 'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], 236 | } 237 | class VGG(nn.Module): 238 | def __init__(self, vgg_name, channel, num_classes, norm='instancenorm', res=32): 239 | super(VGG, self).__init__() 240 | self.channel = channel 241 | self.features = self._make_layers(cfg_vgg[vgg_name], norm, res) 242 | self.classifier = nn.Linear(512 if vgg_name != 'VGGS' else 128, num_classes) 243 | 244 | def forward(self, x): 245 | x = self.features(x) 246 | x = x.view(x.size(0), -1) 247 | x = self.classifier(x) 248 | return x 249 | 250 | def _make_layers(self, cfg, norm, res): 251 | layers = [] 252 | in_channels = self.channel 253 | for ic, x in enumerate(cfg): 254 | if x == 'M': 255 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 256 | else: 257 | layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=3 if self.channel==1 and ic==0 else 1), 258 | nn.GroupNorm(x, x, affine=True) if norm=='instancenorm' else nn.BatchNorm2d(x), 259 | nn.ReLU(inplace=True)] 260 | in_channels = x 261 | layers += [nn.AvgPool2d(kernel_size=1, stride=1 if res==32 else 2)] 262 | return nn.Sequential(*layers) 263 | 264 | 265 | def VGG11(channel, num_classes): 266 | return VGG('VGG11', channel, num_classes) 267 | def VGG11_Tiny(channel, num_classes): 268 | return VGG('VGG11', channel, num_classes,res=64) 269 | def VGG11BN(channel, num_classes): 270 | return VGG('VGG11', channel, num_classes, norm='batchnorm') 271 | def VGG13(channel, num_classes): 272 | return VGG('VGG13', channel, num_classes) 273 | def VGG16(channel, num_classes): 274 | return VGG('VGG16', channel, num_classes) 275 | def VGG19(channel, num_classes): 276 | return VGG('VGG19', channel, num_classes) 277 | 278 | 279 | ''' ResNet_AP ''' 280 | # The conv(stride=2) is replaced by conv(stride=1) + avgpool(kernel_size=2, stride=2) 281 | 282 | class BasicBlock_AP(nn.Module): 283 | expansion = 1 284 | 285 | def __init__(self, in_planes, planes, stride=1, norm='instancenorm'): 286 | super(BasicBlock_AP, self).__init__() 287 | self.norm = norm 288 | self.stride = stride 289 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=1, padding=1, bias=False) # modification 290 | self.bn1 = nn.GroupNorm(planes, planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(planes) 291 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 292 | self.bn2 = nn.GroupNorm(planes, planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(planes) 293 | 294 | self.shortcut = nn.Sequential() 295 | if stride != 1 or in_planes != self.expansion * planes: 296 | self.shortcut = nn.Sequential( 297 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=1, bias=False), 298 | nn.AvgPool2d(kernel_size=2, stride=2), # modification 299 | nn.GroupNorm(self.expansion * planes, self.expansion * planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(self.expansion * planes) 300 | ) 301 | 302 | def forward(self, x): 303 | out = F.relu(self.bn1(self.conv1(x))) 304 | if self.stride != 1: # modification 305 | out = F.avg_pool2d(out, kernel_size=2, stride=2) 306 | out = self.bn2(self.conv2(out)) 307 | out += self.shortcut(x) 308 | out = F.relu(out) 309 | return out 310 | 311 | 312 | class Bottleneck_AP(nn.Module): 313 | expansion = 4 314 | 315 | def __init__(self, in_planes, planes, stride=1, norm='instancenorm'): 316 | super(Bottleneck_AP, self).__init__() 317 | self.norm = norm 318 | self.stride = stride 319 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 320 | self.bn1 = nn.GroupNorm(planes, planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(planes) 321 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) # modification 322 | self.bn2 = nn.GroupNorm(planes, planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(planes) 323 | self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False) 324 | self.bn3 = nn.GroupNorm(self.expansion * planes, self.expansion * planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(self.expansion * planes) 325 | 326 | self.shortcut = nn.Sequential() 327 | if stride != 1 or in_planes != self.expansion * planes: 328 | self.shortcut = nn.Sequential( 329 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=1, bias=False), 330 | nn.AvgPool2d(kernel_size=2, stride=2), # modification 331 | nn.GroupNorm(self.expansion * planes, self.expansion * planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(self.expansion * planes) 332 | ) 333 | 334 | def forward(self, x): 335 | out = F.relu(self.bn1(self.conv1(x))) 336 | out = F.relu(self.bn2(self.conv2(out))) 337 | if self.stride != 1: # modification 338 | out = F.avg_pool2d(out, kernel_size=2, stride=2) 339 | out = self.bn3(self.conv3(out)) 340 | out += self.shortcut(x) 341 | out = F.relu(out) 342 | return out 343 | 344 | 345 | class ResNet_AP(nn.Module): 346 | def __init__(self, block, num_blocks, channel=3, num_classes=10, norm='instancenorm'): 347 | super(ResNet_AP, self).__init__() 348 | self.in_planes = 64 349 | self.norm = norm 350 | 351 | self.conv1 = nn.Conv2d(channel, 64, kernel_size=3, stride=1, padding=1, bias=False) 352 | self.bn1 = nn.GroupNorm(64, 64, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(64) 353 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 354 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 355 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 356 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 357 | self.classifier = nn.Linear(512 * block.expansion * 3 * 3 if channel==1 else 512 * block.expansion * 4 * 4, num_classes) # modification 358 | 359 | def _make_layer(self, block, planes, num_blocks, stride): 360 | strides = [stride] + [1] * (num_blocks - 1) 361 | layers = [] 362 | for stride in strides: 363 | layers.append(block(self.in_planes, planes, stride, self.norm)) 364 | self.in_planes = planes * block.expansion 365 | return nn.Sequential(*layers) 366 | 367 | def forward(self, x): 368 | out = F.relu(self.bn1(self.conv1(x))) 369 | out = self.layer1(out) 370 | out = self.layer2(out) 371 | out = self.layer3(out) 372 | out = self.layer4(out) 373 | out = F.avg_pool2d(out, kernel_size=1, stride=1) # modification 374 | out = out.view(out.size(0), -1) 375 | out = self.classifier(out) 376 | return out 377 | 378 | 379 | def ResNet18BN_AP(channel, num_classes): 380 | return ResNet_AP(BasicBlock_AP, [2,2,2,2], channel=channel, num_classes=num_classes, norm='batchnorm') 381 | 382 | def ResNet18_AP(channel, num_classes): 383 | return ResNet_AP(BasicBlock_AP, [2,2,2,2], channel=channel, num_classes=num_classes) 384 | 385 | 386 | ''' ResNet ''' 387 | 388 | class BasicBlock(nn.Module): 389 | expansion = 1 390 | 391 | def __init__(self, in_planes, planes, stride=1, norm='instancenorm'): 392 | super(BasicBlock, self).__init__() 393 | self.norm = norm 394 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 395 | self.bn1 = nn.GroupNorm(planes, planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(planes) 396 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 397 | self.bn2 = nn.GroupNorm(planes, planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(planes) 398 | 399 | self.shortcut = nn.Sequential() 400 | if stride != 1 or in_planes != self.expansion*planes: 401 | self.shortcut = nn.Sequential( 402 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 403 | nn.GroupNorm(self.expansion*planes, self.expansion*planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(self.expansion*planes) 404 | ) 405 | 406 | def forward(self, x): 407 | out = F.relu(self.bn1(self.conv1(x))) 408 | out = self.bn2(self.conv2(out)) 409 | out += self.shortcut(x) 410 | out = F.relu(out) 411 | return out 412 | 413 | 414 | class Bottleneck(nn.Module): 415 | expansion = 4 416 | 417 | def __init__(self, in_planes, planes, stride=1, norm='instancenorm'): 418 | super(Bottleneck, self).__init__() 419 | self.norm = norm 420 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 421 | self.bn1 = nn.GroupNorm(planes, planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(planes) 422 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 423 | self.bn2 = nn.GroupNorm(planes, planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(planes) 424 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) 425 | self.bn3 = nn.GroupNorm(self.expansion*planes, self.expansion*planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(self.expansion*planes) 426 | 427 | self.shortcut = nn.Sequential() 428 | if stride != 1 or in_planes != self.expansion*planes: 429 | self.shortcut = nn.Sequential( 430 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 431 | nn.GroupNorm(self.expansion*planes, self.expansion*planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(self.expansion*planes) 432 | ) 433 | 434 | def forward(self, x): 435 | out = F.relu(self.bn1(self.conv1(x))) 436 | out = F.relu(self.bn2(self.conv2(out))) 437 | out = self.bn3(self.conv3(out)) 438 | out += self.shortcut(x) 439 | out = F.relu(out) 440 | return out 441 | 442 | 443 | class ResNet(nn.Module): 444 | def __init__(self, block, num_blocks, channel=3, num_classes=10, norm='instancenorm', res=32): 445 | super(ResNet, self).__init__() 446 | self.in_planes = 64 447 | self.norm = norm 448 | if res==64: 449 | self.conv1 = nn.Conv2d(channel, 64, kernel_size=3, stride=2, padding=1, bias=False) 450 | else: 451 | self.conv1 = nn.Conv2d(channel, 64, kernel_size=3, stride=1, padding=1, bias=False) 452 | self.bn1 = nn.GroupNorm(64, 64, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(64) 453 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 454 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 455 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 456 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 457 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 458 | self.classifier = nn.Linear(512*block.expansion, num_classes) 459 | 460 | def _make_layer(self, block, planes, num_blocks, stride): 461 | strides = [stride] + [1]*(num_blocks-1) 462 | layers = [] 463 | for stride in strides: 464 | layers.append(block(self.in_planes, planes, stride, self.norm)) 465 | self.in_planes = planes * block.expansion 466 | return nn.Sequential(*layers) 467 | 468 | def forward(self, x): 469 | out = F.relu(self.bn1(self.conv1(x))) 470 | out = self.layer1(out) 471 | out = self.layer2(out) 472 | out = self.layer3(out) 473 | out = self.layer4(out) 474 | out = F.avg_pool2d(out, 4) 475 | # out = self.avgpool(out) 476 | out = out.view(out.size(0), -1) 477 | out = self.classifier(out) 478 | return out 479 | 480 | 481 | class ResNetImageNet(nn.Module): 482 | def __init__(self, block, num_blocks, channel=3, num_classes=10, norm='instancenorm'): 483 | super(ResNetImageNet, self).__init__() 484 | self.in_planes = 64 485 | self.norm = norm 486 | 487 | self.conv1 = nn.Conv2d(channel, 64, kernel_size=7, stride=2, padding=3, bias=False) 488 | self.bn1 = nn.GroupNorm(64, 64, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(64) 489 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 490 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 491 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 492 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 493 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 494 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 495 | self.classifier = nn.Linear(512*block.expansion, num_classes) 496 | 497 | def _make_layer(self, block, planes, num_blocks, stride): 498 | strides = [stride] + [1]*(num_blocks-1) 499 | layers = [] 500 | for stride in strides: 501 | layers.append(block(self.in_planes, planes, stride, self.norm)) 502 | self.in_planes = planes * block.expansion 503 | return nn.Sequential(*layers) 504 | 505 | def forward(self, x): 506 | out = F.relu(self.bn1(self.conv1(x))) 507 | out = self.maxpool(out) 508 | out = self.layer1(out) 509 | out = self.layer2(out) 510 | out = self.layer3(out) 511 | out = self.layer4(out) 512 | # out = F.avg_pool2d(out, 4) 513 | # out = out.view(out.size(0), -1) 514 | out = self.avgpool(out) 515 | out = torch.flatten(out, 1) 516 | out = self.classifier(out) 517 | return out 518 | 519 | 520 | def ResNet18BN(channel, num_classes): 521 | return ResNet(BasicBlock, [2,2,2,2], channel=channel, num_classes=num_classes, norm='batchnorm') 522 | def ResNet18BN_Tiny(channel, num_classes): 523 | return ResNet(BasicBlock, [2,2,2,2], channel=channel, num_classes=num_classes, norm='batchnorm',res=64) 524 | 525 | def ResNet18(channel, num_classes): 526 | return ResNet(BasicBlock, [2,2,2,2], channel=channel, num_classes=num_classes) 527 | def ResNet18_Tiny(channel, num_classes): 528 | return ResNet(BasicBlock, [2,2,2,2], channel=channel, num_classes=num_classes,res=64) 529 | 530 | def ResNet34(channel, num_classes): 531 | return ResNet(BasicBlock, [3,4,6,3], channel=channel, num_classes=num_classes) 532 | 533 | def ResNet50(channel, num_classes): 534 | return ResNet(Bottleneck, [3,4,6,3], channel=channel, num_classes=num_classes) 535 | 536 | def ResNet101(channel, num_classes): 537 | return ResNet(Bottleneck, [3,4,23,3], channel=channel, num_classes=num_classes) 538 | 539 | def ResNet152(channel, num_classes): 540 | return ResNet(Bottleneck, [3,8,36,3], channel=channel, num_classes=num_classes) 541 | 542 | def ResNet18ImageNet(channel, num_classes): 543 | return ResNetImageNet(BasicBlock, [2,2,2,2], channel=channel, num_classes=num_classes) 544 | 545 | def ResNet6ImageNet(channel, num_classes): 546 | return ResNetImageNet(BasicBlock, [1,1,1,1], channel=channel, num_classes=num_classes) 547 | -------------------------------------------------------------------------------- /reparam_module.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import warnings 4 | import types 5 | from collections import namedtuple 6 | from contextlib import contextmanager 7 | 8 | 9 | class ReparamModule(nn.Module): 10 | def _get_module_from_name(self, mn): 11 | if mn == '': 12 | return self 13 | m = self 14 | for p in mn.split('.'): 15 | m = getattr(m, p) 16 | return m 17 | 18 | def __init__(self, module): 19 | super(ReparamModule, self).__init__() 20 | self.module = module 21 | 22 | param_infos = [] # (module name/path, param name) 23 | shared_param_memo = {} 24 | shared_param_infos = [] # (module name/path, param name, src module name/path, src param_name) 25 | params = [] 26 | param_numels = [] 27 | param_shapes = [] 28 | for mn, m in self.named_modules(): 29 | for n, p in m.named_parameters(recurse=False): 30 | if p is not None: 31 | if p in shared_param_memo: 32 | shared_mn, shared_n = shared_param_memo[p] 33 | shared_param_infos.append((mn, n, shared_mn, shared_n)) 34 | else: 35 | shared_param_memo[p] = (mn, n) 36 | param_infos.append((mn, n)) 37 | params.append(p.detach()) 38 | param_numels.append(p.numel()) 39 | param_shapes.append(p.size()) 40 | 41 | assert len(set(p.dtype for p in params)) <= 1, \ 42 | "expects all parameters in module to have same dtype" 43 | 44 | # store the info for unflatten 45 | self._param_infos = tuple(param_infos) 46 | self._shared_param_infos = tuple(shared_param_infos) 47 | self._param_numels = tuple(param_numels) 48 | self._param_shapes = tuple(param_shapes) 49 | 50 | # flatten 51 | flat_param = nn.Parameter(torch.cat([p.reshape(-1) for p in params], 0)) 52 | self.register_parameter('flat_param', flat_param) 53 | self.param_numel = flat_param.numel() 54 | del params 55 | del shared_param_memo 56 | 57 | # deregister the names as parameters 58 | for mn, n in self._param_infos: 59 | delattr(self._get_module_from_name(mn), n) 60 | for mn, n, _, _ in self._shared_param_infos: 61 | delattr(self._get_module_from_name(mn), n) 62 | 63 | # register the views as plain attributes 64 | self._unflatten_param(self.flat_param) 65 | 66 | # now buffers 67 | # they are not reparametrized. just store info as (module, name, buffer) 68 | buffer_infos = [] 69 | for mn, m in self.named_modules(): 70 | for n, b in m.named_buffers(recurse=False): 71 | if b is not None: 72 | buffer_infos.append((mn, n, b)) 73 | 74 | self._buffer_infos = tuple(buffer_infos) 75 | self._traced_self = None 76 | 77 | def trace(self, example_input, **trace_kwargs): 78 | assert self._traced_self is None, 'This ReparamModule is already traced' 79 | 80 | if isinstance(example_input, torch.Tensor): 81 | example_input = (example_input,) 82 | example_input = tuple(example_input) 83 | example_param = (self.flat_param.detach().clone(),) 84 | example_buffers = (tuple(b.detach().clone() for _, _, b in self._buffer_infos),) 85 | 86 | self._traced_self = torch.jit.trace_module( 87 | self, 88 | inputs=dict( 89 | _forward_with_param=example_param + example_input, 90 | _forward_with_param_and_buffers=example_param + example_buffers + example_input, 91 | ), 92 | **trace_kwargs, 93 | ) 94 | 95 | # replace forwards with traced versions 96 | self._forward_with_param = self._traced_self._forward_with_param 97 | self._forward_with_param_and_buffers = self._traced_self._forward_with_param_and_buffers 98 | return self 99 | 100 | def clear_views(self): 101 | for mn, n in self._param_infos: 102 | setattr(self._get_module_from_name(mn), n, None) # This will set as plain attr 103 | 104 | def _apply(self, *args, **kwargs): 105 | if self._traced_self is not None: 106 | self._traced_self._apply(*args, **kwargs) 107 | return self 108 | return super(ReparamModule, self)._apply(*args, **kwargs) 109 | 110 | def _unflatten_param(self, flat_param): 111 | ps = (t.view(s) for (t, s) in zip(flat_param.split(self._param_numels), self._param_shapes)) 112 | for (mn, n), p in zip(self._param_infos, ps): 113 | setattr(self._get_module_from_name(mn), n, p) # This will set as plain attr 114 | for (mn, n, shared_mn, shared_n) in self._shared_param_infos: 115 | setattr(self._get_module_from_name(mn), n, getattr(self._get_module_from_name(shared_mn), shared_n)) 116 | 117 | @contextmanager 118 | def unflattened_param(self, flat_param): 119 | saved_views = [getattr(self._get_module_from_name(mn), n) for mn, n in self._param_infos] 120 | self._unflatten_param(flat_param) 121 | yield 122 | # Why not just `self._unflatten_param(self.flat_param)`? 123 | # 1. because of https://github.com/pytorch/pytorch/issues/17583 124 | # 2. slightly faster since it does not require reconstruct the split+view 125 | # graph 126 | for (mn, n), p in zip(self._param_infos, saved_views): 127 | setattr(self._get_module_from_name(mn), n, p) 128 | for (mn, n, shared_mn, shared_n) in self._shared_param_infos: 129 | setattr(self._get_module_from_name(mn), n, getattr(self._get_module_from_name(shared_mn), shared_n)) 130 | 131 | @contextmanager 132 | def replaced_buffers(self, buffers): 133 | for (mn, n, _), new_b in zip(self._buffer_infos, buffers): 134 | setattr(self._get_module_from_name(mn), n, new_b) 135 | yield 136 | for mn, n, old_b in self._buffer_infos: 137 | setattr(self._get_module_from_name(mn), n, old_b) 138 | 139 | def _forward_with_param_and_buffers(self, flat_param, buffers, *inputs, **kwinputs): 140 | with self.unflattened_param(flat_param): 141 | with self.replaced_buffers(buffers): 142 | return self.module(*inputs, **kwinputs) 143 | 144 | def _forward_with_param(self, flat_param, *inputs, **kwinputs): 145 | with self.unflattened_param(flat_param): 146 | return self.module(*inputs, **kwinputs) 147 | 148 | def forward(self, *inputs, flat_param=None, buffers=None, **kwinputs): 149 | flat_param = torch.squeeze(flat_param) 150 | # print("PARAMS ON DEVICE: ", flat_param.get_device()) 151 | # print("DATA ON DEVICE: ", inputs[0].get_device()) 152 | # flat_param.to("cuda:{}".format(inputs[0].get_device())) 153 | # self.module.to("cuda:{}".format(inputs[0].get_device())) 154 | if flat_param is None: 155 | flat_param = self.flat_param 156 | if buffers is None: 157 | return self._forward_with_param(flat_param, *inputs, **kwinputs) 158 | else: 159 | return self._forward_with_param_and_buffers(flat_param, tuple(buffers), *inputs, **kwinputs) -------------------------------------------------------------------------------- /utils/cfg.py: -------------------------------------------------------------------------------- 1 | from yacs.config import CfgNode as CN 2 | 3 | 4 | def show_cfg(cfg): 5 | dump_cfg = CN() 6 | dump_cfg.EXPERIMENT = cfg.EXPERIMENT 7 | dump_cfg.DATASET = cfg.DATASET 8 | dump_cfg.DISTILLER = cfg.DISTILLER 9 | dump_cfg.SOLVER = cfg.SOLVER 10 | dump_cfg.LOG = cfg.LOG 11 | if cfg.DISTILLER.TYPE in cfg: 12 | dump_cfg.update({cfg.DISTILLER.TYPE: cfg.get(cfg.DISTILLER.TYPE)}) 13 | print(log_msg("CONFIG:\n{}".format(dump_cfg.dump()), "INFO")) 14 | 15 | 16 | CFG = CN() 17 | 18 | # Configuration Settings 19 | 20 | # dataset 21 | CFG.dataset = 'CIFAR10' 22 | 23 | # ImageNet subset. This only does anything when --dataset=ImageNet 24 | CFG.subset = 'imagenette' 25 | 26 | # model 27 | CFG.model = 'ConvNet' 28 | 29 | # image(s) per class 30 | CFG.ipc = 1 31 | 32 | # eval_mode, check utils.py for more info 33 | CFG.eval_mode = 'S' 34 | 35 | # how many networks to evaluate on 36 | CFG.num_eval = 5 37 | 38 | # how often to evaluate 39 | CFG.eval_it = 100 40 | 41 | # epochs to train a model with synthetic data 42 | CFG.epoch_eval_train = 1000 43 | 44 | # how many distillation steps to perform 45 | CFG.Iteration = 5000 46 | 47 | # Learning rates 48 | CFG.lr_img = 1000 # learning rate for updating synthetic images 49 | CFG.lr_teacher = 0.01 # initialization for synthetic learning rate 50 | CFG.lr_init = 0.01 # how to init lr (alpha) 51 | 52 | # Batch sizes 53 | CFG.batch_real = 256 # batch size for real data 54 | CFG.batch_syn = None # should only use this if you run out of VRAM 55 | CFG.batch_train = 256 # batch size for training networks 56 | 57 | # Initialization for synthetic images 58 | CFG.pix_init = 'samples_predicted_correctly' # initialize synthetic images from random noise or real images 59 | 60 | # Differentiable Siamese Augmentation 61 | CFG.dsa = True # whether to use differentiable Siamese augmentation 62 | CFG.dsa_strategy = 'color_crop_cutout_flip_scale_rotate' # differentiable Siamese augmentation strategy 63 | 64 | # Paths 65 | CFG.data_path = '../dataset/' # dataset path 66 | CFG.buffer_path = '../buffer_storage/' # buffer path 67 | 68 | # Expert epochs and synthetic data steps 69 | CFG.expert_epochs = 2 # how many expert epochs the target params are 70 | CFG.syn_steps = 80 # how many steps to take on synthetic data 71 | 72 | # Start epochs 73 | CFG.max_start_epoch = 25 # max epoch we can start at 74 | CFG.min_start_epoch = 0 # min epoch we can start at 75 | 76 | # ZCA whitening 77 | CFG.zca = True # do ZCA whitening (use True if action='store_true') 78 | 79 | # Load all expert trajectories into RAM 80 | CFG.load_all = False # only use if you can fit all expert trajectories into RAM (use True if action='store_true') 81 | 82 | # Turn off differential augmentation during distillation 83 | CFG.no_aug = False # this turns off diff aug during distillation 84 | 85 | # Distill textures instead 86 | CFG.texture = False # will distill textures instead (use True if action='store_true') 87 | CFG.canvas_size = 2 # size of synthetic canvas 88 | CFG.canvas_samples = 1 # number of canvas samples per iteration 89 | 90 | # Number of expert files to read (leave as None unless doing ablations) 91 | CFG.max_files = None 92 | 93 | # Number of experts to read per file (leave as None unless doing ablations) 94 | CFG.max_experts = None 95 | 96 | # Force saving images for 50ipc 97 | CFG.force_save = False # this will save images for 50ipc (use True if action='store_true') 98 | CFG.ema_decay = 0.999 99 | 100 | 101 | # Learning rate for 'y' 102 | CFG.lr_y = 2. 103 | # Momentum for 'y' 104 | CFG.Momentum_y = 0.9 105 | 106 | # WanDB Project Name 107 | CFG.project = 'TEST' 108 | 109 | # Threshold 110 | CFG.threshold = 1.0 111 | 112 | # Record loss 113 | CFG.record_loss = False 114 | 115 | # Sequential Generation 116 | CFG.Sequential_Generation = True 117 | CFG.expansion_end_epoch = 3000 118 | CFG.current_max_start_epoch = 20 119 | 120 | 121 | # Skip first evaluation 122 | CFG.skip_first_eva = False # If skip first eva 123 | 124 | # Parallel evaluation 125 | CFG.parall_eva = False # If parallel eva 126 | 127 | CFG.lr_lr = 0.00001 128 | 129 | CFG.res = 32 130 | 131 | CFG.device = [0] 132 | 133 | CFG.Initialize_Label_With_Another_Model = False 134 | CFG.Initialize_Label_Model = "" 135 | CFG.Initialize_Label_Model_Dir = "" 136 | CFG.Label_Model_Timestamp = -1 -------------------------------------------------------------------------------- /utils/step_lr.py: -------------------------------------------------------------------------------- 1 | class StepLR: 2 | def __init__(self, optimizer, learning_rate: float, total_epochs: int): 3 | self.optimizer = optimizer 4 | self.total_epochs = total_epochs 5 | self.base = learning_rate 6 | 7 | def __call__(self, epoch): 8 | if epoch < self.total_epochs * 3/10: 9 | lr = self.base 10 | elif epoch < self.total_epochs * 6/10: 11 | lr = self.base * 0.2 12 | elif epoch < self.total_epochs * 8/10: 13 | lr = self.base * 0.2 ** 2 14 | else: 15 | lr = self.base * 0.2 ** 3 16 | 17 | for param_group in self.optimizer.param_groups: 18 | param_group["lr"] = lr 19 | 20 | def lr(self) -> float: 21 | return self.optimizer.param_groups[0]["lr"] 22 | -------------------------------------------------------------------------------- /utils/utils_baseline_backup.py: -------------------------------------------------------------------------------- 1 | # adapted from 2 | # https://github.com/VICO-UoE/DatasetCondensation 3 | 4 | import time 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import os 10 | import kornia as K 11 | import tqdm 12 | from torch.utils.data import Dataset 13 | from torchvision import datasets, transforms 14 | from scipy.ndimage.interpolation import rotate as scipyrotate 15 | from networks import MLP, ConvNet, LeNet, AlexNet, VGG11BN, VGG11, ResNet18, ResNet18BN_AP, ResNet18_AP 16 | 17 | class Config: 18 | imagenette = [0, 217, 482, 491, 497, 566, 569, 571, 574, 701] 19 | 20 | # ["australian_terrier", "border_terrier", "samoyed", "beagle", "shih-tzu", "english_foxhound", "rhodesian_ridgeback", "dingo", "golden_retriever", "english_sheepdog"] 21 | imagewoof = [193, 182, 258, 162, 155, 167, 159, 273, 207, 229] 22 | 23 | # ["tabby_cat", "bengal_cat", "persian_cat", "siamese_cat", "egyptian_cat", "lion", "tiger", "jaguar", "snow_leopard", "lynx"] 24 | imagemeow = [281, 282, 283, 284, 285, 291, 292, 290, 289, 287] 25 | 26 | # ["peacock", "flamingo", "macaw", "pelican", "king_penguin", "bald_eagle", "toucan", "ostrich", "black_swan", "cockatoo"] 27 | imagesquawk = [84, 130, 88, 144, 145, 22, 96, 9, 100, 89] 28 | 29 | # ["pineapple", "banana", "strawberry", "orange", "lemon", "pomegranate", "fig", "bell_pepper", "cucumber", "green_apple"] 30 | imagefruit = [953, 954, 949, 950, 951, 957, 952, 945, 943, 948] 31 | 32 | # ["bee", "ladys slipper", "banana", "lemon", "corn", "school_bus", "honeycomb", "lion", "garden_spider", "goldfinch"] 33 | imageyellow = [309, 986, 954, 951, 987, 779, 599, 291, 72, 11] 34 | 35 | dict = { 36 | "imagenette" : imagenette, 37 | "imagewoof" : imagewoof, 38 | "imagefruit": imagefruit, 39 | "imageyellow": imageyellow, 40 | "imagemeow": imagemeow, 41 | "imagesquawk": imagesquawk, 42 | } 43 | 44 | config = Config() 45 | 46 | def get_dataset(dataset, data_path, batch_size=1, subset="imagenette", args=None): 47 | 48 | class_map = None 49 | loader_train_dict = None 50 | class_map_inv = None 51 | 52 | if dataset == 'CIFAR10': 53 | channel = 3 54 | im_size = (32, 32) 55 | num_classes = 10 56 | mean = [0.4914, 0.4822, 0.4465] 57 | std = [0.2023, 0.1994, 0.2010] 58 | if args.zca: 59 | transform = transforms.Compose([transforms.ToTensor()]) 60 | else: 61 | transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=mean, std=std)]) 62 | dst_train = datasets.CIFAR10(data_path, train=True, download=True, transform=transform) # no augmentation 63 | dst_test = datasets.CIFAR10(data_path, train=False, download=True, transform=transform) 64 | class_names = dst_train.classes 65 | class_map = {x:x for x in range(num_classes)} 66 | 67 | 68 | elif dataset == 'Tiny': 69 | channel = 3 70 | im_size = (64, 64) 71 | num_classes = 200 72 | mean = [0.485, 0.456, 0.406] 73 | std = [0.229, 0.224, 0.225] 74 | if args.zca: 75 | transform = transforms.Compose([transforms.ToTensor()]) 76 | else: 77 | transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=mean, std=std)]) 78 | dst_train = datasets.ImageFolder(os.path.join(data_path, "train"), transform=transform) # no augmentation 79 | dst_test = datasets.ImageFolder(os.path.join(data_path, "val"), transform=transform) 80 | class_names = dst_train.classes 81 | class_map = {x:x for x in range(num_classes)} 82 | 83 | 84 | elif dataset == 'ImageNet': 85 | channel = 3 86 | im_size = (128, 128) 87 | num_classes = 10 88 | 89 | config.img_net_classes = config.dict[subset] 90 | 91 | mean = [0.485, 0.456, 0.406] 92 | std = [0.229, 0.224, 0.225] 93 | if args.zca: 94 | transform = transforms.Compose([transforms.ToTensor(), 95 | transforms.Resize(im_size), 96 | transforms.CenterCrop(im_size)]) 97 | else: 98 | transform = transforms.Compose([transforms.ToTensor(), 99 | transforms.Normalize(mean=mean, std=std), 100 | transforms.Resize(im_size), 101 | transforms.CenterCrop(im_size)]) 102 | 103 | dst_train = datasets.ImageNet(data_path, split="train", transform=transform) # no augmentation 104 | dst_train_dict = {c : torch.utils.data.Subset(dst_train, np.squeeze(np.argwhere(np.equal(dst_train.targets, config.img_net_classes[c])))) for c in range(len(config.img_net_classes))} 105 | dst_train = torch.utils.data.Subset(dst_train, np.squeeze(np.argwhere(np.isin(dst_train.targets, config.img_net_classes)))) 106 | loader_train_dict = {c : torch.utils.data.DataLoader(dst_train_dict[c], batch_size=batch_size, shuffle=True, num_workers=16) for c in range(len(config.img_net_classes))} 107 | dst_test = datasets.ImageNet(data_path, split="val", transform=transform) 108 | dst_test = torch.utils.data.Subset(dst_test, np.squeeze(np.argwhere(np.isin(dst_test.targets, config.img_net_classes)))) 109 | for c in range(len(config.img_net_classes)): 110 | dst_test.dataset.targets[dst_test.dataset.targets == config.img_net_classes[c]] = c 111 | dst_train.dataset.targets[dst_train.dataset.targets == config.img_net_classes[c]] = c 112 | print(dst_test.dataset) 113 | class_map = {x: i for i, x in enumerate(config.img_net_classes)} 114 | class_map_inv = {i: x for i, x in enumerate(config.img_net_classes)} 115 | class_names = None 116 | 117 | 118 | elif dataset.startswith('CIFAR100'): 119 | channel = 3 120 | im_size = (32, 32) 121 | num_classes = 100 122 | mean = [0.4914, 0.4822, 0.4465] 123 | std = [0.2023, 0.1994, 0.2010] 124 | 125 | if args.zca: 126 | transform = transforms.Compose([transforms.ToTensor()]) 127 | else: 128 | transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=mean, std=std)]) 129 | dst_train = datasets.CIFAR100(data_path, train=True, download=True, transform=transform) # no augmentation 130 | dst_test = datasets.CIFAR100(data_path, train=False, download=True, transform=transform) 131 | class_names = dst_train.classes 132 | class_map = {x: x for x in range(num_classes)} 133 | 134 | else: 135 | exit('unknown dataset: %s'%dataset) 136 | 137 | if args.zca: 138 | images = [] 139 | labels = [] 140 | print("Train ZCA") 141 | for i in tqdm.tqdm(range(len(dst_train))): 142 | im, lab = dst_train[i] 143 | images.append(im) 144 | labels.append(lab) 145 | images = torch.stack(images, dim=0).to(args.device) 146 | labels = torch.tensor(labels, dtype=torch.long, device="cpu") 147 | zca = K.enhance.ZCAWhitening(eps=0.1, compute_inv=True) 148 | zca.fit(images) 149 | zca_images = zca(images).to("cpu") 150 | dst_train = TensorDataset(zca_images, labels) 151 | 152 | images = [] 153 | labels = [] 154 | print("Test ZCA") 155 | for i in tqdm.tqdm(range(len(dst_test))): 156 | im, lab = dst_test[i] 157 | images.append(im) 158 | labels.append(lab) 159 | images = torch.stack(images, dim=0).to(args.device) 160 | labels = torch.tensor(labels, dtype=torch.long, device="cpu") 161 | 162 | zca_images = zca(images).to("cpu") 163 | dst_test = TensorDataset(zca_images, labels) 164 | 165 | args.zca_trans = zca 166 | 167 | 168 | testloader = torch.utils.data.DataLoader(dst_test, batch_size=128, shuffle=False, num_workers=2) 169 | 170 | 171 | return channel, im_size, num_classes, class_names, mean, std, dst_train, dst_test, testloader, loader_train_dict, class_map, class_map_inv 172 | 173 | 174 | 175 | class TensorDataset(Dataset): 176 | def __init__(self, images, labels): # images: n x c x h x w tensor 177 | self.images = images.detach().float() 178 | self.labels = labels.detach() 179 | 180 | def __getitem__(self, index): 181 | return self.images[index], self.labels[index] 182 | 183 | def __len__(self): 184 | return self.images.shape[0] 185 | 186 | 187 | 188 | def get_default_convnet_setting(): 189 | net_width, net_depth, net_act, net_norm, net_pooling = 128, 3, 'relu', 'instancenorm', 'avgpooling' 190 | return net_width, net_depth, net_act, net_norm, net_pooling 191 | 192 | 193 | 194 | def get_network(model, channel, num_classes, im_size=(32, 32), dist=True): 195 | torch.random.manual_seed(int(time.time() * 1000) % 100000) 196 | net_width, net_depth, net_act, net_norm, net_pooling = get_default_convnet_setting() 197 | 198 | if model == 'MLP': 199 | net = MLP(channel=channel, num_classes=num_classes) 200 | elif model == 'ConvNet': 201 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=net_depth, net_act=net_act, net_norm=net_norm, net_pooling=net_pooling, im_size=im_size) 202 | elif model == 'LeNet': 203 | net = LeNet(channel=channel, num_classes=num_classes) 204 | elif model == 'AlexNet': 205 | net = AlexNet(channel=channel, num_classes=num_classes) 206 | elif model == 'VGG11': 207 | net = VGG11( channel=channel, num_classes=num_classes) 208 | elif model == 'VGG11BN': 209 | net = VGG11BN(channel=channel, num_classes=num_classes) 210 | elif model == 'ResNet18': 211 | net = ResNet18(channel=channel, num_classes=num_classes) 212 | elif model == 'ResNet18BN_AP': 213 | net = ResNet18BN_AP(channel=channel, num_classes=num_classes) 214 | elif model == 'ResNet18_AP': 215 | net = ResNet18_AP(channel=channel, num_classes=num_classes) 216 | 217 | elif model == 'ConvNetD1': 218 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=1, net_act=net_act, net_norm=net_norm, net_pooling=net_pooling, im_size=im_size) 219 | elif model == 'ConvNetD2': 220 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=2, net_act=net_act, net_norm=net_norm, net_pooling=net_pooling, im_size=im_size) 221 | elif model == 'ConvNetD3': 222 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=3, net_act=net_act, net_norm=net_norm, net_pooling=net_pooling, im_size=im_size) 223 | elif model == 'ConvNetD4': 224 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=4, net_act=net_act, net_norm=net_norm, net_pooling=net_pooling, im_size=im_size) 225 | elif model == 'ConvNetD5': 226 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=5, net_act=net_act, net_norm=net_norm, net_pooling=net_pooling, im_size=im_size) 227 | elif model == 'ConvNetD6': 228 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=6, net_act=net_act, net_norm=net_norm, net_pooling=net_pooling, im_size=im_size) 229 | elif model == 'ConvNetD7': 230 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=7, net_act=net_act, net_norm=net_norm, net_pooling=net_pooling, im_size=im_size) 231 | elif model == 'ConvNetD8': 232 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=8, net_act=net_act, net_norm=net_norm, net_pooling=net_pooling, im_size=im_size) 233 | 234 | 235 | elif model == 'ConvNetW32': 236 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=32, net_depth=net_depth, net_act=net_act, net_norm=net_norm, net_pooling=net_pooling) 237 | elif model == 'ConvNetW64': 238 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=64, net_depth=net_depth, net_act=net_act, net_norm=net_norm, net_pooling=net_pooling) 239 | elif model == 'ConvNetW128': 240 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=128, net_depth=net_depth, net_act=net_act, net_norm=net_norm, net_pooling=net_pooling) 241 | elif model == 'ConvNetW256': 242 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=256, net_depth=net_depth, net_act=net_act, net_norm=net_norm, net_pooling=net_pooling) 243 | elif model == 'ConvNetW512': 244 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=512, net_depth=net_depth, net_act=net_act, net_norm=net_norm, net_pooling=net_pooling) 245 | elif model == 'ConvNetW1024': 246 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=1024, net_depth=net_depth, net_act=net_act, net_norm=net_norm, net_pooling=net_pooling) 247 | 248 | elif model == "ConvNetKIP": 249 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=1024, net_depth=net_depth, net_act=net_act, 250 | net_norm="none", net_pooling=net_pooling) 251 | 252 | elif model == 'ConvNetAS': 253 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=net_depth, net_act='sigmoid', net_norm=net_norm, net_pooling=net_pooling) 254 | elif model == 'ConvNetAR': 255 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=net_depth, net_act='relu', net_norm=net_norm, net_pooling=net_pooling) 256 | elif model == 'ConvNetAL': 257 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=net_depth, net_act='leakyrelu', net_norm=net_norm, net_pooling=net_pooling) 258 | 259 | elif model == 'ConvNetNN': 260 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=net_depth, net_act=net_act, net_norm='none', net_pooling=net_pooling) 261 | elif model == 'ConvNetBN': 262 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=net_depth, net_act=net_act, net_norm='batchnorm', net_pooling=net_pooling) 263 | elif model == 'ConvNetLN': 264 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=net_depth, net_act=net_act, net_norm='layernorm', net_pooling=net_pooling) 265 | elif model == 'ConvNetIN': 266 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=net_depth, net_act=net_act, net_norm='instancenorm', net_pooling=net_pooling) 267 | elif model == 'ConvNetGN': 268 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=net_depth, net_act=net_act, net_norm='groupnorm', net_pooling=net_pooling) 269 | 270 | elif model == 'ConvNetNP': 271 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=net_depth, net_act=net_act, net_norm=net_norm, net_pooling='none') 272 | elif model == 'ConvNetMP': 273 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=net_depth, net_act=net_act, net_norm=net_norm, net_pooling='maxpooling') 274 | elif model == 'ConvNetAP': 275 | net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=net_depth, net_act=net_act, net_norm=net_norm, net_pooling='avgpooling') 276 | 277 | 278 | else: 279 | net = None 280 | exit('DC error: unknown model') 281 | 282 | if dist: 283 | gpu_num = torch.cuda.device_count() 284 | if gpu_num>0: 285 | device = 'cuda' 286 | if gpu_num>1: 287 | net = nn.DataParallel(net) 288 | else: 289 | device = 'cpu' 290 | net = net.to(device) 291 | 292 | return net 293 | 294 | 295 | 296 | def get_time(): 297 | return str(time.strftime("[%Y-%m-%d %H:%M:%S]", time.localtime())) 298 | 299 | 300 | def epoch(mode, dataloader, net, optimizer, criterion, args, aug, texture=False): 301 | loss_avg, acc_avg, num_exp = 0, 0, 0 302 | net = net.to(args.device) 303 | 304 | if args.dataset == "ImageNet": 305 | class_map = {x: i for i, x in enumerate(config.img_net_classes)} 306 | 307 | if mode == 'train': 308 | net.train() 309 | else: 310 | net.eval() 311 | 312 | for i_batch, datum in enumerate(dataloader): 313 | img = datum[0].float().to(args.device) 314 | lab = datum[1].long().to(args.device) 315 | 316 | if mode == "train" and texture: 317 | img = torch.cat([torch.stack([torch.roll(im, (torch.randint(args.im_size[0]*args.canvas_size, (1,)), torch.randint(args.im_size[0]*args.canvas_size, (1,))), (1,2))[:,:args.im_size[0],:args.im_size[1]] for im in img]) for _ in range(args.canvas_samples)]) 318 | lab = torch.cat([lab for _ in range(args.canvas_samples)]) 319 | 320 | if aug: 321 | if args.dsa: 322 | img = DiffAugment(img, args.dsa_strategy, param=args.dsa_param) 323 | else: 324 | img = augment(img, args.dc_aug_param, device=args.device) 325 | 326 | if args.dataset == "ImageNet" and mode != "train": 327 | lab = torch.tensor([class_map[x.item()] for x in lab]).to(args.device) 328 | 329 | n_b = lab.shape[0] 330 | 331 | output = net(img) 332 | loss = criterion(output, lab) 333 | 334 | acc = np.sum(np.equal(np.argmax(output.cpu().data.numpy(), axis=-1), lab.cpu().data.numpy())) 335 | 336 | loss_avg += loss.item()*n_b 337 | acc_avg += acc 338 | num_exp += n_b 339 | 340 | if mode == 'train': 341 | optimizer.zero_grad() 342 | loss.backward() 343 | optimizer.step() 344 | 345 | loss_avg /= num_exp 346 | acc_avg /= num_exp 347 | 348 | return loss_avg, acc_avg 349 | 350 | 351 | 352 | def evaluate_synset(it_eval, net, images_train, labels_train, testloader, args, return_loss=False, texture=False): 353 | net = net.to(args.device) 354 | images_train = images_train.to(args.device) 355 | labels_train = labels_train.to(args.device) 356 | lr = float(args.lr_net) 357 | Epoch = int(args.epoch_eval_train) 358 | lr_schedule = [Epoch//2+1] 359 | optimizer = torch.optim.SGD(net.parameters(), lr=lr, momentum=0.9, weight_decay=0.0005) 360 | 361 | criterion = nn.CrossEntropyLoss().to(args.device) 362 | 363 | dst_train = TensorDataset(images_train, labels_train) 364 | trainloader = torch.utils.data.DataLoader(dst_train, batch_size=args.batch_train, shuffle=True, num_workers=0) 365 | 366 | start = time.time() 367 | acc_train_list = [] 368 | loss_train_list = [] 369 | 370 | for ep in tqdm.tqdm(range(Epoch+1)): 371 | loss_train, acc_train = epoch('train', trainloader, net, optimizer, criterion, args, aug=True, texture=texture) 372 | acc_train_list.append(acc_train) 373 | loss_train_list.append(loss_train) 374 | if ep == Epoch: 375 | with torch.no_grad(): 376 | loss_test, acc_test = epoch('test', testloader, net, optimizer, criterion, args, aug=False) 377 | if ep in lr_schedule: 378 | lr *= 0.1 379 | optimizer = torch.optim.SGD(net.parameters(), lr=lr, momentum=0.9, weight_decay=0.0005) 380 | 381 | 382 | time_train = time.time() - start 383 | 384 | print('%s Evaluate_%02d: epoch = %04d train time = %d s train loss = %.6f train acc = %.4f, test acc = %.4f' % (get_time(), it_eval, Epoch, int(time_train), loss_train, acc_train, acc_test)) 385 | 386 | if return_loss: 387 | return net, acc_train_list, acc_test, loss_train_list, loss_test 388 | else: 389 | return net, acc_train_list, acc_test 390 | 391 | 392 | def augment(images, dc_aug_param, device): 393 | # This can be sped up in the future. 394 | 395 | if dc_aug_param != None and dc_aug_param['strategy'] != 'none': 396 | scale = dc_aug_param['scale'] 397 | crop = dc_aug_param['crop'] 398 | rotate = dc_aug_param['rotate'] 399 | noise = dc_aug_param['noise'] 400 | strategy = dc_aug_param['strategy'] 401 | 402 | shape = images.shape 403 | mean = [] 404 | for c in range(shape[1]): 405 | mean.append(float(torch.mean(images[:,c]))) 406 | 407 | def cropfun(i): 408 | im_ = torch.zeros(shape[1],shape[2]+crop*2,shape[3]+crop*2, dtype=torch.float, device=device) 409 | for c in range(shape[1]): 410 | im_[c] = mean[c] 411 | im_[:, crop:crop+shape[2], crop:crop+shape[3]] = images[i] 412 | r, c = np.random.permutation(crop*2)[0], np.random.permutation(crop*2)[0] 413 | images[i] = im_[:, r:r+shape[2], c:c+shape[3]] 414 | 415 | def scalefun(i): 416 | h = int((np.random.uniform(1 - scale, 1 + scale)) * shape[2]) 417 | w = int((np.random.uniform(1 - scale, 1 + scale)) * shape[2]) 418 | tmp = F.interpolate(images[i:i + 1], [h, w], )[0] 419 | mhw = max(h, w, shape[2], shape[3]) 420 | im_ = torch.zeros(shape[1], mhw, mhw, dtype=torch.float, device=device) 421 | r = int((mhw - h) / 2) 422 | c = int((mhw - w) / 2) 423 | im_[:, r:r + h, c:c + w] = tmp 424 | r = int((mhw - shape[2]) / 2) 425 | c = int((mhw - shape[3]) / 2) 426 | images[i] = im_[:, r:r + shape[2], c:c + shape[3]] 427 | 428 | def rotatefun(i): 429 | im_ = scipyrotate(images[i].cpu().data.numpy(), angle=np.random.randint(-rotate, rotate), axes=(-2, -1), cval=np.mean(mean)) 430 | r = int((im_.shape[-2] - shape[-2]) / 2) 431 | c = int((im_.shape[-1] - shape[-1]) / 2) 432 | images[i] = torch.tensor(im_[:, r:r + shape[-2], c:c + shape[-1]], dtype=torch.float, device=device) 433 | 434 | def noisefun(i): 435 | images[i] = images[i] + noise * torch.randn(shape[1:], dtype=torch.float, device=device) 436 | 437 | 438 | augs = strategy.split('_') 439 | 440 | for i in range(shape[0]): 441 | choice = np.random.permutation(augs)[0] # randomly implement one augmentation 442 | if choice == 'crop': 443 | cropfun(i) 444 | elif choice == 'scale': 445 | scalefun(i) 446 | elif choice == 'rotate': 447 | rotatefun(i) 448 | elif choice == 'noise': 449 | noisefun(i) 450 | 451 | return images 452 | 453 | 454 | 455 | def get_daparam(dataset, model, model_eval, ipc): 456 | # We find that augmentation doesn't always benefit the performance. 457 | # So we do augmentation for some of the settings. 458 | 459 | dc_aug_param = dict() 460 | dc_aug_param['crop'] = 4 461 | dc_aug_param['scale'] = 0.2 462 | dc_aug_param['rotate'] = 45 463 | dc_aug_param['noise'] = 0.001 464 | dc_aug_param['strategy'] = 'none' 465 | 466 | if dataset == 'MNIST': 467 | dc_aug_param['strategy'] = 'crop_scale_rotate' 468 | 469 | if model_eval in ['ConvNetBN']: # Data augmentation makes model training with Batch Norm layer easier. 470 | dc_aug_param['strategy'] = 'crop_noise' 471 | 472 | return dc_aug_param 473 | 474 | 475 | def get_eval_pool(eval_mode, model, model_eval): 476 | if eval_mode == 'M': # multiple architectures 477 | # model_eval_pool = ['MLP', 'ConvNet', 'AlexNet', 'VGG11', 'ResNet18', 'LeNet'] 478 | model_eval_pool = ['ConvNet', 'AlexNet', 'VGG11', 'ResNet18_AP', 'ResNet18'] 479 | # model_eval_pool = ['MLP', 'ConvNet', 'AlexNet', 'VGG11', 'ResNet18'] 480 | elif eval_mode == 'W': # ablation study on network width 481 | model_eval_pool = ['ConvNetW32', 'ConvNetW64', 'ConvNetW128', 'ConvNetW256'] 482 | elif eval_mode == 'D': # ablation study on network depth 483 | model_eval_pool = ['ConvNetD1', 'ConvNetD2', 'ConvNetD3', 'ConvNetD4'] 484 | elif eval_mode == 'A': # ablation study on network activation function 485 | model_eval_pool = ['ConvNetAS', 'ConvNetAR', 'ConvNetAL'] 486 | elif eval_mode == 'P': # ablation study on network pooling layer 487 | model_eval_pool = ['ConvNetNP', 'ConvNetMP', 'ConvNetAP'] 488 | elif eval_mode == 'N': # ablation study on network normalization layer 489 | model_eval_pool = ['ConvNetNN', 'ConvNetBN', 'ConvNetLN', 'ConvNetIN', 'ConvNetGN'] 490 | elif eval_mode == 'S': # itself 491 | model_eval_pool = [model[:model.index('BN')]] if 'BN' in model else [model] 492 | elif eval_mode == 'C': 493 | model_eval_pool = [model, 'ConvNet'] 494 | else: 495 | model_eval_pool = [model_eval] 496 | return model_eval_pool 497 | 498 | 499 | class ParamDiffAug(): 500 | def __init__(self): 501 | self.aug_mode = 'S' #'multiple or single' 502 | self.prob_flip = 0.5 503 | self.ratio_scale = 1.2 504 | self.ratio_rotate = 15.0 505 | self.ratio_crop_pad = 0.125 506 | self.ratio_cutout = 0.5 # the size would be 0.5x0.5 507 | self.ratio_noise = 0.05 508 | self.brightness = 1.0 509 | self.saturation = 2.0 510 | self.contrast = 0.5 511 | 512 | 513 | def set_seed_DiffAug(param): 514 | if param.latestseed == -1: 515 | return 516 | else: 517 | torch.random.manual_seed(param.latestseed) 518 | param.latestseed += 1 519 | 520 | 521 | def DiffAugment(x, strategy='', seed = -1, param = None): 522 | if seed == -1: 523 | param.batchmode = False 524 | else: 525 | param.batchmode = True 526 | 527 | param.latestseed = seed 528 | 529 | if strategy == 'None' or strategy == 'none': 530 | return x 531 | 532 | if strategy: 533 | if param.aug_mode == 'M': # original 534 | for p in strategy.split('_'): 535 | for f in AUGMENT_FNS[p]: 536 | x = f(x, param) 537 | elif param.aug_mode == 'S': 538 | pbties = strategy.split('_') 539 | set_seed_DiffAug(param) 540 | p = pbties[torch.randint(0, len(pbties), size=(1,)).item()] 541 | for f in AUGMENT_FNS[p]: 542 | x = f(x, param) 543 | else: 544 | exit('Error ZH: unknown augmentation mode.') 545 | x = x.contiguous() 546 | return x 547 | 548 | 549 | # We implement the following differentiable augmentation strategies based on the code provided in https://github.com/mit-han-lab/data-efficient-gans. 550 | def rand_scale(x, param): 551 | # x>1, max scale 552 | # sx, sy: (0, +oo), 1: orignial size, 0.5: enlarge 2 times 553 | ratio = param.ratio_scale 554 | set_seed_DiffAug(param) 555 | sx = torch.rand(x.shape[0]) * (ratio - 1.0/ratio) + 1.0/ratio 556 | set_seed_DiffAug(param) 557 | sy = torch.rand(x.shape[0]) * (ratio - 1.0/ratio) + 1.0/ratio 558 | theta = [[[sx[i], 0, 0], 559 | [0, sy[i], 0],] for i in range(x.shape[0])] 560 | theta = torch.tensor(theta, dtype=torch.float) 561 | if param.batchmode: # batch-wise: 562 | theta[:] = theta[0] 563 | grid = F.affine_grid(theta, x.shape, align_corners=True).to(x.device) 564 | x = F.grid_sample(x, grid, align_corners=True) 565 | return x 566 | 567 | 568 | def rand_rotate(x, param): # [-180, 180], 90: anticlockwise 90 degree 569 | ratio = param.ratio_rotate 570 | set_seed_DiffAug(param) 571 | theta = (torch.rand(x.shape[0]) - 0.5) * 2 * ratio / 180 * float(np.pi) 572 | theta = [[[torch.cos(theta[i]), torch.sin(-theta[i]), 0], 573 | [torch.sin(theta[i]), torch.cos(theta[i]), 0],] for i in range(x.shape[0])] 574 | theta = torch.tensor(theta, dtype=torch.float) 575 | if param.batchmode: # batch-wise: 576 | theta[:] = theta[0] 577 | grid = F.affine_grid(theta, x.shape, align_corners=True).to(x.device) 578 | x = F.grid_sample(x, grid, align_corners=True) 579 | return x 580 | 581 | 582 | def rand_flip(x, param): 583 | prob = param.prob_flip 584 | set_seed_DiffAug(param) 585 | randf = torch.rand(x.size(0), 1, 1, 1, device=x.device) 586 | if param.batchmode: # batch-wise: 587 | randf[:] = randf[0] 588 | return torch.where(randf < prob, x.flip(3), x) 589 | 590 | 591 | def rand_brightness(x, param): 592 | ratio = param.brightness 593 | set_seed_DiffAug(param) 594 | randb = torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) 595 | if param.batchmode: # batch-wise: 596 | randb[:] = randb[0] 597 | x = x + (randb - 0.5)*ratio 598 | return x 599 | 600 | 601 | def rand_saturation(x, param): 602 | ratio = param.saturation 603 | x_mean = x.mean(dim=1, keepdim=True) 604 | set_seed_DiffAug(param) 605 | rands = torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) 606 | if param.batchmode: # batch-wise: 607 | rands[:] = rands[0] 608 | x = (x - x_mean) * (rands * ratio) + x_mean 609 | return x 610 | 611 | 612 | def rand_contrast(x, param): 613 | ratio = param.contrast 614 | x_mean = x.mean(dim=[1, 2, 3], keepdim=True) 615 | set_seed_DiffAug(param) 616 | randc = torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) 617 | if param.batchmode: # batch-wise: 618 | randc[:] = randc[0] 619 | x = (x - x_mean) * (randc + ratio) + x_mean 620 | return x 621 | 622 | 623 | def rand_crop(x, param): 624 | # The image is padded on its surrounding and then cropped. 625 | ratio = param.ratio_crop_pad 626 | shift_x, shift_y = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5) 627 | set_seed_DiffAug(param) 628 | translation_x = torch.randint(-shift_x, shift_x + 1, size=[x.size(0), 1, 1], device=x.device) 629 | set_seed_DiffAug(param) 630 | translation_y = torch.randint(-shift_y, shift_y + 1, size=[x.size(0), 1, 1], device=x.device) 631 | if param.batchmode: # batch-wise: 632 | translation_x[:] = translation_x[0] 633 | translation_y[:] = translation_y[0] 634 | grid_batch, grid_x, grid_y = torch.meshgrid( 635 | torch.arange(x.size(0), dtype=torch.long, device=x.device), 636 | torch.arange(x.size(2), dtype=torch.long, device=x.device), 637 | torch.arange(x.size(3), dtype=torch.long, device=x.device), 638 | ) 639 | grid_x = torch.clamp(grid_x + translation_x + 1, 0, x.size(2) + 1) 640 | grid_y = torch.clamp(grid_y + translation_y + 1, 0, x.size(3) + 1) 641 | x_pad = F.pad(x, [1, 1, 1, 1, 0, 0, 0, 0]) 642 | x = x_pad.permute(0, 2, 3, 1).contiguous()[grid_batch, grid_x, grid_y].permute(0, 3, 1, 2) 643 | return x 644 | 645 | 646 | def rand_cutout(x, param): 647 | ratio = param.ratio_cutout 648 | cutout_size = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5) 649 | set_seed_DiffAug(param) 650 | offset_x = torch.randint(0, x.size(2) + (1 - cutout_size[0] % 2), size=[x.size(0), 1, 1], device=x.device) 651 | set_seed_DiffAug(param) 652 | offset_y = torch.randint(0, x.size(3) + (1 - cutout_size[1] % 2), size=[x.size(0), 1, 1], device=x.device) 653 | if param.batchmode: # batch-wise: 654 | offset_x[:] = offset_x[0] 655 | offset_y[:] = offset_y[0] 656 | grid_batch, grid_x, grid_y = torch.meshgrid( 657 | torch.arange(x.size(0), dtype=torch.long, device=x.device), 658 | torch.arange(cutout_size[0], dtype=torch.long, device=x.device), 659 | torch.arange(cutout_size[1], dtype=torch.long, device=x.device), 660 | ) 661 | grid_x = torch.clamp(grid_x + offset_x - cutout_size[0] // 2, min=0, max=x.size(2) - 1) 662 | grid_y = torch.clamp(grid_y + offset_y - cutout_size[1] // 2, min=0, max=x.size(3) - 1) 663 | mask = torch.ones(x.size(0), x.size(2), x.size(3), dtype=x.dtype, device=x.device) 664 | mask[grid_batch, grid_x, grid_y] = 0 665 | x = x * mask.unsqueeze(1) 666 | return x 667 | 668 | 669 | AUGMENT_FNS = { 670 | 'color': [rand_brightness, rand_saturation, rand_contrast], 671 | 'crop': [rand_crop], 672 | 'cutout': [rand_cutout], 673 | 'flip': [rand_flip], 674 | 'scale': [rand_scale], 675 | 'rotate': [rand_rotate], 676 | } 677 | --------------------------------------------------------------------------------